r"""Wraps the KSDFT calculation and saves results to a .chk file.
This module wraps the KSDFT-calculation as it is implemented in the pyscf library.
A Restricted Kohn Sham - class is build and the molecule along with parameters of the calculation are specified.
To obtain the intermediate results after each iteration of the self-consistent-field
method, which are needed further down the line as additional datapoints for training the model, a callback function was
implemented. For the DIIS-coefficients it is further necessary to patch the extrapolate function of the CDIIS - class
in pyscf. All values are stored at runtime in a .chk file.
 The file format is as follows:
In the "Results" folder the initialization parameters of the calculation are stored:
 - **"converged"** : bool , if the scf-method converged on an energy.
 - **"total_energy"** : float , the total energy of the molecule after the last iteration.
 - **"occupation_numbers_orbitals"** : np.ndarray , the occupation numbers of each orbital after the last iteration.
 - **"molecular_coeffs_orbitals"** : np.ndarray , the coefficients of the orbitals after the last iteration.
 - **"max_cycle"** : int , the maximal number of iterations the calculation was initialized with.
 - **"name_xc_functional"** : string , the specified exchange correlation functional.
 - **"init_guess"** : string , the specified initial guess method.
 - **"convergence_tolerance"** : float , the specified convergence tolerance.
 - **"diis_start_cycle"** : int , the first cycle when diis is being used.
 - **"diis_space"** : int , the maximal number of Fock matrices to average in the diis scheme.
 - **"diis_method"** : string , the specified diis method.
 - **"grid_level"** : int , A parameter specifying the density of the grid used in the xc functional.
 - **"prune_method"** : string , A parameter specifying the pruning scheme of the grid used in the xc functional.
   Additionally, in the "Initialization" folder the initial density matrix and the total energy before the first
   iteration are stored, which are needed for the kinetic energy gradient.
 - **"first_density_matrix"** : np.ndarray , the density matrix before the first iteration.
 - **"first_total_energy"** : float , the total energy before the first iteration.
   for each cycle the intermediate results are saved into a "KS-iteration/{cycle}" folder:
 - **"diis_coefficients"** : np.ndarray , the diis-coefficients used to construct the current Fock matrix.
 - **"occupation_numbers_orbitals"** : np.ndarray , the occupations number of the orbitals.
 - **"molecular_coeffs_orbitals"** : np.ndarray , the coefficients of the orbitals in the specified basis.
 - **"total_energy"** : float , the total energy of the molecule in this iteration.
 - **"coulomb_energy"** : float , the total coulomb energy after this iteration (equal to the hartree energy).
 - **"exchange_correlation_energy"** : float , the energy calculated from the exchange correlation functional.
"""
import os
from pathlib import Path
from typing import Callable
import numpy as np
import scipy
from omegaconf import DictConfig
from pyscf import dft, gto, scf
from pyscf.lib import logger, misc
from pyscf.lib.diis import BLOCK_SIZE
from mldft.ofdft.basis_integrals import get_overlap_tensor
from mldft.ofdft.libxc_functionals import prune_to_string, string_to_prune
from mldft.utils.molecules import construct_aux_mol
[docs]
def perturb_fock_matrix(
    fock_matrix: np.ndarray,
    mf: dft.rks,
    cycle: int,
    start_std: float,
    end_std: float,
    start_cycle: int,
    end_cycle: int,
    of_basis_set: str,
) -> tuple[np.ndarray, np.ndarray]:
    """Perturb the Fock matrix by adding a random perturbation which is sampled in the density
    basis.
    Args:
        fock_matrix: The Fock matrix to perturb.
        mf: The pyscf mf object.
        cycle: The current cycle of the SCF iteration.
        start_std: The standard deviation of the perturbation at the first cycle.
        end_std: The standard deviation of the perturbation at the last cycle.
        start_cycle: The cycle at which the perturbation starts.
        end_cycle: The cycle at which the perturbation ends.
        of_basis_set: String of the density basis set.
    """
    np.random.seed(int.from_bytes(os.urandom(4), byteorder="little"))
    mol_orbital = mf.mol
    mol_density = construct_aux_mol(mol_orbital, aux_basis_name=of_basis_set)
    pert_std = start_std * (end_cycle - cycle) / (end_cycle - start_cycle) + end_std
    overlap_tensor = get_overlap_tensor(mol_density, mol_orbital)
    perturbation_coeffs = np.random.normal(0, pert_std, overlap_tensor.shape[0])
    perturbation = np.einsum("ijk,i->jk", overlap_tensor, perturbation_coeffs, optimize=True)
    perturbed_fock_matrix = fock_matrix + perturbation.reshape(fock_matrix.shape)
    return perturbed_fock_matrix, perturbation_coeffs 
[docs]
def _save_scf_iteration_callback(envs: dict) -> None:
    """Save the data of each iteration in the chkfile.
    Args:
        envs: Dictionary with the local environment variables of the iteration.
    Returns:
        None
    """
    cycle = envs["cycle"]
    # for perturbation of the DIIS coefficients
    mf = envs["mf"]
    envs["mf_diis"].cycle = cycle
    envs["mf_diis"].mf_aux = mf
    # These are the initial values from before the first iteration. They are needed for the kinetic energy gradient
    if cycle == 0:
        infos_initialization = {
            "first_density_matrix": envs["dm_last"],
            "first_total_energy": envs["last_hf_e"],
        }
        scf.chkfile.save(mf.chkfile, "Initialization", infos_initialization)
    diis_coeffs = envs["mf_diis"].diis_coeffs  # only works if DIIS is enabled from the start
    info = {
        "diis_coefficients": diis_coeffs,  # DIIS coefficients
        "occupation_numbers_orbitals": envs["mo_occ"],  # Occupied orbitals
        "molecular_coeffs_orbitals": envs["mo_coeff"],  # Molecular orbital coefficients
        "total_energy": envs["e_tot"],
        "coulomb_energy": envs["vhf"].ecoul,
        "exchange_correlation_energy": envs["vhf"].exc,
    }
    if (
        hasattr(envs["mf_diis"], "perturbation_coeffs")
        and envs["mf_diis"].perturbation_coeffs is not None
    ):
        info["perturbation_coeffs"] = envs["mf_diis"].perturbation_coeffs
    scf.chkfile.save(mf.chkfile, f"KS-iteration/{cycle:d}", info) 
