node_embedding

The NodeEmbedding module initializes the hidden node features \(h\).

It combines three parts: A simple embedding layer assigns learnable weights to each of the atomic numbers \(Z\). Density coefficients \(\tilde{p}\) are embedded in an ‘atom-hot’ way, mapping all coefficients to the same dimension. A further processing of a ShrinkGateModule maps the coefficients into a bounded space, facilitating a stable optimization. Afterwards, an MLP projects these features to the hidden node dimension. To encode the chemical environment, the pairwise distances encoded by Gaussian-Basis-Functions are aggregated and passed through an MLP. The three parts are then summed to yield the hidden node features \(h\), ready to serve as input for the following stack of G3D-layers.

class NodeEmbedding(n_atoms: int, basis_dim_per_atom: ~torch.Tensor | ~numpy.ndarray, basis_atomic_numbers: ~torch.Tensor | ~numpy.ndarray, atomic_number_to_atom_index: ~torch.Tensor | ~numpy.ndarray, out_channels: int, dst_in_channels: int, p_hidden_channels: int = 32, p_num_layers: int = 3, p_activation: callable = <class 'torch.nn.modules.activation.SiLU'>, p_dropout: float = 0.0, dst_hidden_channels: int = 32, dst_num_layers: int = 3, dst_activation: callable = <class 'torch.nn.modules.activation.SiLU'>, dst_dropout: float = 0.0, lambda_co: float = None, lambda_mul: float = None, use_per_basis_func_shrink_gate: bool = False, cutoff: float | None = None, cutoff_start: float = 0.0)[source]

The NodeEmbedding module creates node features from density, atomic number and distance embeddings.

__init__(n_atoms: int, basis_dim_per_atom: ~torch.Tensor | ~numpy.ndarray, basis_atomic_numbers: ~torch.Tensor | ~numpy.ndarray, atomic_number_to_atom_index: ~torch.Tensor | ~numpy.ndarray, out_channels: int, dst_in_channels: int, p_hidden_channels: int = 32, p_num_layers: int = 3, p_activation: callable = <class 'torch.nn.modules.activation.SiLU'>, p_dropout: float = 0.0, dst_hidden_channels: int = 32, dst_num_layers: int = 3, dst_activation: callable = <class 'torch.nn.modules.activation.SiLU'>, dst_dropout: float = 0.0, lambda_co: float = None, lambda_mul: float = None, use_per_basis_func_shrink_gate: bool = False, cutoff: float | None = None, cutoff_start: float = 0.0) None[source]

Initialize the NodeEmbedding module.

Parameters:
  • n_atoms (int) – Number of atom types in the dataset.

  • basis_dim_per_atom (Tensor[int] | ndarray[int]) – Basis Dimensions per atomic number.

  • basis_atomic_numbers (Tensor[int] | ndarray[int]) – Atomic numbers in the basis.

  • atomic_number_to_atom_index (Tensor[int] | ndarray[int]) – Mapping from atomic number to atom index, e.g. basis_dim_per_atom[atomic_number_to_atom_index[1]] yields the basis dimensions for atomic number 1.

  • out_channels (int) – Number of output channels for the hidden node representation \(h\).

  • dst_in_channels (int) – Number of input channels of GBF-transformed pairwise distances \(\mathcal{E}\).

  • p_hidden_channels (int) – Number of hidden channels for the MLP of density coefficients. Defaults to 32.

  • p_num_layers (int) – Number hidden layers for the MLP of density coefficients. Defaults to 3.

  • p_activation (callable) – Activation function for the MLP of density coefficients.

  • dst_hidden_channels (int) – Number of hidden channels for the MLP for the distances.

  • dst_num_layers (int) – Number hidden layers for the MLP of GBF-transformed distances. Defaults to 3.

  • dst_activation (callable) – Activation function for the MLP of GBF-transformed

  • lambda_co (float) – lambda_co parameter for the ShrinkGateModule.

  • lambda_mul (float) – lambda_mul parameter for the ShrinkGateModule.

  • use_per_basis_func_shrink_gate (bool) – Whether to use a per-basis-function shrink gate or not. Mainly to allow backwards compatibility for older checkpoints.

  • cutoff – The cutoff radius for the distance embedding. If None, no cutoff is applied.

__setstate__(state: dict) None[source]

This method is called during unpickling.

