"""Callback classes for the minimization."""
import dataclasses
from pathlib import Path
from typing import List, Optional
import numpy as np
import torch
from pyscf import gto
from mldft.ml.data.components.of_data import StrEnumwithCheck
from mldft.ofdft.energies import Energies
from mldft.ofdft.ofstate import OFState, StoppingCriterion
[docs]
@dataclasses.dataclass
class BasicCallback:
    """Callback class for basic information during the minimization.
    The basic callback is used to store the energy, gradient norm, learning rate and the
    coefficients during the minimization. It can be used to plot the minimization
    progress.
    Attributes:
        mol (Optional[gto.Mole]): PySCF Mole object
        energy (List[Energies]): List of electronic energies
        gradient_norm (List[float]): List of gradient norm
        learning_rate (List[float]): List of learning rates
        coeffs (List[np.ndarray]): List of coefficients
    """
    mol: Optional[gto.Mole] = None
    energy: List[Energies] = dataclasses.field(default_factory=list)
    gradient_norm: List[float] = dataclasses.field(default_factory=list)
    l2_norm: List[float] = dataclasses.field(default_factory=list)
    learning_rate: List[float] = dataclasses.field(default_factory=list)
    coeffs: List[torch.Tensor] = dataclasses.field(default_factory=list)
    stopping_index: Optional[int] = None
[docs]
    def __call__(self, minimization_locals: dict) -> None:
        """Callback function for the minimization.
        Args:
            minimization_locals (dict): Dictionary containing the local variables of the
                minimization function.
        """
        self.energy.append(minimization_locals["energy"])
        self.gradient_norm.append(minimization_locals["gradient_norm"])
        self.learning_rate.append(minimization_locals["learning_rate"])
        coeffs = minimization_locals["coeffs"]
        sample = minimization_locals["sample"]
        delta_coeffs = coeffs.detach() - sample.ground_state_coeffs
        l2_norm = torch.sqrt(delta_coeffs @ sample.overlap_matrix @ delta_coeffs).item()
        self.l2_norm.append(l2_norm)
        self.coeffs.append(coeffs.detach().cpu().clone()) 
[docs]
    def convert_to_numpy(self) -> dict:
        """Convert the data to a numpy array.
        Returns:
            dict: Dictionary containing the data as numpy arrays.
        """
        data = dict()
        for field in dataclasses.fields(self):
            data[field.name] = np.asarray(getattr(self, field.name))
        return data 
[docs]
    def save_to_file(self, filename: str | Path) -> None:
        """Save the data to a npz file.
        Args:
            filename (str | Path): Filename.
        """
        data = self.convert_to_numpy()
        np.savez(filename, **data) 
[docs]
    @classmethod
    def from_npz(cls, filename: str | Path, mol: gto.Mole | None = None) -> "BasicCallback":
        """Load the data from a npz file.
        Args:
            filename (str | Path): Filename.
            mol (Optional[gto.Mole]): PySCF Mole object.
        Returns:
            BasicCallback: Callback object. If mol is not None, the mol attribute is
                set to mol.
        """
        callback = cls(mol=mol)
        data = np.load(filename, allow_pickle=True)
        for field in dataclasses.fields(callback):
            setattr(callback, field.name, data[field.name])
        return callback 
 
[docs]
class ConvergenceCriterion(StrEnumwithCheck):
    LAST_ITER = "last_iter"
    MIN_E_STEP = "min_e_step"
    MOFDFT = "mofdft" 
[docs]
class ConvergenceCallback(BasicCallback):
    """Callback that extends BasisCallback by the stopping criterion."""
[docs]
    def get_convergence_result(
        self,
        convergence_criterion: str | ConvergenceCriterion = "last_iter",
    ) -> OFState:
        """Return the final "converged" result.
        If the criterion is last iteration, the last step is used.
        If the criterion is mof_small_scale, the final result is the one with the
        smallest energy difference. If the criterion is mof_large_scale, the final
        result is chosen hierarchically:
            #. The step where the projected gradient first stops decreasing, if existent.
               Unless there are constant values, this is the first local minimum of the
               projected gradient.
            #. The first step where the single-step energy update first stops
               decreasing, if existent. Unless there are constant values, this is the
               first local minimum of the single-step energy update.
            #. The minimal energy update (always exists).
        The energy update is to be understood as the absolute value of the difference.
        Args:
            convergence_criterion: The criterion to choose the final result.
        Returns:
            OFState of the chosen final result. It has the new attribute stopping_index and
            stopping_criterion, which contain the index of the chosen result and the
            :py:class:`~mldft.ofdft.ofstate.StoppingCriterion` that was used.
        Raises:
            ValueError: If the lengths of energy, gradient_norm and coeffs do not match.
        """
        ConvergenceCriterion.check_key(
            convergence_criterion
        ), f"Invalid convergence criterion: {convergence_criterion}"
        if convergence_criterion == "last_iter":
            state = OFState(
                mol=self.mol,
                coeffs=self.coeffs[-1],
                energy=self.energy[-1],
            )
            stopping_index = len(self.energy) - 1
            state.stopping_index = stopping_index
            state.stopping_criterion = StoppingCriterion.LAST_ITERATION
            self.stopping_index = stopping_index
            return state
        if not len(self.energy) == len(self.gradient_norm) == len(self.coeffs):
            raise ValueError("Length of energy, gradient_norm and coeffs must be the same.")
        if len(self.energy) == 0:
            raise ValueError("No data available.")
        elif len(self.energy) == 1:
            state = OFState(
                mol=self.mol,
                coeffs=self.coeffs[0],
                energy=self.energy[0],
            )
            state.stopping_index = 0
            state.stopping_criterion = StoppingCriterion.ENERGY_UPDATE_GLOBAL_MINIMUM
            return state
        energies = np.asarray([e.total_energy for e in self.energy])
        energy_diff = np.diff(energies)
        energy_updates = np.abs(energy_diff)
        # as a baseline, choose the smallest energy difference
        # since diff[i] = energy[i+1] - energy[i], the index is before the
        # smallest energy difference, so we are closer to the actual minimum
        stopping_index = np.argmin(energy_updates)
        stopping_criterion = StoppingCriterion.ENERGY_UPDATE_GLOBAL_MINIMUM
        if convergence_criterion == ConvergenceCriterion.MOFDFT:
            gradient_norms = np.asarray(self.gradient_norm)
            gradient_norm_diff = np.diff(gradient_norms)
            gradient_decreasing = gradient_norm_diff < 0
            energy_update_decreasing = np.diff(energy_updates) < 0
            if not np.all(gradient_decreasing):
                # choose index of the first local minimum of the gradient norm,
                # i.e. the first index where gradient_decreasing is False
                stopping_index = np.argmin(gradient_decreasing)
                stopping_criterion = StoppingCriterion.GRADIENT_STOPS_DECREASING
            elif not np.all(energy_update_decreasing):
                # choose index of the first local minimum of the energy difference,
                # i.e. the first index where energy_update_decreasing is False
                stopping_index = np.argmin(energy_update_decreasing)
                stopping_criterion = StoppingCriterion.ENERGY_UPDATE_STOPS_DECREASING
        state = OFState(
            mol=self.mol,
            coeffs=self.coeffs[stopping_index],
            energy=self.energy[stopping_index],
        )
        state.stopping_index = stopping_index
        self.stopping_index = stopping_index
        state.stopping_criterion = stopping_criterion
        return state