[docs]
class ConvergenceError(Exception):
    """Raised when the calculation did not converge.""" 
[docs]
def ksdft(
    mol: gto.Mole,
    savefile: Path,
    xc_functional: str = r"PBE",
    init_guess: str = "minao",
    max_cycle: int = 50,
    diis_space: int = 8,
    diis_method: str = "CDIIS",
    grid_level: int = 3,
    prune_method: str | Callable | None = dft.nwchem_prune,
    density_fit_basis: str = "def2-universal-jfit",
    density_fit_threshold: int = 30,
    convergence_tolerance: float = 1e-9,
    extra_callback: Callable = None,
    use_perturbation: bool = False,
    perturbation_cfg: DictConfig | None = None,
) -> None:
    """Calculates the non-relativistic restricted spin calculation as in [M-OFDFT]_. XC functional
    should be PBE. Basis set should be 6-31G(2df,p). DIIS is enabled by default. MINAO
    initialization is active by default.
    Args:
        mol: Molecule object
        savefile: Path to the savefile
        xc_functional: The xc functional. Default is PBE.
        init_guess: the initial guess method (default is MINAO)
        max_cycle: the maximal number of iterations
        diis_space: the maximal number of Fock matrices used in the DIIS method
        diis_method: Either CDIIS(default), EDIIS or ADIIS
        grid_level: The grid level to use for the integration grid used in the xc-functional.
        prune_method: The method to prune the integration grid. If None, no pruning is performed.
        density_fit_basis: The basis set to use for the density fitting.
        density_fit_threshold: The threshold for the number of atoms in the molecule to use density fitting.
        convergence_tolerance: The convergence tolerance after which the SCF iteration stops. An alternative value can
            be 1meV 0.0000367493, see Appendix C.2 in [M-OFDFT]_.
        extra_callback: Additional callback function to be called after the original callback is called each iteration.
        use_perturbation: If True, the Fock matrix is perturbed each iteration.
        perturbation_cfg: Configuration for the perturbation of the Fock matrix.
    Returns:
        None
    Raises:
        ConvergenceError: If the calculation did not converge.
    """
    # our diis averaging only works if DIIS is enabled from the start (which is the pyscf default)
    diis_start_cycle = 0
    mf = dft.RKS(mol, xc=xc_functional)
    if len(mol.atom_charges()) >= density_fit_threshold:
        mf.density_fit(density_fit_basis)
    mf.chkfile = savefile.as_posix()
    if isinstance(prune_method, str):
        if prune_method not in string_to_prune:
            raise NotImplementedError("This pruning method is currently not implemented")
        prune_method = string_to_prune[prune_method]
    if prune_method not in prune_to_string:
        prune_string = "unknown_function"
    else:
        prune_string = prune_to_string[prune_method]
    if diis_method == "CDIIS":
        if use_perturbation:
            def perturbation_function(fock_matrix, mf, cycle):
                return perturb_fock_matrix(
                    fock_matrix,
                    mf,
                    cycle,
                    perturbation_cfg.start_std,
                    perturbation_cfg.end_std,
                    perturbation_cfg.start_cycle,
                    perturbation_cfg.end_cycle,
                    perturbation_cfg.of_basis_set,
                )
            mf.DIIS.extrapolate = lambda self, n_d=None: patched_extrapolate(
                self,
                n_d,
                use_perturbation,
                perturbation_function,
                perturbation_cfg.start_cycle,
                perturbation_cfg.end_cycle,
            )
        else:
            mf.DIIS.extrapolate = lambda self, n_d=None: patched_extrapolate(self, n_d)
    else:
        raise NotImplementedError(f"DIIS method {diis_method} not supported.")
    mf.grids.level = grid_level
    mf.grids.prune = prune_method
    if extra_callback is None:
        callback = _save_scf_iteration_callback
    else:
        def callback(*args, **kwargs):
            _save_scf_iteration_callback(*args, **kwargs)
            extra_callback(*args, **kwargs)
    mf.run(
        callback=callback,
        init_guess=init_guess,
        max_cycle=max_cycle,
        conv_tol=convergence_tolerance,
        diis_start_cycle=diis_start_cycle,
        diis_space=diis_space,
    )
    assert mf.damp == 0
    assert mf.level_shift == 0
    if not mf.converged:
        raise ConvergenceError("The calculation did not converge.")
    res = {
        "converged": mf.converged,
        "total_energy": mf.e_tot,
        "occupation_numbers_orbitals": mf.mo_occ,
        "molecular_coeffs_orbitals": mf.mo_coeff,
        "max_cycle": mf.max_cycle,
        "name_xc_functional": xc_functional,
        "init_guess": init_guess,
        "convergence_tolerance": convergence_tolerance,
        "diis_start_cycle": diis_start_cycle,
        "diis_space": diis_space,
        "diis_method": diis_method,
        "grid_level": mf.grids.level,
        "prune_method": prune_string,
    }
    scf.chkfile.save(mf.chkfile, "Results", res)