Source code for mldft.ml.models.components.atom_ref

"""Atomic reference module from [M-OFDFT]_."""

import numpy as np
import torch
import torch.nn as nn
from torch import Tensor
from torch_geometric.data import Batch
from torch_geometric.nn.pool import global_add_pool

from mldft.ml.data.components.basis_info import BasisInfo
from mldft.ml.data.components.of_data import OFData


[docs] class AtomRef(nn.Module): """Atomic reference module from [M-OFDFT]_."""
[docs] def __init__( self, t_z: torch.Tensor | np.ndarray, t_global: float, g_z: torch.Tensor | np.ndarray, ) -> None: """Initialize the atomic reference module. Alternatively use :meth:`from_dataset_statistics`. Args: t_z: Kinetic energy per atomic species. Shape (n_atom_types,). t_global: Global bias. g_z: Mean gradient vector, concatenated over atomic species. Shape (basis,). """ super().__init__() self.register_buffer("t_z", torch.as_tensor(t_z, dtype=torch.get_default_dtype())) self.t_global = t_global self.register_buffer("g_z", torch.as_tensor(g_z, dtype=torch.get_default_dtype()))
[docs] @classmethod def from_dataset_statistics( cls, dataset_statistics, basis_info: BasisInfo = None, weigher_key: str = "has_energy_label", scalar_only=False, ) -> "AtomRef": """Initialize using a :class:`~mldft.ml.preprocess.dataset_statistics.DatasetStatistics` object. Args: dataset_statistics: Dataset statistics object. basis_info: Basis information object. Required if scalar_only is True, to infer which features are scalars. weigher_key (str): Selects the sample weigher for the dataset statistics. scalar_only (bool): If True, only the scalar features are used. Defaults to False. """ if not scalar_only: return cls( t_z=dataset_statistics.load_statistic(weigher_key, "atom_ref_atom_type_bias"), t_global=float( dataset_statistics.load_statistic(weigher_key, "atom_ref_global_bias") ), g_z=dataset_statistics.load_statistic(weigher_key, "gradient_label/mean"), ) else: assert basis_info is not None, "basis_info must be provided if scalar_only is True" scalar_mask = basis_info.l_per_basis_func == 0 return cls( t_z=dataset_statistics.load_statistic( weigher_key, "scalar_atom_ref_atom_type_bias" ), t_global=float( dataset_statistics.load_statistic(weigher_key, "scalar_atom_ref_global_bias") ), g_z=dataset_statistics.load_statistic(weigher_key, "gradient_label/mean") * scalar_mask, )
[docs] def forward( self, atom_ind: torch.Tensor, coeffs: torch.Tensor, basis_function_ind: torch.Tensor, atom_batch: torch.Tensor | None, # do not set default to avoid passing data from a batch on accident coeffs_batch: torch.Tensor | None = None, ) -> torch.Tensor: """Calculates the energy for one molecule, or a batch of molecules, according to a linear fit, e.g. computed by :class:`~mldft.ml.preprocess.dataset_statistics.DatasetStatistics`. Args: atom_ind: Atom types in the molecule (indexing the atomic numbers present in the basis). Shape (n_atoms,). coeffs: Coefficients of the molecule. Shape (n_basis,). basis_function_ind: Indices of the basis functions in the molecule. See :attr:`mldft.ml.data.components.of_data.OFData.basis_function_ind`. Shape (n_basis,). atom_batch: Batch index for the atoms. Shape (n_atoms,), or None if not batched. coeffs_batch: Batch index for the coefficients. Shape (n_basis,), or None if not batched. """ # compute the gradient times the coefficients for the molecule, i.e. the part proportional to the coefficients g_m = self.g_z[basis_function_ind] # shape (n_basis) g_times_p = global_add_pool(g_m * coeffs, coeffs_batch) # shape (n_batch) # compute the bias per molecule, by summing the bias per atom t_m = global_add_pool(self.t_z[atom_ind], atom_batch) # return the sum of all contributions return g_times_p + t_m + self.t_global
[docs] def sample_forward(self, sample: OFData) -> torch.Tensor: """Calculate the expected energy for a sample (or batch) with a linear fit, e.g. computed by :class:`~mldft.ml.preprocess.dataset_statistics.DatasetStatistics`. Args: sample: OFData object containing the molecule. """ if isinstance(sample, Batch): coeffs_batch = sample.coeffs_batch atom_batch = sample.batch else: coeffs_batch = None atom_batch = None return self( atom_ind=sample.atom_ind, coeffs=sample.coeffs, atom_batch=atom_batch, coeffs_batch=coeffs_batch, basis_function_ind=sample.basis_function_ind, )
[docs] class SimpleQuadraticAtomRef(nn.Module):
[docs] def __init__(self, ground_state_coeff_mean: Tensor, factor: float): """ Args: ground_state_coeff_mean: The mean of the ground state coefficients per atom type. factor: The factor to multiply the quadratic term with. """ super().__init__() self.ground_state_coeff_mean = ground_state_coeff_mean self.factor = factor
[docs] def forward( self, coeffs: Tensor, basis_function_ind: Tensor, coeffs_batch: Tensor | None ) -> Tensor: """Compute an isotropic quadratic energy around the mean ground state coeffs. Args: coeffs: The coefficients of the molecule. Shape (n_basis,). basis_function_ind: Indices of the basis functions in the molecule. See :attr:`mldft.ml.data.components.of_data.OFData.basis_function_ind`. Shape (n_basis,). coeffs_batch: Batch index for the coefficients. Shape (n_basis,), or None if not batched. """ coeff_delta = coeffs - self.ground_state_coeff_mean[basis_function_ind] return self.factor * global_add_pool(coeff_delta**2, coeffs_batch)
[docs] def sample_forward(self, sample: OFData) -> Tensor: """Compute the energy for a sample. Args: sample: The sample to compute the energy for. """ return self(sample.coeffs, sample.basis_function_ind, sample.coeffs_batch)