Source code for mldft.ofdft.functional_factory

"""Provides energy functionals."""

from collections.abc import Callable
from functools import partial

import numpy as np
import torch
from loguru import logger
from pyscf import dft, gto

from mldft.ml.data.components.basis_info import BasisInfo
from mldft.ml.data.components.convert_transforms import to_torch
from mldft.ml.data.components.of_data import OFData
from mldft.ofdft import libxc_functionals, torch_functionals
from mldft.ofdft.energies import Energies
from mldft.utils.coeffs_to_grid import coeffs_to_rho
from mldft.utils.molecules import build_molecule_ofdata

LIBXC_PREFIX = "libxc_"


[docs] def requires_grid( target_key: str, negative_integrated_density_penalty_weight: float | None = None ) -> bool: """Check if the functional requires a grid for the evaluation. Returns: True if the functional requires a grid, False otherwise. """ if target_key in ["kin_plus_xc", "tot"] and ( negative_integrated_density_penalty_weight is None or negative_integrated_density_penalty_weight == 0 ): return False else: return True
[docs] def hartree_functional( coeffs: torch.Tensor, coulomb_matrix: torch.Tensor ) -> tuple[float, torch.Tensor]: """Get the Hartree energy for the given coefficients. Args: coulomb_matrix: Coulomb matrix of the system. coeffs: The coefficients p. Returns: The Hartree energy and its gradient. """ hartree_potential_vector = coulomb_matrix @ coeffs hartree_energy = 0.5 * coeffs @ hartree_potential_vector return hartree_energy.item(), hartree_potential_vector
[docs] def nuclear_attraction_functional( coeffs: torch.Tensor, nuclear_attraction_vector: torch.Tensor ) -> tuple[float, torch.Tensor]: """Get the nuclear attraction energy for the given coefficients. Args: nuclear_attraction_vector: The nuclear attraction vector. coeffs: The coefficients p. Returns: The nuclear attraction energy and its gradient. """ nuclear_attraction_energy = coeffs @ nuclear_attraction_vector return nuclear_attraction_energy.item(), nuclear_attraction_vector
[docs] class NegativeIntegratedDensity(torch.nn.Module): """Class to compute a penalty term on the negative integrated density."""
[docs] def __init__(self, grid_weights: torch.Tensor, ao: torch.Tensor, gamma: float = 1000.0): """Initialize the negative integrated density penalty. Args: grid_weights: The grid_weights. ao: The atomic orbital values on the grid. sample: The OFData object. gamma: The penalty factor. """ super().__init__() self.gamma = gamma self.__name__ = "integrated_negative_density" self.weights = grid_weights self.ao = ao
[docs] def forward(self, sample: OFData) -> tuple[torch.Tensor, torch.Tensor]: """Compute the negative integrated density penalty. Args: sample: The OFData object. Returns: The energy (0.0) and gradient of the penalty term. """ coeffs = sample.coeffs coeffs.requires_grad = True density = coeffs_to_rho(coeffs, self.ao) density[density > 0] = 0 integrated_negative_density = self.gamma * (density * density) @ self.weights # get the gradient grad = torch.autograd.grad(integrated_negative_density, coeffs, create_graph=False)[0] sample.coeffs = coeffs.detach() return torch.tensor([0.0], dtype=torch.float32, device=sample.coeffs.device), grad.detach()
[docs] class FunctionalFactory: """Class to construct energy functionals. In the init, only contributions to the energy functional need to be passed. One instance can be used to construct multiple functionals that are specialized to multiple molecules. """
[docs] def __init__( self, *contributions: str | torch.nn.Module | Callable[[OFData], tuple[float, torch.Tensor]], ): """Initialize the energy functional with contributions that will be summed up. Args: contributions: The contributions (addends) to the energy functional. Allowed strings are either libxc functionals, torch functionals, "hartree" or "nuclear_attraction". Can also be a torch.nn.Module or a callable. Raises: ValueError: If there are duplicate contributions. This is important since energies are stored in a dict, see :class:`~mldft.ofdft.energies.Energies`. """ libxc_functionals_list = [] # Used to collect all libxc functionals to compute at once torch_functionals_list = [] # Used to collect all torch functionals to compute at once for contribution in contributions: if isinstance(contribution, str): if ( contribution == "hartree" or contribution == "nuclear_attraction" or "integrated_negative_density" in contribution ): continue elif contribution.startswith(LIBXC_PREFIX): functional_name = contribution[len(LIBXC_PREFIX) :] libxc_functionals.check_kxc_implementation(functional_name) libxc_functionals_list.append(functional_name) else: assert contribution in torch_functionals.str_to_torch_functionals, ( f"{contribution} not in supported torch functionals. Supported functionals" f" are: {list(torch_functionals.str_to_torch_functionals.keys())}" ) torch_functionals_list.append(contribution) elif isinstance(contribution, torch.nn.Module): contribution.__name__ = contribution.target_key if contribution.training: contribution.eval() logger.warning(f"Setting {contribution.__class__.__name__} to eval mode.") elif callable(contribution): pass else: raise ValueError(f"Invalid functional contribution: {contribution}") # throw error if there are duplicate names # this is important since energies are stored in a dict names = [] for c in contributions: if isinstance(c, str): names.append(c) else: names.append(c.__name__) if len(set(names)) != len(names): raise ValueError(f"Duplicate contributions: {names}") self.contributions = list(contributions) self.torch_functionals = torch_functionals_list self.libxc_functionals = libxc_functionals_list classical_functionals = self.libxc_functionals + self.torch_functionals self.max_derivative = libxc_functionals.required_derivative(classical_functionals)
[docs] @classmethod def get_vw_functional(cls, xc_functional: str): """Construct a functional with the von Weizsäcker kinetic energy functional. The von Weizsäcker kinetic energy functional is exact for two-electron molecules. The remaining erroneous contribution is the exchange part of the xc functional (there should be no correlation part for two-electron molecules). Args: xc_functional: The xc functional. """ return cls("libxc_GGA_K_VW", xc_functional, "hartree", "nuclear_attraction")
[docs] @classmethod def from_module( cls, module: torch.nn.Module, xc_functional: str = None, negative_integrated_density_penalty_weight: float = 0.0, ): """Construct a functional from a torch.nn.Module. Args: module: The torch.nn.Module. xc_functional: The xc functional. Can be omitted if the module learned the xc functional. Note that torch_PBE will be converted to the PBE X and PBE C functional. Other torch functionals will be used as is (but can thus only be one functional). Returns: The functional factory. """ if not hasattr(module, "target_key"): raise ValueError("Module has no target key.") target_key = module.target_key if target_key == "kin": assert xc_functional is not None, f"xc_functional must be given for {target_key=}" contributions = [module, xc_functional, "hartree", "nuclear_attraction"] elif target_key == "kin_minus_apbe": assert xc_functional is not None, f"xc_functional must be given for {target_key=}" contributions = [ module, "APBE", xc_functional, "hartree", "nuclear_attraction", ] elif target_key == "kin_plus_xc": contributions = [module, "hartree", "nuclear_attraction"] elif target_key == "tot": contributions = [module] else: raise ValueError(f"Invalid target key: {target_key}") if negative_integrated_density_penalty_weight > 0: contributions.append( f"{negative_integrated_density_penalty_weight}*integrated_negative_density" ) return cls(*contributions)
[docs] def evaluate_functional( self, sample: OFData, mol: gto.Mole, coulomb_matrix: torch.Tensor, nuclear_attraction_vector: torch.Tensor, grid: dft.Grids | None = None, grid_weights: torch.Tensor | None = None, ao: np.ndarray | None = None, max_xc_memory: int = 4000, ) -> tuple[Energies, torch.Tensor]: """Evaluate the energy functional for the given coefficients. The main use of this function is to be partially called by :func:`construct()`. Args: sample: The OFData object for a torch.nn.Module. The coeffs attribute is updated. mol: The molecule of the current optimization. coulomb_matrix: The Coulomb matrix. nuclear_attraction_vector: The nuclear attraction vector. grid: The grid for the evaluation of the xc functional. grid_weights: The weights of the grid points. ao: The atomic orbital values on the grid. max_xc_memory: The maximum (additional to ao memory usage) memory in MB for the evaluation of the xc functional. Returns: The energies and gradient. """ energies = Energies(mol) gradient = torch.zeros_like(sample.coeffs) # Run torch and libxc functionals all at once to not have to recompute the densities if len(self.torch_functionals) > 0: torch_functional_outs = torch_functionals.eval_torch_functionals_blocked_fast( ao, grid_weights, sample.coeffs, self.torch_functionals, max_memory=max_xc_memory, ) for name, (energy, grad) in torch_functional_outs.items(): gradient += grad energies[name] = energy.item() sample.coeffs = sample.coeffs.detach() if len(self.libxc_functionals) > 0: libxc_functionals_outs = libxc_functionals.eval_libxc_functionals( sample.coeffs.numpy(force=True), self.libxc_functionals, grid, ao.cpu().numpy(), self.max_derivative, ) grad = libxc_functionals_outs[1] gradient += torch.as_tensor(grad, dtype=torch.float64, device=sample.coeffs.device) for name, energy in zip(self.libxc_functionals, libxc_functionals_outs[0]): energies[LIBXC_PREFIX + name] = energy for contribution in self.contributions: # Hartree energy if contribution == "hartree": energy, grad = hartree_functional(sample.coeffs, coulomb_matrix=coulomb_matrix) # Nuclear attraction energy elif contribution == "nuclear_attraction": energy, grad = nuclear_attraction_functional( sample.coeffs, nuclear_attraction_vector=nuclear_attraction_vector ) # Negative integrated density elif isinstance(contribution, str): if "integrated_negative_density" in contribution: negative_integrated_density = NegativeIntegratedDensity( grid_weights, ao, gamma=float(contribution[: -len("integrated_negative_density") - 1]), ) energy, grad = negative_integrated_density(sample.detach()) energy = energy.item() grad = grad sample.coeffs = sample.coeffs.detach() else: # Torch and libxc functionals are already evaluated continue # torch.nn.Module elif isinstance(contribution, torch.nn.Module): # If we do float32 model forward, copy the sample, then overwrite new coeffs at the end if contribution.dtype != sample.coeffs.dtype: converted_sample = to_torch( sample.detach().clone(), float_dtype=contribution.dtype, device=sample.coeffs.device, ) else: converted_sample = sample converted_sample = contribution.sample_forward(converted_sample) energy = converted_sample.pred_energy.item() grad = converted_sample.pred_gradient.to(torch.float64) sample.coeffs = converted_sample.coeffs.detach().to(sample.coeffs.dtype) # other callable else: energy, grad = contribution(sample) if isinstance(contribution, str): energies[contribution] = energy else: energies[contribution.__name__] = energy gradient += grad assert not sample.coeffs.requires_grad, "Coeffs should not require grad" assert not gradient.requires_grad, "Gradient should not require grad" return energies, gradient
[docs] def construct( self, mol: gto.Mole, coulomb_matrix: torch.Tensor, nuclear_attraction_vector: torch.Tensor, grid: dft.Grids | None = None, ao: torch.Tensor | None = None, max_xc_memory: int = 4000, ) -> Callable[[OFData], tuple[Energies, torch.Tensor]]: """Construct the energy functional. Args: mol: The molecule of the current optimization. coulomb_matrix: The Coulomb matrix. nuclear_attraction_vector: The nuclear attraction vector. grid: The grid for the evaluation of the xc functional. ao: The atomic orbital values on the grid. max_xc_memory: The maximum (additional to ao memory usage) memory in MB for the evaluation of the xc functional. Returns: The energy functional. """ if grid is not None: grid_weights = torch.as_tensor( grid.weights, dtype=torch.float64, device=coulomb_matrix.device ) else: grid_weights = None return partial( self.evaluate_functional, mol=mol, coulomb_matrix=coulomb_matrix, nuclear_attraction_vector=nuclear_attraction_vector, grid=grid, grid_weights=grid_weights, ao=ao, max_xc_memory=max_xc_memory, )
[docs] def get_energies_label(self, sample: OFData, basis_info: BasisInfo) -> Energies: """Get the energies label from OFData object. Args: sample: OFData object with the ground state energy labels. Returns: Energies object initialized with the energies from OFData object corresponding to the functional contributions. """ mol = build_molecule_ofdata(sample, basis_info.basis_dict) energies = Energies(mol) for contribution in self.contributions: if contribution == "hartree": energies[contribution] = sample["of_labels/energies/ground_state_e_hartree"] elif contribution == "nuclear_attraction": energies[contribution] = sample["of_labels/energies/ground_state_e_ext"] elif isinstance(contribution, str) and ("integrated_negative_density" in contribution): pass elif isinstance(contribution, str) and contribution.startswith(LIBXC_PREFIX): functional_name = contribution[len(LIBXC_PREFIX) :] libxc_functionals.check_kxc_implementation(functional_name) if "_K_" in contribution: energies[contribution] = sample["of_labels/energies/ground_state_e_kin"] elif "_XC_" in contribution or functional_name in dft.libxc.XC_ALIAS: energies[contribution] = sample["of_labels/energies/ground_state_e_xc"] else: raise ValueError(f"No label for {contribution} available.") elif isinstance(contribution, str): if contribution == "APBE": energies[contribution] = sample["of_labels/energies/ground_state_e_kinapbe"] elif contribution == "PBE": energies[contribution] = sample["of_labels/energies/ground_state_e_xc"] else: raise ValueError(f"No label for {contribution} available.") elif isinstance(contribution, torch.nn.Module): energies[contribution.target_key] = sample[ f"of_labels/energies/ground_state_e_{contribution.target_key}" ] elif isinstance(contribution, Callable): energies[contribution.__name__] = sample[ f"of_labels/energies/ground_state_e_{contribution.__name__}" ] else: raise ValueError(f"No label for {contribution} available.") for key, value in energies.energies_dict.items(): if isinstance(value, torch.Tensor): energies.energies_dict[key] = value.item() return energies
[docs] def __str__(self): """Return a string that contains the contribution names.""" names = [c if isinstance(c, str) else c.__name__ for c in self.contributions] return "functional = " + " + ".join(names)