Source code for mldft.ml.models.components.graphformer

"""Implements the full model as described in [M-OFDFT]_."""

from typing import Tuple, Union

import torch
import torch.nn as nn
from matplotlib import pyplot as plt
from pyscf.data.elements import ELEMENTS
from torch import Tensor
from torch_geometric.nn.norm import LayerNorm
from torch_geometric.nn.pool import global_add_pool

from mldft.ml.data.components.of_batch import OFBatch
from mldft.ml.data.components.of_data import OFData
from mldft.ml.models.components.atom_ref import AtomRef
from mldft.ml.models.components.dimension_wise_rescaling import DimensionWiseRescaling
from mldft.ml.models.components.g3d_stack import G3DStack
from mldft.ml.models.components.gbf_module import GaussianLayer, GBFModule
from mldft.ml.models.components.initial_guess_delta_module import (
    InitialGuessDeltaModule,
)
from mldft.ml.models.components.mlp import MLP
from mldft.ml.models.components.node_embedding import NodeEmbedding
from mldft.utils.log_utils.logging_mixin import LoggingMixin

# set tensorframes to None if not available
try:
    from tensorframes.lframes import LFrames
except ImportError:
    LFrames = None


[docs] class MLPStack(nn.Module):
[docs] def __init__(self, mlp_class: nn.Sequential, n_mlps: int, **mlp_kwargs) -> None: super().__init__() self.n_mlps = n_mlps # List of learnable scalars self.weights = nn.Parameter(torch.ones(n_mlps)) self.mlps = nn.ModuleList([mlp_class(**mlp_kwargs) for _ in range(n_mlps)])
[docs] def forward(self, x: Tensor, batch: Tensor) -> Tensor: out = self.weights[0] * self.mlps[0](x[0], batch) for i in range(1, self.n_mlps): out = out + self.weights[i] * self.mlps[i](x[i], batch) return out / self.n_mlps
[docs] class Graphformer(nn.Module, LoggingMixin): """The Graphformer module as described in [M-OFDFT]_."""
[docs] def __init__( self, edge_mlp: MLP, energy_mlp: MLP | MLPStack, gbf_module: GBFModule, node_embedding_module: NodeEmbedding, gnn_module: G3DStack, atom_ref_module: AtomRef, initial_guess_module: InitialGuessDeltaModule, dimension_wise_rescaling_module: DimensionWiseRescaling, final_energy_factor: float = 1.0, final_norm_layer: nn.Module = None, ) -> None: """Initializes the Graphformer class. Args: edge_mlp (MLP): The MLP predicting edge attributes as input for the G3D Layer. energy_mlp (MLP): The MLP predicting the energy per atom. gbf_module (GBFModule): The GBF module. node_embedding_module (NodeEmbedding): The node embedding module. gnn_module (G3DStack): The stack of G3DLayers, which make up the main part of the module. Can be replaced by any Graph NN module with the same signature. atom_ref_module (AtomRef): The atomic reference module. initial_guess_module (InitialGuessDeltaModule): The Module mapping the final node_features to the initial guess differences. dimension_wise_rescaling_module: The dimension wise rescaling module. """ super().__init__() self.edge_mlp = edge_mlp self.gbf_module = gbf_module self.node_embedding_module = node_embedding_module self.gnn_module = gnn_module self.energy_mlp = energy_mlp if isinstance(energy_mlp, MLPStack): n_readouts = gnn_module.n_layers / gnn_module.energy_readout_every assert ( n_readouts == energy_mlp.n_mlps ), f"Energy MLP stack must have the same number of MLPs ({energy_mlp.n_mlps}) as the number of configured readouts ({n_readouts})." elif isinstance(energy_mlp, MLP): assert ( gnn_module.energy_readout_every == gnn_module.n_layers ), "Energy should only be readout at the last layer when using a single MLP as the energy readout." self.atom_ref_module = atom_ref_module self.initial_guess_module = initial_guess_module self.dimension_wise_rescaling_module = dimension_wise_rescaling_module self.final_energy_factor = final_energy_factor self.final_norm_layer = final_norm_layer
[docs] def forward(self, batch: Union[OFData, OFBatch]) -> Tuple[Tensor, Tensor]: """Calculates the forward pass of the module according to the Figure 6 in [M-OFDFT]_. Args: batch (OFData): The batch / data object containing the input data. Returns: Tuple[Tensor, Tensor]: The energy and the initial guess delta, in this order. """ edge_index = batch.edge_index # upper part of figure 6 in the M-OFDFT paper gbf_embedding, length = self.gbf_module(batch) # these are scalar attributes g3d_edge_attr = self.edge_mlp(gbf_embedding, batch.batch) # middle row of figure 6 coeffs_rescaled = self.dimension_wise_rescaling_module(batch) node_features = self.node_embedding_module( coeffs=coeffs_rescaled, atom_ind=batch.atom_ind, basis_function_ind=batch.basis_function_ind, coeff_ind_to_node_ind=batch.coeff_ind_to_node_ind, n_basis_dim_per_atom=batch.n_basis_per_atom, distance_embedding=gbf_embedding, edge_index=edge_index, length=length, ) if getattr(batch, "lframes", None) is not None and LFrames is not None: lframes = LFrames(matrices=batch.lframes.reshape(-1, 3, 3)) else: # If the gnn_module requires lframes it will raise an error lframes = None node_features = self.gnn_module( x=node_features, edge_index=edge_index, batch=batch.batch, edge_attr=g3d_edge_attr, length=length, lframes=lframes, ) if hasattr(self, "final_norm_layer") and self.final_norm_layer is not None: if isinstance(self.final_norm_layer, LayerNorm): node_features = self.final_norm_layer(node_features, batch.batch) else: node_features = self.final_norm_layer(node_features) energies_per_atom = self.energy_mlp(node_features, batch.batch) # If we use an MLP the output will still have the first n_readouts dimension with size 1 if energies_per_atom.ndim == 3: energies_per_atom.squeeze_(0) energies = global_add_pool(energies_per_atom, batch.batch)[:, 0] if hasattr(self, "final_energy_factor"): # for compatibility with older models energies *= self.final_energy_factor # lower part of figure 6 energies_atom_ref = self.atom_ref_module.sample_forward(batch) self.log(energy_atom_ref=energies_atom_ref.sum(), energy_g3d=energies.sum()) energies += energies_atom_ref # output branch initial guess initial_guess_delta = self.initial_guess_module( x=node_features[-1], basis_function_ind=batch.basis_function_ind, coeff_ind_to_node_ind=batch.coeff_ind_to_node_ind, batch=batch.batch, ) return energies, initial_guess_delta
[docs] def get_distance_embeddings(self, distances: Tensor) -> Tuple[Tensor, Tensor]: """Get the distance embeddings for the given distances. Args: distances (Tensor): The distances for which to calculate the embeddings. Returns: gbf_embedding, g3d_edge_attr: The distance embeddings, shapes (n_distances, n_embedding_dims) and (n_distances, 1), respectively. """ assert distances.dim() == 1, f"Distances must be 1D but got shape {distances.shape}." # create positions: first atom at origin, other atom on x-axis at given distances pos = torch.stack( [distances, torch.zeros_like(distances), torch.zeros_like(distances)], dim=1 ) # (n_distances, 3) pos = torch.cat([pos.new_zeros((1, 3)), pos], dim=0) # (n_distances + 1, 3) # create edge index connecting the first atom with all others device = distances.device edge_index = torch.stack( [ torch.zeros_like(distances, dtype=torch.long), torch.arange(1, distances.shape[0] + 1, device=device), ], dim=0, ) dummy_sample = OFData( pos=pos, edge_index=edge_index, atom_ind=torch.zeros_like( pos[:, 0], dtype=torch.long ), # dummy atom indices: all H (probably, depends on basis_info) ) gbf_embedding, _ = self.gbf_module(dummy_sample) g3d_edge_attr = self.edge_mlp(gbf_embedding) return gbf_embedding, g3d_edge_attr
[docs] def plot_distance_embeddings( self, max_distance: float = 10.0, n_distances: int = 1000 ) -> plt.Figure: """Plot the distance embeddings for a range of distances. Args: max_distance (float): The maximum distance to consider. n_distances (int): The number of distances to consider. Returns: plt.Figure: The plot. """ device = next(self.parameters()).device # infer device from first parameter distances = torch.linspace(0, max_distance, n_distances, device=device) gbf_embedding, g3d_edge_attr = self.get_distance_embeddings(distances) distances = distances.cpu().numpy() gbf_embedding = gbf_embedding.detach().cpu().numpy() g3d_edge_attr = g3d_edge_attr.detach().cpu().numpy() cmap = plt.cm.viridis n_ax = 2 with_edge_type = False if isinstance(self.gbf_module, GaussianLayer): n_ax += 1 with_edge_type = True fig, axs = plt.subplots(n_ax, 1, figsize=(15, 5 * n_ax)) # plot gbf_embedding as multiple lines in upper plot, colored by their index using the viridis colormap for i in range(gbf_embedding.shape[1]): axs[0].plot( distances, gbf_embedding[:, i], color=cmap(i / gbf_embedding.shape[1]), linewidth=1, ) axs[0].set_ylabel("Gaussian embedding value") axs[0].set_xlabel("Distance [Bohr]") axs[0].set_xlim(0, max_distance) axs[1].plot(distances, g3d_edge_attr) axs[1].set_ylabel("Edge attribute value") axs[1].set_xlabel("Distance [Bohr]") axs[1].set_xlim(0, max_distance) if with_edge_type: ax = axs[2] basis_info = self.gbf_module.basis_info edge_type_labels = [ f"{ELEMENTS[basis_info.atomic_numbers[i]]}-{ELEMENTS[basis_info.atomic_numbers[j]]}" for i in range(basis_info.n_types) for j in range(basis_info.n_types) ] edge_type_labels = edge_type_labels[ : self.gbf_module.n_edge_types ] # will be shorter in undirected case x = torch.arange(len(edge_type_labels)) ax.scatter( x, self.gbf_module.mul.weight.view(-1).detach().cpu().numpy(), label="mul", color="red", ) ax.scatter( x, self.gbf_module.bias.weight.view(-1).detach().cpu().numpy(), label="bias", color="blue", ) ax.set_xticks(range(len(edge_type_labels))) ax.set_xticklabels(edge_type_labels, rotation=45) ax.legend() ax.set_xlabel("Edge type") ax.set_axisbelow(True) ax.grid(True) plt.tight_layout() return fig