local_frames_module
Calculates the local frames of the atoms in the molecule.
- class LocalBasisModule(ignore_hydrogen: bool = True, use_three_atoms_for_basis: bool = False)[source]
MessagePassing Module, which calculates the local basis on each node of the graph.
- __init__(ignore_hydrogen: bool = True, use_three_atoms_for_basis: bool = False) None[source]
This class is a MessagePassing module which calculates the local basis of the atoms in the molecule when given positions on a graph.
- Parameters:
ignore_hydrogen (
bool, optional) – If true, the two closest heavy atoms are used to construct the local bases. Defaults to True.use_three_atoms_for_basis (
bool, optional) – If true, the three closest atoms are used to construct the basis
- aggregate(inputs: Tensor) Tensor[source]
Aggregates the messages from the neighboring atoms.
- Parameters:
inputs (
Tensor) – The relative positions of the neighboring atoms. Shape: (n_edges=k*n_atom, 3)- Returns:
The local frames of the atoms in the molecule.
- Return type:
Tensor
- forward(pos: Tensor, atomic_numbers: Tensor | None = None, batch: Tensor | None = None) Tensor[source]
Calculates the forward pass of the module.
- Parameters:
pos (
Tensor) – The positions of the atoms.atomic_numbers (
Tensor) – The atomic numbers of the atoms in the moleculebatch (
Tensor, optional) – The pytorch geometric batch. Defaults to None.
- Returns:
The local frames of the atoms in the molecule.
- Return type:
Tensor
- message(pos_i, pos_j) Tensor[source]
Return the relative position of the neighboring atoms as a message.
- Parameters:
pos_i (
Tensor) – The position of the atom, from the second row of the edge index.pos_j (
Tensor) – The position of the neighboring atom, from the first row of the edge index.
- Returns:
The relative position of the neighboring atom.
- Return type:
Tensor
- class LocalFramesModule[source]
This Module calculates the transformed coefficients as a nn.Module.
It splits the coefficients by atom and transforms them into the local frame individually.
LocalFramesTransformMatrixcan be used to do it in parallel using sparse matrices.- forward(coeffs: list[Tensor], irreps: list[Irreps], pos: Tensor, atomic_numbers: Tensor | None = None, batch: Tensor | None = None) list[Tensor][source]
Calculates the forward pass of the module.
- Parameters:
coeffs (
list[Tensor]) – coefficients of the atoms in the molecule.irreps (
list[Irreps]) – irreps of the atoms in the molecule.pos (
Tensor) – positions of the atoms in the molecule.atomic_numbers (
Tensor) – The atomic numbers of the atoms in the moleculebatch (
Tensor, optional) – The pytorch geometric batch. Defaults to None.
- Returns:
The transformed coefficients into the local frame.
- Return type:
Tensor
- class LocalFramesTransformMatrixDense[source]
Module to calculate the (dense) transformation matrix from the standard basis to the local basis.
- forward(irreps_per_atom: ndarray[Irreps], pos: Tensor, atomic_numbers: Tensor | None = None, batch: Tensor | None = None, return_lframes: bool = False) Tensor | tuple[Tensor, Tensor][source]
Calculates the transformation matrix from the standard basis to the local basis.
- Parameters:
irreps_per_atom – Irreps of the basis functions per atom. Shape (n_atom,).
pos – Positions of the atoms. Shape (n_atom, 3).
atomic_numbers (
Tensor) – The atomic numbers of the atoms in the moleculebatch – Batch vector. Shape (n_atom,). Defaults to None.
return_lframes – If True, the local frames are returned as well. Defaults to False.
- class LocalFramesTransformMatrixSparse[source]
Module to calculate the (sparse) transformation matrix from the standard basis to the local basis.
- forward(n_basis: int, irreps_per_atom: ndarray[Irreps], pos: Tensor, atom_coo_indices: Tensor, atomic_numbers: Tensor | None = None, batch: Tensor | None = None, return_lframes: bool = False) Tensor | tuple[Tensor, Tensor][source]
Calculates the transformation matrix from the standard basis to the local basis.
- Parameters:
n_basis – Total number of basis functions in the molecule.
irreps_per_atom – Irreps of the basis functions per atom. Shape (n_atom,).
pos – Positions of the atoms. Shape (n_atom, 3).
atom_coo_indices – Indices that can be used to construct a per-atom block-diagonal sparse COOrdinate matrix, as returned by
add_atom_coo_indices(). Shape (2, n_basis).atomic_numbers (
Tensor) – The atomic numbers of the atoms in the moleculebatch – Batch vector. Shape (n_atom,). Defaults to None.
return_lframes – If True, the local frames are returned as well. Defaults to False.
- sample_forward(sample: OFData, return_lframes: bool = False) Tensor | tuple[Tensor, Tensor][source]
Wrapper around forward that takes an OFData object instead of the individual arguments.
- Parameters:
sample – The sample.
return_lframes – If True, the local frames are returned as well. Defaults to False.