"""Callback classes for the minimization."""
import dataclasses
from pathlib import Path
from typing import List, Optional
import numpy as np
import torch
from pyscf import gto
from mldft.ml.data.components.of_data import StrEnumwithCheck
from mldft.ofdft.energies import Energies
from mldft.ofdft.ofstate import OFState, StoppingCriterion
[docs]
@dataclasses.dataclass
class BasicCallback:
"""Callback class for basic information during the minimization.
The basic callback is used to store the energy, gradient norm, learning rate and the
coefficients during the minimization. It can be used to plot the minimization
progress.
Attributes:
mol (Optional[gto.Mole]): PySCF Mole object
energy (List[Energies]): List of electronic energies
gradient_norm (List[float]): List of gradient norm
learning_rate (List[float]): List of learning rates
coeffs (List[np.ndarray]): List of coefficients
"""
mol: Optional[gto.Mole] = None
energy: List[Energies] = dataclasses.field(default_factory=list)
gradient_norm: List[float] = dataclasses.field(default_factory=list)
l2_norm: List[float] = dataclasses.field(default_factory=list)
learning_rate: List[float] = dataclasses.field(default_factory=list)
coeffs: List[torch.Tensor] = dataclasses.field(default_factory=list)
stopping_index: Optional[int] = None
[docs]
def __call__(self, minimization_locals: dict) -> None:
"""Callback function for the minimization.
Args:
minimization_locals (dict): Dictionary containing the local variables of the
minimization function.
"""
self.energy.append(minimization_locals["energy"])
self.gradient_norm.append(minimization_locals["gradient_norm"])
self.learning_rate.append(minimization_locals["learning_rate"])
coeffs = minimization_locals["coeffs"]
sample = minimization_locals["sample"]
delta_coeffs = coeffs.detach() - sample.ground_state_coeffs
l2_norm = torch.sqrt(delta_coeffs @ sample.overlap_matrix @ delta_coeffs).item()
self.l2_norm.append(l2_norm)
self.coeffs.append(coeffs.detach().cpu().clone())
[docs]
def convert_to_numpy(self) -> dict:
"""Convert the data to a numpy array.
Returns:
dict: Dictionary containing the data as numpy arrays.
"""
data = dict()
for field in dataclasses.fields(self):
data[field.name] = np.asarray(getattr(self, field.name))
return data
[docs]
def save_to_file(self, filename: str | Path) -> None:
"""Save the data to a npz file.
Args:
filename (str | Path): Filename.
"""
data = self.convert_to_numpy()
np.savez(filename, **data)
[docs]
@classmethod
def from_npz(cls, filename: str | Path, mol: gto.Mole | None = None) -> "BasicCallback":
"""Load the data from a npz file.
Args:
filename (str | Path): Filename.
mol (Optional[gto.Mole]): PySCF Mole object.
Returns:
BasicCallback: Callback object. If mol is not None, the mol attribute is
set to mol.
"""
callback = cls(mol=mol)
data = np.load(filename, allow_pickle=True)
for field in dataclasses.fields(callback):
setattr(callback, field.name, data[field.name])
return callback
[docs]
class ConvergenceCriterion(StrEnumwithCheck):
LAST_ITER = "last_iter"
MIN_E_STEP = "min_e_step"
MOFDFT = "mofdft"
[docs]
class ConvergenceCallback(BasicCallback):
"""Callback that extends BasisCallback by the stopping criterion."""
[docs]
def get_convergence_result(
self,
convergence_criterion: str | ConvergenceCriterion = "last_iter",
) -> OFState:
"""Return the final "converged" result.
If the criterion is last iteration, the last step is used.
If the criterion is mof_small_scale, the final result is the one with the
smallest energy difference. If the criterion is mof_large_scale, the final
result is chosen hierarchically:
#. The step where the projected gradient first stops decreasing, if existent.
Unless there are constant values, this is the first local minimum of the
projected gradient.
#. The first step where the single-step energy update first stops
decreasing, if existent. Unless there are constant values, this is the
first local minimum of the single-step energy update.
#. The minimal energy update (always exists).
The energy update is to be understood as the absolute value of the difference.
Args:
convergence_criterion: The criterion to choose the final result.
Returns:
OFState of the chosen final result. It has the new attribute stopping_index and
stopping_criterion, which contain the index of the chosen result and the
:py:class:`~mldft.ofdft.ofstate.StoppingCriterion` that was used.
Raises:
ValueError: If the lengths of energy, gradient_norm and coeffs do not match.
"""
ConvergenceCriterion.check_key(
convergence_criterion
), f"Invalid convergence criterion: {convergence_criterion}"
if convergence_criterion == "last_iter":
state = OFState(
mol=self.mol,
coeffs=self.coeffs[-1],
energy=self.energy[-1],
)
stopping_index = len(self.energy) - 1
state.stopping_index = stopping_index
state.stopping_criterion = StoppingCriterion.LAST_ITERATION
self.stopping_index = stopping_index
return state
if not len(self.energy) == len(self.gradient_norm) == len(self.coeffs):
raise ValueError("Length of energy, gradient_norm and coeffs must be the same.")
if len(self.energy) == 0:
raise ValueError("No data available.")
elif len(self.energy) == 1:
state = OFState(
mol=self.mol,
coeffs=self.coeffs[0],
energy=self.energy[0],
)
state.stopping_index = 0
state.stopping_criterion = StoppingCriterion.ENERGY_UPDATE_GLOBAL_MINIMUM
return state
energies = np.asarray([e.total_energy for e in self.energy])
energy_diff = np.diff(energies)
energy_updates = np.abs(energy_diff)
# as a baseline, choose the smallest energy difference
# since diff[i] = energy[i+1] - energy[i], the index is before the
# smallest energy difference, so we are closer to the actual minimum
stopping_index = np.argmin(energy_updates)
stopping_criterion = StoppingCriterion.ENERGY_UPDATE_GLOBAL_MINIMUM
if convergence_criterion == ConvergenceCriterion.MOFDFT:
gradient_norms = np.asarray(self.gradient_norm)
gradient_norm_diff = np.diff(gradient_norms)
gradient_decreasing = gradient_norm_diff < 0
energy_update_decreasing = np.diff(energy_updates) < 0
if not np.all(gradient_decreasing):
# choose index of the first local minimum of the gradient norm,
# i.e. the first index where gradient_decreasing is False
stopping_index = np.argmin(gradient_decreasing)
stopping_criterion = StoppingCriterion.GRADIENT_STOPS_DECREASING
elif not np.all(energy_update_decreasing):
# choose index of the first local minimum of the energy difference,
# i.e. the first index where energy_update_decreasing is False
stopping_index = np.argmin(energy_update_decreasing)
stopping_criterion = StoppingCriterion.ENERGY_UPDATE_STOPS_DECREASING
state = OFState(
mol=self.mol,
coeffs=self.coeffs[stopping_index],
energy=self.energy[stopping_index],
)
state.stopping_index = stopping_index
self.stopping_index = stopping_index
state.stopping_criterion = stopping_criterion
return state