If ‘cutoff_start’ is missing (as would be the case with an older checkpoint), it will be added with a default value.

aggregate_distances(edge_attributes: Tensor, edge_index: Tensor, n_atoms: int) Tensor[source]

Aggregate distances (edge features) for each atom over connected atoms.

Parameters:
  • edge_attributes (Tensor[float]) – Edge features (GBF transformed distances) of shape (num_edges, dst_in_channels).

  • edge_index (Tensor[int]) – Edge indices of shape (2, num_edges).

  • n_atoms – Number of atoms in the sample.

Returns:

Aggregated edge features of

shape (n_atoms, dst_in_channels).

Return type:

aggregated_features (Tensor[float])

forward(coeffs: Tensor | ndarray, atom_ind: Tensor | ndarray, basis_function_ind: Tensor | ndarray, n_basis_dim_per_atom: Tensor | ndarray, coeff_ind_to_node_ind: Tensor, distance_embedding: Tensor, edge_index: Tensor, batch: Tensor = None, length: Tensor = None) Tensor[source]

Forward pass of the NodeEmbedding module.

Passes the list of density coefficients \(\tilde{p}\) through and embedding and the ShrinkGate, embeds the atomic_numbers \(Z\) and aggregates the distances \(\mathcal{E}\) over edges. After passing the density coefficients and distance features through MLPs, the node features \(h\) are calculated as the sum of the three terms.

Parameters:
  • coeffs (Tensor[float]) – Density coefficients \(\tilde{p}\) of varying shape but length n_atoms.

  • atom_ind (Tensor[int]) – Atomic indices of shape (n_atoms, 1).

  • basis_function_ind (Tensor[int]) – Array holding the OFData’s basis function indices. Will be used to embed coefficients in an atom-hot way.

  • n_basis_dim_per_atom (Tensor[int]) – Number of basis functions per atom.

  • coeff_ind_to_node_ind (Tensor[int]) – Tensor mapping coefficient indices to node (=atom) indices.

  • distance_embedding (Tensor[float]) – GBF transformed distances of shape (num_edges, dst_in_channels). dst_in_channels correspond to edge_channels.

  • edge_index (Tensor[int]) – Edge indices of shape (2, num_edges).

  • batch (Tensor, optional) – Batch tensor for LayerNorm inside the MLPs.

  • length (Tensor, optional) – Edge lengths for the cutoff function.

Returns:

Node features \(h\) of shape (n_atoms, out_channels).

Return type:

h (Tensor[float])

classmethod from_basis_info(basis_info, out_channels: int, dst_in_channels: int, p_hidden_channels: int = 32, p_num_layers: int = 3, p_activation: callable = <class 'torch.nn.modules.activation.SiLU'>, p_dropout: float = 0.0, dst_hidden_channels: int = 32, dst_num_layers: int = 3, dst_activation: callable = <class 'torch.nn.modules.activation.SiLU'>, dst_dropout: float = 0.0, lambda_co: float = None, lambda_mul: float = None, use_per_basis_func_shrink_gate: bool = False, cutoff: float | None = None, cutoff_start: float = 0.0) NodeEmbedding[source]

Initialize the NodeEmbedding module from a BasisInfo object.

The arguments pertaining to the basis_info object, i.e. basis_dim_per_atom, basis_atomic_numbers and atomic_number_to_atom_index are extracted from the basis_info. For the remaining arguments, see __init__() for details on other arguments.

reset_parameters() None[source]

Reset all parameters of the NodeEmbedding module.

smooth_falloff(value: Tensor, falloff_end: float, falloff_start: float = 0.0) Tensor[source]

Calculates a smooth falloff value using a cosine function.

The function returns a tensor with values in the range [0, 1] such that:
  • For values less than or equal to falloff_start, the function returns 1.

  • For values greater than or equal to falloff_end, the function returns 0.

  • For values between falloff_start and falloff_end, it returns a smoothly interpolated value between 1 and 0 using a cosine function.

The cosine interpolation is computed as:

0.5 * (cos(pi * (value - falloff_start) / (falloff_end - falloff_start)) + 1)

Parameters:
  • value – The input tensor.

  • falloff_end – The value at which the falloff reaches 0.

  • falloff_start – The value at which the falloff begins (default is 0).

Returns:

A tensor representing the smooth falloff value, between 0 and 1.