Source code for mldft.ml.data.components.basis_transforms
"""Transforms that act on :class:`OFData` objects.
The idea is to apply a subset of transforms as part of the loader, i.e. execute them per-molecule
on the cpu, before batches of multiple molecules are assembled and moved to the GPU.
"""
from typing import Literal
import numpy as np
import torch
from loguru import logger
from omegaconf import DictConfig
from mldft.ml.data.components.of_data import OFData, Representation
from mldft.ml.models.components.local_frames_module import (
LocalBasisModule,
LocalFramesTransformMatrixDense,
LocalFramesTransformMatrixSparse,
)
from mldft.ml.models.components.natural_reparametrization import (
natural_reparametrization_matrices_torch,
)
from mldft.utils import RankedLogger
log = RankedLogger(__name__, rank_zero_only=True)
[docs]
def transform_tensor(
tensor: torch.Tensor,
transformation_matrix: torch.Tensor,
inv_transformation_matrix: torch.Tensor,
representation: Representation,
):
r"""Transform tensor according to given representation.
The behavior of the transformation is determined by its `representation`. The available options
for :class:`~mldft.ml.data.components.of_data.Representation` are:
- :class:`Representation.NONE` or :class:`Representation.SCALAR`: no transformation is applied.
- :class:`Representation.VECTOR`: e.g. the density coefficients :math:`p`,
.. math::
p' = T p.
This also determines the convention for the transformation matrix.
- :class:`Representation.DUAL_VECTOR` and :class:`Representation.GRADIENT` e.g. gradients, vectors that are
applied to other vectors like the nuclear attraction vector, the basis_integrals and the
inverse transformation matrix :math:`M = \tilde T^{-1}` after another transformation :math:`T`,
.. math::
M' = M T^{-1}.
Gradients can also be projected but the other vectors can't.
- :class:`Representation.ENDOMORPHISM`: e.g. the projection matrix,
.. math::
P' = T P T^{-1}.
- :class:`Representation.BILINEAR_FORM`: e.g. the coulomb matrix,
.. math::
C' = T^{-T} C T^{-1}.
- :class:`Representation.AO`: the atomic orbitals, these transform
.. math::
A' = A T^{-1}
which is the same as the inverse vector representation.
Args:
tensor: The tensor to be transformed.
transformation_matrix: The transformation matrix.
inv_transformation_matrix: The inverse of the transformation matrix.
representation: The representation that defines the transformation behavior of the tensor.
Returns:
The transformed tensor.
Raises:
KeyError: If the representation is not known.
"""
if representation == Representation.NONE:
return tensor
elif representation == Representation.SCALAR:
return tensor
elif representation == Representation.VECTOR:
return transformation_matrix @ tensor
elif representation == Representation.DUAL_VECTOR or representation == Representation.GRADIENT:
return tensor @ inv_transformation_matrix
elif representation == Representation.ENDOMORPHISM:
return transformation_matrix @ tensor @ inv_transformation_matrix
elif representation == Representation.BILINEAR_FORM:
return inv_transformation_matrix.T @ tensor @ inv_transformation_matrix
elif representation == Representation.AO:
return tensor @ inv_transformation_matrix
else:
raise KeyError(f"Unknown representation {representation}.")
[docs]
def transform_tensor_with_sample(
sample: OFData,
tensor: torch.Tensor,
representation: Representation,
invert: bool = False,
) -> torch.Tensor:
"""Transform the given tensor with the transformation that the sample was transformed with.
The transformation_matrix and inv_transformation_matrix have to be added to the sample before
it is transformed.
Args:
sample: The sample that was already transformed and contains the transformation matrices.
tensor: The tensor to be transformed.
representation: The representation of the tensor to transform.
invert: Whether to invert the transformation.
Returns:
The transformed tensor.
"""
assert hasattr(
sample, "transformation_matrix"
), "The sample does not have a transformation matrix."
assert hasattr(
sample, "inv_transformation_matrix"
), "The sample does not have an inverse transformation matrix."
transformation_matrix = sample.transformation_matrix
inv_transformation_matrix = sample.inv_transformation_matrix
if invert:
transformation_matrix, inv_transformation_matrix = (
inv_transformation_matrix,
transformation_matrix,
)
return transform_tensor(
tensor, transformation_matrix, inv_transformation_matrix, representation
)
[docs]
class ApplyBasisTransformation(torch.nn.Module):
"""Transform all fields in the sample according to their representation.
The transformation is determined by the transformation matrix and its inverse. The
representation should be saved in the sample.representations.
"""
[docs]
def forward(
self,
sample: OFData,
transformation_matrix: torch.Tensor,
inv_transformation_matrix: torch.Tensor,
invert: bool = False,
) -> OFData:
"""
Args:
sample: The sample to transform.
transformation_matrix: The transformation matrix.
inv_transformation_matrix: The inverse of the transformation matrix.
invert: Whether to invert the transformation.
Returns:
The transformed sample.
"""
t_matrix, inv_t_matrix = (
transformation_matrix,
inv_transformation_matrix,
) # for brevity
if invert:
t_matrix, inv_t_matrix = inv_t_matrix, t_matrix
for key, rep in sample.representations.items():
if hasattr(sample, key):
tensor = transform_tensor(getattr(sample, key), t_matrix, inv_t_matrix, rep)
setattr(sample, key, tensor)
return sample
[docs]
class ToLocalFrames(ApplyBasisTransformation):
"""Apply the local frames transformation to the sample."""
[docs]
def __init__(
self,
sparse: bool = True,
):
"""
Args:
sparse: Whether to use a sparse matrix for the transformation.
"""
super().__init__()
self.sparse = sparse
if self.sparse:
self.local_frames = LocalFramesTransformMatrixSparse()
else:
self.local_frames = LocalFramesTransformMatrixDense()
[docs]
def __call__(self, sample: OFData, invert: bool = False) -> OFData:
"""Transform the sample to local frames in-place.
Args:
sample: The sample.
"""
transformation_matrix, lframes = self.local_frames.sample_forward(
sample, return_lframes=True
)
sample.add_item("lframes", lframes, Representation.NONE)
return super().__call__(
sample, transformation_matrix, transformation_matrix.T, invert=invert
)
[docs]
class AddLocalFrames:
"""Add the local frames to the sample."""
[docs]
def __call__(self, sample: OFData, invert: bool = False) -> OFData:
"""Add the local frames to the sample.
Args:
sample: The sample.
"""
lframes = self.local_basis_module(pos=sample.pos, atomic_numbers=sample.atomic_numbers)
sample.add_item("lframes", lframes, Representation.NONE)
return sample
[docs]
class ToGlobalNatRep(ApplyBasisTransformation):
"""Apply the global natural reparametrization to the sample."""
[docs]
def __init__(self, orthogonalization: Literal["symmetric", "canonical"] = "symmetric"):
"""
Args:
orthogonalization: The type of orthogonalization to use. Either "symmetric" (default)
or "canonical".
"""
super().__init__()
self.orthogonalization = orthogonalization
[docs]
def __call__(self, sample: OFData, invert: bool = False) -> OFData:
"""Naturally reparametrize the sample in-place.
Args:
sample: The sample to be transformed.
"""
assert hasattr(
sample, "overlap_matrix"
), "OFdata sample does not have an overlap matrix to which natural reparametrization needs to be applied."
(
transformation_matrix,
inv_transformation_matrix,
) = natural_reparametrization_matrices_torch(
sample.overlap_matrix, orthogonalization=self.orthogonalization
)
return super().__call__(
sample, transformation_matrix.T, inv_transformation_matrix.T, invert=invert
)
[docs]
class MasterTransformation(torch.nn.Module):
[docs]
def __init__(
self,
name: str,
use_cached_data: bool = True,
pre_transforms: list = None,
cached_transforms: DictConfig = None,
basis_transforms: list[ApplyBasisTransformation] = None,
post_transforms: list = None,
add_transformation_matrix: bool = False,
**kwargs,
) -> None:
"""Class that handles the basis transformations.
Before using the forward method, the configure_cached_transforms method should be called which
configures whether there are cached basis transforms that can be used.
Args:
target: The target transformation.
pre_transforms: List of transforms to apply before the basis transforms.
cached_transforms: List of basis transforms that are cached and can be used if the cached
option is set.
basis_transforms: List of basis transforms to apply.
post_transforms: List of transforms to apply after the basis transforms
add_transformation_matrix: Whether to add the transformation matrix to the sample.
"""
super().__init__()
if len(kwargs) > 0:
logger.warning(
f"Keyword arguments {kwargs} passed, but will not be used anymore. "
f"You're probably using an old checkpoint."
)
self.name = name
self.use_cached_data = use_cached_data
self.pre_transforms = [] if pre_transforms is None else pre_transforms
self.post_transforms = [] if post_transforms is None else post_transforms
self.additional_pre_transforms = (
cached_transforms.get("additional_pre_transforms", [])
if cached_transforms is not None
else []
)
self.cached_basis_transforms = (
cached_transforms.get("transforms", []) if cached_transforms is not None else []
)
if not self.use_cached_data:
self.label_subdir = "labels"
else:
# Configure which label subdir to use
cached_transform_name = (
cached_transforms.name if cached_transforms is not None else "none"
)
if cached_transform_name == "none":
self.label_subdir = "labels"
else:
self.label_subdir = f"labels_{cached_transform_name}"
self.basis_transforms = [] if basis_transforms is None else basis_transforms
self.add_transformation_matrix = add_transformation_matrix
if add_transformation_matrix and use_cached_data:
raise ValueError(
"Transformation matrix will be wrong when using cached data. Either don't add the transformation matrix or don't use cached data."
)
[docs]
def forward(self, sample: OFData) -> OFData:
"""Apply the configured transformation to a sample."""
if self.add_transformation_matrix:
sample = self.initialize_transformation_matrices(sample)
if not self.use_cached_data:
for transform in self.additional_pre_transforms:
sample = transform(sample)
for transform in self.pre_transforms:
sample = transform(sample)
if not self.use_cached_data:
for transform in self.cached_basis_transforms:
sample = transform(sample)
for transform in self.basis_transforms:
sample = transform(sample)
for transform in self.post_transforms:
sample = transform(sample)
return sample
[docs]
def basis_transform(self, sample: OFData) -> OFData:
"""Apply the forward basis transforms to a sample."""
for transform in self.basis_transforms:
sample = transform(sample)
return sample
[docs]
def invert_basis_transform(self, sample: OFData) -> OFData:
"""Invert the basis transforms of a sample.
Inverse transforms are currently used for visualization during training and for obtaining
the final density in ofdft. This will be done in the torch default float dtype.
"""
# If the sample has a transformation matrix, the transforms were saved and can be directly applied using
# the transformation matrices. Otherwise we recompute the transformation matrices using a dummy sample.
if not (
hasattr(sample, "transformation_matrix")
and hasattr(sample, "inv_transformation_matrix")
):
dummy_sample = sample.clone()
dummy_sample = self.initialize_transformation_matrices(dummy_sample)
for transform in self.additional_pre_transforms:
dummy_sample = transform(dummy_sample)
for transform in self.pre_transforms:
dummy_sample = transform(dummy_sample)
for transform in self.cached_basis_transforms:
dummy_sample = transform(dummy_sample)
for transform in self.basis_transforms:
dummy_sample = transform(dummy_sample)
for transform in self.post_transforms:
dummy_sample = transform(dummy_sample)
transformation_matrix = dummy_sample.transformation_matrix
inv_transformation_matrix = dummy_sample.inv_transformation_matrix
else:
transformation_matrix = sample.transformation_matrix
inv_transformation_matrix = sample.inv_transformation_matrix
# The transformation matrix will be in the default dtype, as well as the sample
ApplyBasisTransformation()(
sample, transformation_matrix, inv_transformation_matrix, invert=True
)
return sample
[docs]
@staticmethod
def initialize_transformation_matrices(sample: OFData):
"""Add the transformation matrices to the sample."""
sample.add_item(
"transformation_matrix",
np.eye(sample.coeffs.shape[0], dtype=np.float64),
Representation.VECTOR,
)
sample.add_item(
"inv_transformation_matrix",
np.eye(sample.coeffs.shape[0], dtype=np.float64),
Representation.DUAL_VECTOR,
)
return sample