local_frames_module

Calculates the local frames of the atoms in the molecule.

class LocalBasisModule(ignore_hydrogen: bool = True, use_three_atoms_for_basis: bool = False)[source]

MessagePassing Module, which calculates the local basis on each node of the graph.

__init__(ignore_hydrogen: bool = True, use_three_atoms_for_basis: bool = False) None[source]

This class is a MessagePassing module which calculates the local basis of the atoms in the molecule when given positions on a graph.

Parameters:
  • ignore_hydrogen (bool, optional) – If true, the two closest heavy atoms are used to construct the local bases. Defaults to True.

  • use_three_atoms_for_basis (bool, optional) – If true, the three closest atoms are used to construct the basis

aggregate(inputs: Tensor) Tensor[source]

Aggregates the messages from the neighboring atoms.

Parameters:

inputs (Tensor) – The relative positions of the neighboring atoms. Shape: (n_edges=k*n_atom, 3)

Returns:

The local frames of the atoms in the molecule.

Return type:

Tensor

forward(pos: Tensor, atomic_numbers: Tensor | None = None, batch: Tensor | None = None) Tensor[source]

Calculates the forward pass of the module.

Parameters:
  • pos (Tensor) – The positions of the atoms.

  • atomic_numbers (Tensor) – The atomic numbers of the atoms in the molecule

  • batch (Tensor, optional) – The pytorch geometric batch. Defaults to None.

Returns:

The local frames of the atoms in the molecule.

Return type:

Tensor

message(pos_i, pos_j) Tensor[source]

Return the relative position of the neighboring atoms as a message.

Parameters:
  • pos_i (Tensor) – The position of the atom, from the second row of the edge index.

  • pos_j (Tensor) – The position of the neighboring atom, from the first row of the edge index.

Returns:

The relative position of the neighboring atom.

Return type:

Tensor

class LocalFramesModule[source]

This Module calculates the transformed coefficients as a nn.Module.

It splits the coefficients by atom and transforms them into the local frame individually. LocalFramesTransformMatrix can be used to do it in parallel using sparse matrices.

__init__() None[source]

Initializes the module.

forward(coeffs: list[Tensor], irreps: list[Irreps], pos: Tensor, atomic_numbers: Tensor | None = None, batch: Tensor | None = None) list[Tensor][source]

Calculates the forward pass of the module.

Parameters:
  • coeffs (list[Tensor]) – coefficients of the atoms in the molecule.

  • irreps (list[Irreps]) – irreps of the atoms in the molecule.

  • pos (Tensor) – positions of the atoms in the molecule.

  • atomic_numbers (Tensor) – The atomic numbers of the atoms in the molecule

  • batch (Tensor, optional) – The pytorch geometric batch. Defaults to None.

Returns:

The transformed coefficients into the local frame.

Return type:

Tensor

class LocalFramesTransformMatrixDense[source]

Module to calculate the (dense) transformation matrix from the standard basis to the local basis.

__init__() None[source]

Initializes the module.

forward(irreps_per_atom: ndarray[Irreps], pos: Tensor, atomic_numbers: Tensor | None = None, batch: Tensor | None = None, return_lframes: bool = False) Tensor | tuple[Tensor, Tensor][source]

Calculates the transformation matrix from the standard basis to the local basis.

Parameters:
  • irreps_per_atom – Irreps of the basis functions per atom. Shape (n_atom,).

  • pos – Positions of the atoms. Shape (n_atom, 3).

  • atomic_numbers (Tensor) – The atomic numbers of the atoms in the molecule

  • batch – Batch vector. Shape (n_atom,). Defaults to None.

  • return_lframes – If True, the local frames are returned as well. Defaults to False.

sample_forward(sample: OFData, return_lframes: bool = False) Tensor | tuple[Tensor, Tensor][source]

Wrapper around forward that takes an OFData object instead of the individual arguments.

Parameters:

sample – The sample.

class LocalFramesTransformMatrixSparse[source]

Module to calculate the (sparse) transformation matrix from the standard basis to the local basis.

__init__() None[source]

Initializes the module.

forward(n_basis: int, irreps_per_atom: ndarray[Irreps], pos: Tensor, atom_coo_indices: Tensor, atomic_numbers: Tensor | None = None, batch: Tensor | None = None, return_lframes: bool = False) Tensor | tuple[Tensor, Tensor][source]

Calculates the transformation matrix from the standard basis to the local basis.

Parameters:
  • n_basis – Total number of basis functions in the molecule.

  • irreps_per_atom – Irreps of the basis functions per atom. Shape (n_atom,).

  • pos – Positions of the atoms. Shape (n_atom, 3).

  • atom_coo_indices – Indices that can be used to construct a per-atom block-diagonal sparse COOrdinate matrix, as returned by add_atom_coo_indices(). Shape (2, n_basis).

  • atomic_numbers (Tensor) – The atomic numbers of the atoms in the molecule

  • batch – Batch vector. Shape (n_atom,). Defaults to None.

  • return_lframes – If True, the local frames are returned as well. Defaults to False.

sample_forward(sample: OFData, return_lframes: bool = False) Tensor | tuple[Tensor, Tensor][source]

Wrapper around forward that takes an OFData object instead of the individual arguments.

Parameters:
  • sample – The sample.

  • return_lframes – If True, the local frames are returned as well. Defaults to False.