Source code for mldft.ml.models.components.g3d_stack

"""Module to wrap a stack of :class:`~mldft.ml.models.components.g3d_layer.G3DLayer`."""
import torch
import torch.nn as nn
from torch import Tensor

from mldft.ml.models.components.g3d_layer import G3DLayer

# set tensorframes to None if not available
try:
    from tensorframes.lframes import LFrames

    from mldft.ml.models.components.g3d_layer_tf import G3DLayerTF
except ImportError:
    G3DLayerTF = None
    LFrames = None


[docs] class G3DStack(nn.Module): """Module to wrap a stack of G3D layers."""
[docs] def __init__( self, g3d_class: G3DLayer, n_layers: int, energy_readout_every: int | None = None, **g3d_kwargs, ) -> None: """The G3DStack module is a stack of :class:`~mldft.ml.models.components.g3d_layer.G3DLayer`. Args: n_layers (int): number of G3D layers. energy_readout_every (int, optional): The frequency of energy readout. If `None`, the energy is read out every `n_layers` which means only the last layer will be read out. **g3d_kwargs: Arguments of `G3DLayer`. """ super().__init__() self.n_layers = n_layers self.energy_readout_every = ( energy_readout_every if energy_readout_every is not None else n_layers ) self.g3d_layers = nn.ModuleList([g3d_class(**g3d_kwargs) for _ in range(n_layers)])
[docs] def forward( self, x: Tensor, edge_index: Tensor, batch: Tensor, lframes: LFrames = None, edge_attr: Tensor = None, length: Tensor = None, ) -> Tensor: """Calculates the forward pass of the module by going through all the G3D layers. Args: x (Tensor): The input node features. edge_index (Tensor): The edge indices. batch (Tensor): Batch assigning each node to a specific graph. lframes (LFrames, optional): The LFrames object containing the local frames. (default: `None`) edge_attr (Tensor, optional): The edge features. (default: `None`) length (Tensor, optional): The length of the edges sorted as in the edge index. Returns: Tensor: The output node features. """ inter_nodes = [] energy_readout_every = ( self.energy_readout_every if hasattr(self, "energy_readout_every") else self.n_layers ) for i in range(self.n_layers): g3d_layer = self.g3d_layers[i] if isinstance(g3d_layer, G3DLayer): x = self.g3d_layers[i]( x=x, edge_index=edge_index, batch=batch, edge_attr=edge_attr, length=length ) elif isinstance(g3d_layer, G3DLayerTF): x = self.g3d_layers[i]( x=x, edge_index=edge_index, batch=batch, lframes=lframes, edge_attr=edge_attr, length=length, ) else: raise NotImplementedError("G3DLayer class not recognized.") if (i + 1) % energy_readout_every == 0: inter_nodes.append(x) return torch.stack(inter_nodes)