Source code for mldft.ml.data.components.convert_transforms

from collections.abc import Callable
from functools import partial

import numpy as np
import torch
from loguru import logger
from pyscf import dft
from torch_geometric.nn import radius_graph

from mldft.ml.data.components.basis_info import BasisInfo
from mldft.ml.data.components.of_data import OFData, Representation
from mldft.ml.models.components.loss_function import project_gradient
from mldft.ofdft.basis_integrals import (
    get_coulomb_matrix,
    get_nuclear_attraction_vector,
    get_overlap_matrix,
)
from mldft.utils.grids import grid_setup
from mldft.utils.molecules import build_molecule_ofdata
from mldft.utils.sparse import construct_block_diag_coo_indices_and_shape


[docs] def apply_to_attributes(f: Callable, sample: OFData, attributes: tuple) -> OFData: """Simplified version of :meth:`Data.apply`, with the crucial difference that errors are not quietly ignored. Applies the function f to the attributes of the object obj specified in the tuple attributes. .. warning:: This function modifies the object in-place. Args: f: The function to apply. sample: The object to apply the function to. attributes: The names of the attributes to apply the function to. Returns: The object with the attributes modified. """ for attr in attributes: try: setattr(sample, attr, f(getattr(sample, attr))) except Exception: raise RuntimeError(f"Error applying {f} to attribute {attr} of {sample}.") return sample
[docs] def split_by_atom( sample: OFData, fields=("coeffs", "ground_state_coeffs", "gradient_label") ) -> OFData: """Split all basis-function wise fields by atom. Currently, not planned to be used, but demonstrates how fields can be split when needed. Args: sample: The sample. fields: The fields to split. Returns: The sample with the fields split by atom. """ for key in fields: setattr(sample, key, np.split(getattr(sample, key), sample.atom_ptr[:-1])) return sample
[docs] class SplitByAtom: """Split all basis-function wise fields by atom."""
[docs] def __call__(self, sample: OFData) -> OFData: """Apply the transform to the sample. See :func:`split_by_atom` for details. """ return split_by_atom(sample)
[docs] class ProjectGradient(torch.nn.Module): """Project gradients stored in a sample onto the tangent space of the manifold of constant electron number."""
[docs] def forward(self, sample: OFData) -> OFData: """Project the fields of the sample as configured in :meth:`__init__`. Args: sample (OFData): The sample. Returns: OFData: The sample with the projected gradients """ for key, representation in sample.representations.items(): if representation == Representation.GRADIENT and hasattr(sample, key): sample[key] = project_gradient(getattr(sample, key), sample) return sample
[docs] def str_to_torch_float_dtype(dtype: str | torch.dtype) -> torch.dtype: """Convert a string to a torch float dtype. Useful to set the dtype in hydra configs. """ if isinstance(dtype, torch.dtype): return dtype if dtype == "torch.float64": return torch.float64 elif dtype == "torch.float32": return torch.float32 else: raise ValueError(f"Unknown dtype {dtype}")
[docs] def dtype_map(dtype: np.dtype, float_dtype: np.dtype | None | str = None) -> None | torch.dtype: """Map a numpy dtype to a torch dtype. Args: dtype: The numpy dtype. float_dtype: The dtype to use for float dtypes. Defaults to the pytorch default dtype. For strings, "torch.float64" and "torch.float32" are supported. Returns: The torch dtype, or None if no mapping exists. """ # This adds hydra support for torch.float64 and torch.float32 if type(float_dtype) == str: float_dtype = str_to_torch_float_dtype(float_dtype) if ( dtype == np.float64 or dtype == np.float32 or dtype == np.float16 or dtype == torch.float64 or dtype == torch.float32 or dtype == torch.float16 or dtype == torch.bfloat16 ): return torch.get_default_dtype() if float_dtype is None else float_dtype elif dtype == np.uint8: return torch.int else: return None
def tensor_or_array_to_torch( x: torch.Tensor | np.ndarray, device=None, float_dtype: np.dtype | torch.dtype | str | None = None, ) -> torch.Tensor: dtype = dtype_map(x.dtype, float_dtype=float_dtype) return torch.as_tensor(x, dtype=dtype, device=device)
[docs] def to_torch( sample: OFData, device=None, float_dtype: np.dtype | torch.dtype | str | None = None ) -> OFData: """Convert all numpy arrays in the sample to torch tensors. Args: sample: The sample. device: The device to put the tensors on. Defaults to None, i.e. the pytorch default device. float_dtype: The dtype to use for float dtypes. Defaults to the pytorch default dtype. For strings, "torch.float64" and "torch.float32" are supported to enable hydra support. Returns: The sample with all numpy arrays converted to torch tensors. """ keys = [] for key in sample.keys(): if isinstance(getattr(sample, key), np.ndarray): if not getattr(sample, key).dtype == np.object_: keys.append(key) elif isinstance(getattr(sample, key), torch.Tensor): keys.append(key) func = partial(tensor_or_array_to_torch, device=device, float_dtype=float_dtype) apply_to_attributes(func, sample, keys) return sample
[docs] class ToTorch: """Convert all numpy arrays in the sample to torch tensors."""
[docs] def __init__( self, device: torch.device = None, float_dtype: np.dtype | torch.dtype | None = None, ): """Initialize the transform. Args: device: The device to put the tensors on. Defaults to None, i.e. the pytorch default device. float_dtype: The dtype to use for float dtypes. Defaults to the pytorch default dtype. For strings, "torch.float64" and "torch.float32" are supported. """ self.device = device self.float_dtype = float_dtype
[docs] def __call__(self, sample: OFData) -> OFData: """Apply the transform to the sample. See :func:`to_torch` for details. """ return to_torch(sample, device=self.device, float_dtype=self.float_dtype)
[docs] def to_numpy(sample: OFData) -> OFData: """Convert all torch tensors in the sample to numpy arrays.""" keys = tuple(key for key in sample.keys() if isinstance(getattr(sample, key), torch.Tensor)) return apply_to_attributes(lambda x: x.detach().cpu().numpy(), sample, keys)
[docs] class ToNumpy: """Convert all torch tensors in the sample to numpy arrays."""
[docs] def __call__(self, sample: OFData) -> OFData: """Apply the transform to the sample. See :func:`to_numpy` for details. """ return to_numpy(sample)
[docs] def add_atom_coo_indices(sample: OFData) -> OFData: """Add a field "atom_coo_indices" to the sample, containing an index tensor of shape (2, n_basis), which can be used to construct a block diagonal sparse matrix with each block corresponding to the basis functions of one atom. Args: sample: The sample. Returns: The sample with the field "atom_coo_indices" added. """ block_shapes = [(n_basis, n_basis) for n_basis in sample.n_basis_per_atom] sample.add_item( "atom_coo_indices", construct_block_diag_coo_indices_and_shape(*block_shapes)[0], representation=Representation.NONE, ) return sample
[docs] class AddAtomCooIndices: """Add a field "atom_coo_indices" to the sample, containing an index tensor of shape (2, n_basis), which can be used to construct a block diagonal sparse matrix with each block corresponding to the basis functions of one atom."""
[docs] def __call__(self, sample: OFData) -> OFData: """Apply the transform to the sample. See :func:`add_atom_coo_indices` for details. """ return add_atom_coo_indices(sample)
[docs] class AddOverlapMatrix: """Adds the overlap matrix to the sample."""
[docs] def __init__(self, basis_info: BasisInfo): """ Args: basis_info: The basis information in the BasisInfo format. """ super().__init__() self.basis_info = basis_info
[docs] def __call__(self, sample: OFData) -> OFData: """ Args: sample: the molecule in the OFData format """ mol = build_molecule_ofdata(sample, self.basis_info.basis_dict) # The overlap matrix is only ever needed to compute the natrep transformation matrix. sample.add_item("overlap_matrix", get_overlap_matrix(mol), Representation.BILINEAR_FORM) return sample
[docs] class AddFullEdgeIndex(torch.nn.Module): """Add a full edge index to the sample."""
[docs] def forward(self, sample: OFData) -> OFData: """Add a full edge index to the sample and returns the same sample with the additional edge_index attribute. Args: sample (OFData): Sample data object. Returns: OFData: The same sample data object with the additional edge_index attribute. """ number_nodes = sample.pos.shape[0] edge_index = torch.stack( torch.meshgrid(torch.arange(number_nodes), torch.arange(number_nodes), indexing="ij") ).reshape(2, -1) sample.add_item("edge_index", edge_index, Representation.NONE) return sample
[docs] class AddRadiusEdgeIndex(torch.nn.Module): """Add a radius edge index to the sample."""
[docs] def __init__(self, radius: float): """ Args: radius: The radius to use for the edge index. """ super().__init__() self.radius = radius
[docs] def forward(self, sample: OFData) -> OFData: """Add a radius edge index to the sample and returns the same sample with the additional edge_index attribute. Args: sample (OFData): Sample data object. Returns: OFData: The same sample data object with the additional edge_index attribute. """ pos = sample.pos edge_index = radius_graph(pos, r=self.radius, loop=True) sample.add_item("edge_index", edge_index, Representation.NONE) return sample
[docs] class AddBasisInfo: """Adds the nuclear attraction vector to the sample."""
[docs] def __init__(self, basis_info: BasisInfo): """ Args: basis_info: The basis information in the BasisInfo format. """ super().__init__() self.basis_info = basis_info
[docs] def __call__(self, sample: OFData) -> OFData: """ Args: sample: the molecule in the OFData format """ sample.add_item("basis_info", self.basis_info, Representation.NONE) return sample
[docs] class AddNuclearAttractionVector: """Adds the nuclear attraction vector to the sample."""
[docs] def __init__(self, basis_info: BasisInfo): """ Args: basis_info: The basis information in the BasisInfo format. """ super().__init__() self.basis_info = basis_info
[docs] def __call__(self, sample: OFData) -> OFData: """ Args: sample: the molecule in the OFData format """ mol = build_molecule_ofdata(sample, self.basis_info.basis_dict) nuclear_attraction_vector = torch.as_tensor( get_nuclear_attraction_vector(mol), dtype=torch.float64 ) sample.add_item( "nuclear_attraction_vector", nuclear_attraction_vector, Representation.DUAL_VECTOR, ) return sample
[docs] class AddCoulombMatrix: """Adds the coulomb matrix to the sample."""
[docs] def __init__(self, basis_info: BasisInfo): """ Args: basis_info: The basis information in the BasisInfo format. """ super().__init__() self.basis_info = basis_info
[docs] def __call__(self, sample: OFData) -> OFData: """ Args: sample: the molecule in the OFData format """ mol = build_molecule_ofdata(sample, self.basis_info.basis_dict) coulomb_matrix = torch.as_tensor(get_coulomb_matrix(mol), dtype=torch.float64) sample.add_item("coulomb_matrix", coulomb_matrix, Representation.BILINEAR_FORM) return sample
[docs] class AddMol: """Adds the mol to the sample."""
[docs] def __init__(self, basis_info: BasisInfo): """ Args: basis_info: The basis information in the BasisInfo format. """ super().__init__() self.basis_info = basis_info
[docs] def __call__(self, sample: OFData) -> OFData: """ Args: sample: the molecule in the OFData format """ mol = build_molecule_ofdata(sample, self.basis_info.basis_dict) sample.add_item("mol", mol, Representation.NONE) return sample
class PrepareForDensityOptimization: def __init__( self, basis_info: BasisInfo, add_grid: bool = False, grid_level: int = 2, grid_prune: str = "nwchem_prune", ): """Transform adding the necessary information for density optimization.""" self.basis_info = basis_info self.add_grid = add_grid self.grid_level = grid_level self.grid_prune = grid_prune def __call__(self, sample: OFData) -> OFData: """Prepare the sample for density optimization. Args: sample: The sample. Returns: OFData: The prepared sample. """ mol = build_molecule_ofdata(sample, self.basis_info.basis_dict) sample.add_item("mol", mol, representation=Representation.NONE) if not hasattr(sample, "basis_info"): sample.add_item("basis_info", self.basis_info, representation=Representation.NONE) if not hasattr(sample, "overlap_matrix"): sample = AddOverlapMatrix(basis_info=self.basis_info)(sample) if not hasattr(sample, "coulomb_matrix"): coulomb_matrix = get_coulomb_matrix(mol) sample.add_item( "coulomb_matrix", coulomb_matrix, representation=Representation.BILINEAR_FORM, ) if not hasattr(sample, "nuclear_attraction_vector"): nuclear_attraction_vector = get_nuclear_attraction_vector(mol) sample.add_item( "nuclear_attraction_vector", nuclear_attraction_vector, representation=Representation.DUAL_VECTOR, ) if self.add_grid: logger.info("Adding grid to sample.") grid = grid_setup(mol, self.grid_level, self.grid_prune) sample.add_item("grid_level", self.grid_level, representation=Representation.NONE) sample.add_item("grid_prune", self.grid_prune, representation=Representation.NONE) ao = np.asarray(dft.numint.eval_ao(mol, grid.coords, deriv=1)) sample.add_item("ao", ao, representation=Representation.AO) return sample