Source code for mldft.ofdft.density_optimization

from typing import Callable, Optional

import numpy as np
import torch
from loguru import logger
from pyscf import gto

from mldft.ml.data.components.basis_transforms import transform_tensor_with_sample
from mldft.ml.data.components.of_data import OFData, Representation
from mldft.ml.models.mldft_module import MLDFTLitModule
from mldft.ofdft.callbacks import ConvergenceCallback, ConvergenceCriterion
from mldft.ofdft.energies import Energies
from mldft.ofdft.functional_factory import FunctionalFactory
from mldft.ofdft.initial_guess import (
    dumb_guess,
    hueckel_guess,
    label_guess,
    minao_guess,
    perturbed_label_guess,
    proj_minao_guess,
)
from mldft.ofdft.optimizer import Optimizer
from mldft.utils.grids import grid_setup
from mldft.utils.pyscf_pretty_print import mole_to_sum_formula
from mldft.utils.sad_guesser import SADGuesser


[docs] def initial_guess( sample: OFData, mol: gto.Mole, initial_guess_str: str, normalize_initial_guess: bool = True, ks_basis: str = "6-31G(2df,p)", proj_minao_module: MLDFTLitModule | None = None, sad_guess_kwargs: dict | None = None, ) -> torch.Tensor: """Return the initial guess for the coefficients. Args: sample: The OFData object. mol: The molecule. Returns: torch.Tensor: The initial guess for the coefficients. Raises: KeyError: If the initial guess is not in INITIAL_GUESS (e.g. implemented guess). """ if initial_guess_str == "dumb": return dumb_guess(sample, mol) elif initial_guess_str == "label": return label_guess(sample) elif initial_guess_str.startswith("perturbed_label"): scale = float(initial_guess_str.split("_")[-1]) return perturbed_label_guess(sample, scale) elif initial_guess_str == "hueckel": return hueckel_guess(sample, mol, ks_basis, normalize=normalize_initial_guess) elif initial_guess_str == "minao": return minao_guess(sample, mol, ks_basis, normalize=normalize_initial_guess) elif initial_guess_str == "proj_minao": assert proj_minao_module is not None, "A lightning Module must be provided for proj_minao." return proj_minao_guess( sample, mol, proj_minao_module, ks_basis, normalize=normalize_initial_guess, ) elif initial_guess_str == "sad": guesser = SADGuesser.from_dataset_statistics(**sad_guess_kwargs) guesser = guesser.to(device=sample.coeffs.device) # this is a bit of a hack, but we need the untransformed basis integrals, and this is one way to get them untransformed_sample = sample.clone() untransformed_sample.dual_basis_integrals = transform_tensor_with_sample( untransformed_sample, untransformed_sample.dual_basis_integrals, Representation.DUAL_VECTOR, invert=True, ) coeffs_guess = guesser(untransformed_sample) coeffs_guess = transform_tensor_with_sample(sample, coeffs_guess, Representation.VECTOR) return coeffs_guess elif initial_guess_str == "sad_transformed": # compute the SAD guess in the (potentially transformed) basis of the sample # this seems to work well, except for global natural reparametrization guesser = SADGuesser.from_dataset_statistics(**sad_guess_kwargs) guesser = guesser.to(device=sample.coeffs.device) return guesser(sample) else: raise KeyError(f"Initial guess {initial_guess_str} not recognized.")
[docs] def compute_density_optimization_metrics( callback: ConvergenceCallback, ground_state_energy: Energies, convergence_criterion: str | ConvergenceCriterion, ) -> dict: """Compute metrics for the density optimization process.""" metric_dict = {} total_energies = np.asarray([e.total_energy for e in callback.energy]) converged_state = callback.get_convergence_result(convergence_criterion) stopping_index = converged_state.stopping_index signed_energy_error = total_energies[stopping_index] - ground_state_energy.total_energy metric_dict["gs_signed_energy_error"] = signed_energy_error metric_dict["gs_abs_energy_error"] = np.abs(signed_energy_error) energy_differences = total_energies - ground_state_energy.total_energy metric_dict["min_energy_error"] = energy_differences.min() metric_dict["gs_l2_norm"] = callback.l2_norm[stopping_index] metric_dict["minimum_l2_norm"] = min(callback.l2_norm) metric_dict["stopping_index"] = stopping_index metric_dict["gs_gradient_norm"] = callback.gradient_norm[stopping_index] metric_dict["min_gradient_norm"] = min(callback.gradient_norm) return metric_dict
[docs] def density_optimization( sample: OFData, mol: gto.Mole, optimizer: Optimizer, func_factory: FunctionalFactory, callback: ConvergenceCallback | None = None, initialization: str | torch.Tensor = "minao", max_xc_memory: int = 4000, normalize_initial_guess: bool = True, ks_basis: str = "6-31G(2df,p)", proj_minao_module: Optional["MLDFTLitModule"] = None, sad_guess_kwargs: dict | None = None, disable_printing: bool = False, disable_pbar: bool = None, ) -> tuple[Energies, torch.Tensor, bool, Callable]: """Perform density optimization for a given sample and molecule. Args: sample: OFData sample object containing required tensors for the functional. mol: Molecule object used for the initial guess and building the grid. optimizer: Optimizer used for the density optimization process. func_factory: FunctionalFactory object used to construct the energy functional. callback: ConvergenceCallback object to store the optimization history. initialization: Which initial guess to use. max_xc_memory: Maximum memory for the XC functional. normalize_initial_guess: Whether to normalize the initial guess to the correct number of electrons. ks_basis: Basis set to use for the initial guess. proj_minao_module: Lightning Module used for the proj_minao initial guess. Only required if initialization is 'proj_minao'. sad_guess_kwargs: Dictionary of arguments for the SAD guesser. Only required if initialization is 'sad' or 'sad_transformed'. disable_printing: Whether to disable printing (useful when running density optimization during training with the rich progress bar, as output becomes messy). disable_pbar: Whether to disable the progress bar. If None, it will use the value of disable_printing. Returns: Energies: Energies object containing the total energy and its components. torch.Tensor: Final coefficients in the AO basis. bool: Whether the optimization converged. """ if disable_pbar is None: disable_pbar = disable_printing # Prepare aos first, otherwise sample.clones will be too large if hasattr(sample, "grid_level"): grid = grid_setup(mol, sample.grid_level, sample.grid_prune) else: grid = None if hasattr(sample, "ao"): ao = sample.ao # On older GPUs having the tensor in this format gives a 3x speed up for the xc functional if ao.device.type == "cuda" and torch.cuda.get_device_capability()[0] < 8: ao = ao.transpose(1, 2).contiguous().transpose(1, 2) # this leaves large unused memory blocks, so we free them torch.cuda.empty_cache() # delete from sample to allow for sample cloning sample.delete_item("ao") else: ao = None # Prepare initial guess and functional if isinstance(initialization, torch.Tensor): initial_coeffs = initialization else: initial_coeffs = initial_guess( sample, mol, initialization, normalize_initial_guess, ks_basis, proj_minao_module, sad_guess_kwargs, ) sample.coeffs = initial_coeffs nelectron_guess = sample.dual_basis_integrals @ initial_coeffs if not np.isclose(nelectron_guess.item(), mol.nelectron, atol=1e-2): logger.warning( f"Guess normalized to {nelectron_guess} electrons, but {mol.nelectron} expected." ) energy_functional = func_factory.construct( mol, sample.coulomb_matrix, sample.nuclear_attraction_vector, grid, ao, max_xc_memory, ) # Optimization final_energies, converged = optimizer.optimize( sample, energy_functional, callback, disable_pbar=disable_pbar ) final_coeffs = transform_tensor_with_sample( sample, sample.coeffs, Representation.VECTOR, invert=True ) if not disable_printing: logger.info( f"{'Summary':<11}{mole_to_sum_formula(mol, use_subscript=True):<16}" f"E={final_energies.total_energy} Ha, " f"converged: {converged}" ) return final_energies, final_coeffs, converged, energy_functional
[docs] def density_optimization_with_label( sample: OFData, mol: gto.Mole, optimizer: Optimizer, func_factory: FunctionalFactory, callback: ConvergenceCallback | None, max_xc_memory: int = 4000, initial_guess_str: str = "minao", normalize_initial_guess: bool = True, ks_basis: str = "6-31G(2df,p)", proj_minao_module: Optional["MLDFTLitModule"] = None, sad_guess_kwargs: dict | None = None, convergence_criterion: str | ConvergenceCriterion = "last_iter", disable_printing: bool = False, ) -> tuple[dict, ConvergenceCallback, Energies, Callable]: """Perform density optimization for a given sample and molecule. Args: sample: OFData sample object containing required tensors for the functional. mol: Molecule object used for the initial guess and building the grid. optimizer: Optimizer used for the density optimization process. initial_guess_str: Which initial guess to use. normalize_initial_guess: Whether to normalize the initial guess to the correct number of electrons. ks_basis: Basis set to use for the initial guess. proj_minao_module: Lightning Module used for the proj_minao initial guess. Only required if initial_guess_str is 'proj_minao'. sad_guess_kwargs: Dictionary of arguments for the SAD guesser. Only required if initial_guess_str is 'sad' or 'sad_transformed'. convergence_criterion: Criterion to determine the ground state of the optimization process. Returns: metric_dict: Dictionary containing metrics of the density optimization. callback: ConvergenceCallback object containing the optimization history for plotting and further analysis. energies_label: The ground state label of the molecule. energy_functional: The constructed energy functional. """ _, _, _, energy_functional = density_optimization( sample, mol, optimizer, func_factory, callback=callback, initialization=initial_guess_str, normalize_initial_guess=normalize_initial_guess, max_xc_memory=max_xc_memory, ks_basis=ks_basis, proj_minao_module=proj_minao_module, sad_guess_kwargs=sad_guess_kwargs, disable_printing=True, disable_pbar=disable_printing, ) energies_label = func_factory.get_energies_label(sample, basis_info=sample.basis_info) # Compute metrics metric_dict = compute_density_optimization_metrics( callback, energies_label, convergence_criterion ) if not disable_printing: logger.info( f"{'Summary':<11}{mole_to_sum_formula(mol, use_subscript=True):<16}" f"ΔE={metric_dict['gs_signed_energy_error'] * 1000:6.2g} mHa, " f"ρNorm={metric_dict['gs_l2_norm']:6.2g}, " f"gNorm={metric_dict['gs_gradient_norm']:7.2e}" ) return metric_dict, callback, energies_label, energy_functional