Source code for mldft.ml.models.components.local_frames_module
"""Calculates the local frames of the atoms in the molecule."""
import numpy as np
import torch
from e3nn.o3 import Irreps
from torch import Tensor, nn
from torch_geometric.nn import MessagePassing, knn_graph
from mldft.ml.data.components.of_data import OFData
from mldft.utils.local_frames import (
local_frames_from_rel_positions,
pyscf_to_e3nn_local_frames_matrix,
transform_coeffs_to_local,
)
[docs]
class LocalBasisModule(MessagePassing):
"""MessagePassing Module, which calculates the local basis on each node of the graph."""
[docs]
def __init__(
self, ignore_hydrogen: bool = True, use_three_atoms_for_basis: bool = False
) -> None:
"""This class is a MessagePassing module which calculates the local basis of the atoms in
the molecule when given positions on a graph.
Args:
ignore_hydrogen (bool, optional): If true, the two closest heavy atoms are used to construct the
local bases. Defaults to True.
use_three_atoms_for_basis (bool, optional): If true, the three closest atoms are used to construct the basis
"""
super().__init__()
self.ignore_hydrogen = ignore_hydrogen
# distance where the dummy atoms are placed
self.dummy_distance = 1e6
# Number of neighbors to consider. (In this case fixed to 3)
if use_three_atoms_for_basis:
self.k = 3
else:
self.k = 2
[docs]
def aggregate(self, inputs: Tensor) -> Tensor:
"""Aggregates the messages from the neighboring atoms.
Args:
inputs (Tensor): The relative positions of the neighboring atoms. Shape: (n_edges=k*n_atom, 3)
Returns:
Tensor: The local frames of the atoms in the molecule.
"""
# Reshape to [num_nodes, k, num_features]
inputs = inputs.view(-1, self.k, 3)
length = torch.norm(inputs, dim=-1)
if self.k == 3:
# Sort the inputs by the length of the vectors
sorted_indices = torch.argsort(length, dim=-1)
sorted_indices_inv = torch.argsort(sorted_indices, dim=-1)
pos_1 = inputs[sorted_indices_inv == 0, :]
pos_2 = inputs[sorted_indices_inv == 1, :]
pos_3 = inputs[sorted_indices_inv == 2, :]
local_basis = local_frames_from_rel_positions(pos_1, pos_2, pos_3)
else:
assert self.k == 2
# Sort the inputs by the length of the vectors
# use <= for the edge case of two neighbors with equal distance
diff = length[:, 0] - length[:, 1]
pos_1 = torch.where(diff[:, None] <= 0, inputs[:, 0], inputs[:, 1])
pos_2 = torch.where(diff[:, None] <= 0, inputs[:, 1], inputs[:, 0])
local_basis = local_frames_from_rel_positions(pos_1, pos_2)
return local_basis
[docs]
def forward(
self,
pos: Tensor,
atomic_numbers: Tensor | None = None,
batch: Tensor | None = None,
) -> Tensor:
"""Calculates the forward pass of the module.
Args:
pos (Tensor): The positions of the atoms.
atomic_numbers (Tensor): The atomic numbers of the atoms in the molecule
batch (Tensor, optional): The pytorch geometric batch. Defaults to None.
Returns:
Tensor: The local frames of the atoms in the molecule.
"""
if not self.ignore_hydrogen:
# construct a simple knn graph
# if there are not enough atoms, i.e., one more than the number of neighbors, add dummy atoms to the graph
if len(pos) - 1 < self.k:
n_dummy_atoms = self.k - len(pos) + 1
else:
n_dummy_atoms = 0
assert (
2 * max(torch.linalg.norm(pos, dim=-1)) < self.dummy_distance
), "The dummy distance might be too small for the molecule"
# add dummy atoms to the graph at self.dummy_distance. If dummy_atoms is 0, this does nothing
pos = torch.cat(
[
pos,
self.dummy_distance
* torch.randn(
[n_dummy_atoms] + list(pos[0].size()),
dtype=pos.dtype,
layout=pos.layout,
device=pos.device,
),
]
)
edge_index = knn_graph(pos, self.k, batch, loop=False, flow=self.flow)
else:
heavy_atom_mask = atomic_numbers != 1
# if there are not enough heavy atoms, i.e., one more than the number of neighbors, add dummy atoms to the
# graph
if heavy_atom_mask.sum() - 1 < self.k:
n_dummy_atoms = self.k - heavy_atom_mask.sum() + 1
else:
n_dummy_atoms = 0
heavy_atom_mask = torch.cat(
[
heavy_atom_mask,
torch.ones(
n_dummy_atoms,
dtype=heavy_atom_mask.dtype,
device=heavy_atom_mask.device,
),
]
)
# construct a knn graph with the three closest heavy atoms for each atom otherwise add dummy atoms
assert (
atomic_numbers is not None
), "atom_ind must be provided if ignore_hydrogen is True"
assert batch is None, "batching is not supported if ignore_hydrogen is True"
heavy_atom_ind = torch.argwhere(heavy_atom_mask)
pos = torch.cat(
[
pos,
self.dummy_distance
* torch.randn(
[n_dummy_atoms] + list(pos[0].size()),
dtype=pos.dtype,
layout=pos.layout,
device=pos.device,
),
]
)
# construct distance matrix between heavy atoms and all atoms
dist_mat = torch.cdist(pos, pos[heavy_atom_mask])
dist_order = heavy_atom_ind[torch.argsort(dist_mat, dim=1, descending=False)]
assert torch.all(dist_order[heavy_atom_mask, 0].eq(heavy_atom_ind)), (
f"the closest heavy atom is not itself: "
f"{dist_order[heavy_atom_mask, 0]} != {heavy_atom_ind}"
)
edge_index_neighbors = torch.cat(
[
# take the k closest heavy atoms for each hydrogen
dist_order[~heavy_atom_mask, : self.k].flatten(),
# take the second to k+1 closest heavy atoms for each heavy atom, as the closest one is itself
dist_order[heavy_atom_mask, 1 : self.k + 1].flatten(),
],
dim=0,
)
edge_index_centers = torch.cat(
[
# take the hydrogen atoms
torch.arange(pos.shape[0])[~heavy_atom_mask].repeat_interleave(self.k),
# take the heavy atoms
torch.arange(pos.shape[0])[heavy_atom_mask].repeat_interleave(self.k),
],
dim=0,
)
edge_index = torch.stack([edge_index_neighbors, edge_index_centers], dim=0)
# sort the edge index by the center atoms. this is necessary for the aggregation!
edge_index = edge_index[:, torch.argsort(edge_index[1])]
# assert that the edge index is sorted
assert torch.all(edge_index[1, 1:] >= edge_index[1, :-1])
# propagate the messages through the knn-graph
result = self.propagate(edge_index, pos=pos)
# remove the dummy atoms again
if n_dummy_atoms > 0:
# remove the dummy atoms from the result
result = result[:-n_dummy_atoms, :, :]
return result
[docs]
def message(self, pos_i, pos_j) -> Tensor:
"""Return the relative position of the neighboring atoms as a message.
Args:
pos_i (Tensor): The position of the atom, from the second row of the edge index.
pos_j (Tensor): The position of the neighboring atom, from the first row of the edge index.
Returns:
Tensor: The relative position of the neighboring atom.
"""
return pos_j - pos_i
[docs]
class LocalFramesModule(nn.Module):
"""This Module calculates the transformed coefficients as a nn.Module.
It splits the coefficients by atom and transforms them into the local frame individually.
:class:`LocalFramesTransformMatrix` can be used to do it in parallel using sparse matrices.
"""
[docs]
def __init__(self) -> None:
"""Initializes the module."""
super().__init__()
self.local_basis_module = LocalBasisModule()
[docs]
def forward(
self,
coeffs: list[Tensor],
irreps: list[Irreps],
pos: Tensor,
atomic_numbers: Tensor | None = None,
batch: Tensor | None = None,
) -> list[Tensor]:
"""Calculates the forward pass of the module.
Args:
coeffs (list[Tensor]): coefficients of the atoms in the molecule.
irreps (list[Irreps]): irreps of the atoms in the molecule.
pos (Tensor): positions of the atoms in the molecule.
atomic_numbers (Tensor): The atomic numbers of the atoms in the molecule
batch (Tensor, optional): The pytorch geometric batch. Defaults to None.
Returns:
Tensor: The transformed coefficients into the local frame.
"""
bases = self.local_basis_module.forward(pos, atomic_numbers, batch)
for i in range(pos.size(0)):
# calculate the transformed coefficients
coeffs[i] = transform_coeffs_to_local(coeffs[i], irreps[i], bases[i])
return coeffs
[docs]
class LocalFramesTransformMatrixSparse(nn.Module):
"""Module to calculate the (sparse) transformation matrix from the standard basis to the local
basis."""
[docs]
def __init__(self) -> None:
"""Initializes the module."""
super().__init__()
self.local_basis_module = LocalBasisModule()
[docs]
def forward(
self,
n_basis: int,
irreps_per_atom: np.ndarray[Irreps],
pos: Tensor,
atom_coo_indices: Tensor,
atomic_numbers: Tensor | None = None,
batch: Tensor | None = None,
return_lframes: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""Calculates the transformation matrix from the standard basis to the local basis.
Args:
n_basis: Total number of basis functions in the molecule.
irreps_per_atom: Irreps of the basis functions per atom. Shape (n_atom,).
pos: Positions of the atoms. Shape (n_atom, 3).
atom_coo_indices: Indices that can be used to construct a per-atom block-diagonal sparse COOrdinate matrix,
as returned by :func:`~mldft.ml.data.components.convert_transforms.add_atom_coo_indices`. Shape (2, n_basis).
atomic_numbers (Tensor): The atomic numbers of the atoms in the molecule
batch: Batch vector. Shape (n_atom,). Defaults to None.
return_lframes: If True, the local frames are returned as well. Defaults to False.
"""
bases = self.local_basis_module.forward(pos, atomic_numbers, batch)
blocks = []
# it might be possible to make this more efficient by computing even smaller blocks:
# one for each tensor field, instead of one for each atom
for irreps, basis in zip(irreps_per_atom, bases):
blocks.append(pyscf_to_e3nn_local_frames_matrix(basis, irreps).to(pos.device))
# construct sparse matrix
values = torch.cat([block.flatten() for block in blocks])
# mask out zero values, not needed in the sparse matrix.
# there are plenty because we use bigger blocks than necessary (see comment above).
non_zero_mask = values != 0
mat = torch.sparse_coo_tensor(
indices=atom_coo_indices[:, non_zero_mask],
values=values[non_zero_mask],
size=(n_basis,) * 2,
is_coalesced=True,
)
if return_lframes:
return mat, bases
else:
return mat
[docs]
def sample_forward(
self, sample: OFData, return_lframes: bool = False
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""Wrapper around forward that takes an OFData object instead of the individual arguments.
Args:
sample: The sample.
return_lframes: If True, the local frames are returned as well. Defaults to False.
"""
return self.forward(
sample.n_basis,
sample.irreps_per_atom,
sample.pos,
sample.atom_coo_indices,
sample.atomic_numbers,
sample.batch,
return_lframes=return_lframes,
)
[docs]
class LocalFramesTransformMatrixDense(nn.Module):
"""Module to calculate the (dense) transformation matrix from the standard basis to the local
basis."""
[docs]
def __init__(self) -> None:
"""Initializes the module."""
super().__init__()
self.local_basis_module = LocalBasisModule()
[docs]
def forward(
self,
irreps_per_atom: np.ndarray[Irreps],
pos: Tensor,
atomic_numbers: Tensor | None = None,
batch: Tensor | None = None,
return_lframes: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""Calculates the transformation matrix from the standard basis to the local basis.
Args:
irreps_per_atom: Irreps of the basis functions per atom. Shape (n_atom,).
pos: Positions of the atoms. Shape (n_atom, 3).
atomic_numbers (Tensor): The atomic numbers of the atoms in the molecule
batch: Batch vector. Shape (n_atom,). Defaults to None.
return_lframes: If True, the local frames are returned as well. Defaults to False.
"""
bases = self.local_basis_module.forward(pos, atomic_numbers, batch)
blocks = []
# as in the sparse case, it might be possible to make this more efficient by computing even smaller blocks:
# one for each tensor field, instead of one for each atom.
for irreps, basis in zip(irreps_per_atom, bases):
blocks.append(pyscf_to_e3nn_local_frames_matrix(basis, irreps).to(pos.device))
if return_lframes:
return torch.block_diag(*blocks), bases
else:
return torch.block_diag(*blocks)
[docs]
def sample_forward(
self, sample: OFData, return_lframes: bool = False
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""Wrapper around forward that takes an OFData object instead of the individual arguments.
Args:
sample: The sample.
"""
return self.forward(
sample.irreps_per_atom,
sample.pos,
sample.atomic_numbers,
sample.batch,
return_lframes=return_lframes,
)