"""The NodeEmbedding module initializes the hidden node features :math:`h`.
It combines three parts: A simple embedding layer assigns learnable weights to each of the
atomic numbers :math:`Z`. Density coefficients :math:`\\tilde{p}` are embedded in an 'atom-hot' way,
mapping all coefficients to the same dimension. A further processing of a ShrinkGateModule maps
the coefficients into a bounded space, facilitating a stable optimization. Afterwards, an MLP
projects these features to the hidden node dimension. To encode the chemical environment, the
pairwise distances encoded by Gaussian-Basis-Functions are aggregated and passed through an MLP.
The three parts are then summed to yield the hidden node features :math:`h`, ready to serve as
input for the following stack of G3D-layers.
"""
from __future__ import annotations
import torch
from numpy import ndarray
from torch import Tensor
from torch.nn import Embedding, Module, SiLU
from torch_geometric.utils import scatter
from mldft.ml.models.components.density_coeff_embedding import AtomHotEmbedding
from mldft.ml.models.components.mlp import MLP
from mldft.ml.models.components.shrink_gate_module import (
PerBasisFuncShrinkGateModule,
ShrinkGateModule,
)
from mldft.utils.log_utils.logging_mixin import LoggingMixin
[docs]
def smooth_falloff(
value: torch.Tensor, falloff_end: float, falloff_start: float = 0.0
) -> torch.Tensor:
"""Calculates a smooth falloff value using a cosine function.
The function returns a tensor with values in the range [0, 1] such that:
- For values less than or equal to `falloff_start`, the function returns 1.
- For values greater than or equal to `falloff_end`, the function returns 0.
- For values between `falloff_start` and `falloff_end`, it returns a smoothly
interpolated value between 1 and 0 using a cosine function.
The cosine interpolation is computed as:
0.5 * (cos(pi * (value - falloff_start) / (falloff_end - falloff_start)) + 1)
Args:
value: The input tensor.
falloff_end: The value at which the falloff reaches 0.
falloff_start: The value at which the falloff begins (default is 0).
Returns:
A tensor representing the smooth falloff value, between 0 and 1.
"""
if falloff_start >= falloff_end:
raise ValueError("falloff_start must be less than falloff_end")
# Compute the interpolation factor for values in the decay range.
fraction = (value - falloff_start) / (falloff_end - falloff_start)
cosine_falloff = 0.5 * (torch.cos(torch.pi * fraction) + 1)
# Use torch.where to assign:
# - 1 for values <= falloff_start,
# - 0 for values >= falloff_end, and
# - cosine_falloff for values in between.
result = torch.where(
value <= falloff_start, 1, torch.where(value >= falloff_end, 0, cosine_falloff)
)
return result
[docs]
class NodeEmbedding(Module, LoggingMixin):
"""The NodeEmbedding module creates node features from density, atomic number and distance
embeddings."""
[docs]
def __init__(
self,
n_atoms: int,
basis_dim_per_atom: Tensor | ndarray,
basis_atomic_numbers: Tensor | ndarray,
atomic_number_to_atom_index: Tensor | ndarray,
out_channels: int,
dst_in_channels: int,
p_hidden_channels: int = 32,
p_num_layers: int = 3,
p_activation: callable = SiLU,
p_dropout: float = 0.0,
dst_hidden_channels: int = 32,
dst_num_layers: int = 3,
dst_activation: callable = SiLU,
dst_dropout: float = 0.0,
lambda_co: float = None,
lambda_mul: float = None,
use_per_basis_func_shrink_gate: bool = False,
cutoff: float | None = None,
cutoff_start: float = 0.0,
) -> None:
r"""Initialize the NodeEmbedding module.
Args:
n_atoms (int): Number of atom types in the dataset.
basis_dim_per_atom (Tensor[int] | ndarray[int]): Basis Dimensions per atomic number.
basis_atomic_numbers (Tensor[int] | ndarray[int]): Atomic numbers in the basis.
atomic_number_to_atom_index (Tensor[int] | ndarray[int]): Mapping from atomic number
to atom index, e.g. basis_dim_per_atom[atomic_number_to_atom_index[1]] yields the
basis dimensions for atomic number 1.
out_channels (int): Number of output channels for the hidden node representation :math:`h`.
dst_in_channels (int): Number of input channels of GBF-transformed pairwise distances
:math:`\mathcal{E}`.
p_hidden_channels (int): Number of hidden channels for the MLP of density coefficients.
Defaults to 32.
p_num_layers (int): Number hidden layers for the MLP of density coefficients.
Defaults to 3.
p_activation (callable): Activation function for the MLP of density coefficients.
dst_hidden_channels (int): Number of hidden channels for the MLP for the distances.
dst_num_layers (int): Number hidden layers for the MLP of GBF-transformed distances.
Defaults to 3.
dst_activation (callable): Activation function for the MLP of GBF-transformed
lambda_co (float): lambda_co parameter for the ShrinkGateModule.
lambda_mul (float): lambda_mul parameter for the ShrinkGateModule.
use_per_basis_func_shrink_gate (bool): Whether to use a per-basis-function shrink gate or not.
Mainly to allow backwards compatibility for older checkpoints.
cutoff: The cutoff radius for the distance embedding. If None, no cutoff is applied.
"""
super().__init__()
self.n_atoms = n_atoms # Maybe only at forward pass?
self.out_channels = out_channels
self.basis_dim_per_atom = basis_dim_per_atom
self.basis_atomic_numbers = basis_atomic_numbers
self.atomic_number_to_atom_index = atomic_number_to_atom_index
self.atom_hot_embedding = AtomHotEmbedding(sum(basis_dim_per_atom))
self.dst_in_channels = dst_in_channels
self.p_in_channels = sum(basis_dim_per_atom)
self.z_embed = Embedding(n_atoms, out_channels)
self.cutoff = cutoff
self.cutoff_start = cutoff_start
if use_per_basis_func_shrink_gate:
self.shrink_gate = PerBasisFuncShrinkGateModule(
sum(basis_dim_per_atom), lambda_co, lambda_mul
)
else:
self.shrink_gate = ShrinkGateModule(lambda_co, lambda_mul)
self.mlp_p = MLP(
in_channels=self.p_in_channels,
hidden_channels=[p_hidden_channels for _ in range(p_num_layers - 1)] + [out_channels],
activation_layer=p_activation,
dropout=p_dropout,
)
self.mlp_dst = MLP(
in_channels=self.dst_in_channels,
hidden_channels=[dst_hidden_channels for _ in range(dst_num_layers - 1)]
+ [out_channels],
activation_layer=dst_activation,
dropout=dst_dropout,
)
self.reset_parameters()
[docs]
@classmethod
def from_basis_info(
cls,
basis_info,
out_channels: int,
dst_in_channels: int,
p_hidden_channels: int = 32,
p_num_layers: int = 3,
p_activation: callable = SiLU,
p_dropout: float = 0.0,
dst_hidden_channels: int = 32,
dst_num_layers: int = 3,
dst_activation: callable = SiLU,
dst_dropout: float = 0.0,
lambda_co: float = None,
lambda_mul: float = None,
use_per_basis_func_shrink_gate: bool = False,
cutoff: float | None = None,
cutoff_start: float = 0.0,
) -> NodeEmbedding:
"""Initialize the NodeEmbedding module from a BasisInfo object.
The arguments pertaining to the basis_info object, i.e. basis_dim_per_atom,
basis_atomic_numbers and atomic_number_to_atom_index are extracted from the basis_info.
For the remaining arguments, see :meth:`__init__` for details on other arguments.
"""
return cls(
n_atoms=basis_info.n_types,
basis_atomic_numbers=basis_info.atomic_numbers,
basis_dim_per_atom=basis_info.basis_dim_per_atom,
atomic_number_to_atom_index=basis_info.atomic_number_to_atom_index,
out_channels=out_channels,
dst_in_channels=dst_in_channels,
p_hidden_channels=p_hidden_channels,
p_num_layers=p_num_layers,
p_activation=p_activation,
p_dropout=p_dropout,
dst_hidden_channels=dst_hidden_channels,
dst_num_layers=dst_num_layers,
dst_activation=dst_activation,
dst_dropout=dst_dropout,
lambda_co=lambda_co,
lambda_mul=lambda_mul,
use_per_basis_func_shrink_gate=use_per_basis_func_shrink_gate,
cutoff=cutoff,
cutoff_start=cutoff_start,
)
[docs]
def reset_parameters(self) -> None:
"""Reset all parameters of the NodeEmbedding module."""
for layer in self.mlp_p:
if hasattr(layer, "reset_parameters"):
layer.reset_parameters()
for layer in self.mlp_dst:
if hasattr(layer, "reset_parameters"):
layer.reset_parameters()
self.z_embed.reset_parameters()
[docs]
def aggregate_distances(
self, edge_attributes: Tensor, edge_index: Tensor, n_atoms: int
) -> Tensor:
"""Aggregate distances (edge features) for each atom over connected atoms.
Args:
edge_attributes (Tensor[float]): Edge features (GBF transformed distances) of shape
(num_edges, dst_in_channels).
edge_index (Tensor[int]): Edge indices of shape (2, num_edges).
n_atoms: Number of atoms in the sample.
Returns:
aggregated_features (Tensor[float]): Aggregated edge features of
shape (n_atoms, dst_in_channels).
"""
source_nodes = edge_index[0] # 'source' atoms
aggregated_features = scatter(src=edge_attributes, index=source_nodes, dim_size=n_atoms)
return aggregated_features
[docs]
def forward(
self,
coeffs: Tensor | ndarray,
atom_ind: Tensor | ndarray,
basis_function_ind: Tensor | ndarray,
n_basis_dim_per_atom: Tensor | ndarray,
coeff_ind_to_node_ind: Tensor,
distance_embedding: Tensor,
edge_index: Tensor,
batch: Tensor = None,
length: Tensor = None,
) -> Tensor:
r"""Forward pass of the NodeEmbedding module.
Passes the list of density coefficients :math:`\tilde{p}` through and embedding and the
ShrinkGate, embeds the atomic_numbers :math:`Z` and aggregates the distances :math:`\mathcal{E}` over
edges. After passing the density coefficients and distance features through MLPs, the
node features :math:`h` are calculated as the sum of the three terms.
Args:
coeffs (Tensor[float]): Density coefficients :math:`\tilde{p}` of varying shape
but length n_atoms.
atom_ind (Tensor[int]): Atomic indices of shape (n_atoms, 1).
basis_function_ind (Tensor[int]): Array holding the OFData's basis function indices.
Will be used to embed coefficients in an atom-hot way.
n_basis_dim_per_atom (Tensor[int]): Number of basis functions per atom.
coeff_ind_to_node_ind (Tensor[int]): Tensor mapping coefficient indices to
node (=atom) indices.
distance_embedding (Tensor[float]): GBF transformed distances of shape
(num_edges, dst_in_channels). dst_in_channels correspond to edge_channels.
edge_index (Tensor[int]): Edge indices of shape (2, num_edges).
batch (Tensor, optional): Batch tensor for LayerNorm inside the MLPs.
length (Tensor, optional): Edge lengths for the cutoff function.
Returns:
h (Tensor[float]): Node features :math:`h` of shape (n_atoms, out_channels).
"""
if isinstance(coeffs, ndarray):
coeffs = torch.from_numpy(coeffs).float()
if isinstance(atom_ind, ndarray):
atom_ind = torch.from_numpy(atom_ind).int()
if isinstance(basis_function_ind, ndarray):
basis_function_ind = torch.from_numpy(basis_function_ind).int()
p = self.atom_hot_embedding(
coeffs=coeffs,
basis_function_ind=basis_function_ind,
n_basis_per_atom=n_basis_dim_per_atom,
coeff_ind_to_node_ind=coeff_ind_to_node_ind,
)
p = self.shrink_gate(p)
p = self.mlp_p(p, batch)
z = self.z_embed(atom_ind.squeeze())
if self.cutoff is not None:
distance_embedding = distance_embedding * smooth_falloff(
length, self.cutoff, self.cutoff_start
)
dst_sum = self.aggregate_distances(
edge_attributes=distance_embedding,
edge_index=edge_index,
n_atoms=atom_ind.shape[0],
)
dst_sum = self.mlp_dst(dst_sum, batch)
self.log(p=p, z=z, dst_sum=dst_sum)
# self.log(p_std=p.std(), z_std=z.std(), dst_sum_std=dst_sum.std())
h = p + z + dst_sum
return h
[docs]
def __setstate__(self, state: dict) -> None:
"""This method is called during unpickling.
If 'cutoff_start' is missing (as would be the case with an older checkpoint), it will be
added with a default value.
"""
# Update the state dictionary first
self.__dict__.update(state)
# If the new attribute is missing, add it with the default value.
if not hasattr(self, "cutoff"):
self.cutoff = None
if not hasattr(self, "cutoff_start"):
self.cutoff_start = 0.0