graphformer
Implements the full model as described in [M-OFDFT].
- class Graphformer(edge_mlp: MLP, energy_mlp: MLP | MLPStack, gbf_module: GBFModule, node_embedding_module: NodeEmbedding, gnn_module: G3DStack, atom_ref_module: AtomRef, initial_guess_module: InitialGuessDeltaModule, dimension_wise_rescaling_module: DimensionWiseRescaling, final_energy_factor: float = 1.0, final_norm_layer: Module = None)[source]
The Graphformer module as described in [M-OFDFT].
- __init__(edge_mlp: MLP, energy_mlp: MLP | MLPStack, gbf_module: GBFModule, node_embedding_module: NodeEmbedding, gnn_module: G3DStack, atom_ref_module: AtomRef, initial_guess_module: InitialGuessDeltaModule, dimension_wise_rescaling_module: DimensionWiseRescaling, final_energy_factor: float = 1.0, final_norm_layer: Module = None) None[source]
Initializes the Graphformer class.
- Parameters:
edge_mlp (
MLP) – The MLP predicting edge attributes as input for the G3D Layer.energy_mlp (
MLP) – The MLP predicting the energy per atom.gbf_module (
GBFModule) – The GBF module.node_embedding_module (
NodeEmbedding) – The node embedding module.gnn_module (
G3DStack) – The stack of G3DLayers, which make up the main part of the module. Can be replacedsignature. (by any Graph NN module with the same)
atom_ref_module (
AtomRef) – The atomic reference module.initial_guess_module (
InitialGuessDeltaModule) – The Module mapping the final node_features to the initialdifferences. (guess)
dimension_wise_rescaling_module – The dimension wise rescaling module.
- forward(batch: OFData | OFBatch) Tuple[Tensor, Tensor][source]
Calculates the forward pass of the module according to the Figure 6 in [M-OFDFT].
- Parameters:
batch (
OFData) – The batch / data object containing the input data.- Returns:
The energy and the initial guess delta, in this order.
- Return type:
Tuple[Tensor, Tensor]
- get_distance_embeddings(distances: Tensor) Tuple[Tensor, Tensor][source]
Get the distance embeddings for the given distances.
- Parameters:
distances (
Tensor) – The distances for which to calculate the embeddings.- Returns:
- The distance embeddings, shapes (n_distances, n_embedding_dims) and
(n_distances, 1), respectively.
- Return type:
gbf_embedding, g3d_edge_attr
- plot_distance_embeddings(max_distance: float = 10.0, n_distances: int = 1000) Figure[source]
Plot the distance embeddings for a range of distances.
- Parameters:
max_distance (
float) – The maximum distance to consider.n_distances (
int) – The number of distances to consider.
- Returns:
The plot.
- Return type:
plt.Figure
- class MLPStack(mlp_class: Sequential, n_mlps: int, **mlp_kwargs)[source]
- __init__(mlp_class: Sequential, n_mlps: int, **mlp_kwargs) None[source]
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x: Tensor, batch: Tensor) Tensor[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.