from typing import List
import hydra
import torch
from lightning import Callback
from lightning.pytorch.loggers import Logger
from omegaconf import DictConfig, open_dict
from mldft.ml.data.components.of_data import Representation
from mldft.ml.data.datamodule import OFDataModule
from mldft.ml.models.mldft_module import MLDFTLitModule
from mldft.utils.log_utils import pylogger
log = pylogger.RankedLogger(__name__, rank_zero_only=True)
[docs]
def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
"""Instantiates callbacks from config.
Args:
callbacks_cfg: A DictConfig object containing callback configurations.
Returns:
A list of instantiated callbacks.
"""
callbacks: List[Callback] = []
if not callbacks_cfg:
log.warning("No callback configs found! Skipping..")
return callbacks
if not isinstance(callbacks_cfg, DictConfig):
raise TypeError("Callbacks config must be a DictConfig!")
for _, cb_conf in callbacks_cfg.items():
if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
log.info(f"Instantiating callback <{cb_conf._target_}>")
callbacks.append(hydra.utils.instantiate(cb_conf))
return callbacks
[docs]
def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
"""Instantiates loggers from config.
Args:
logger_cfg: A DictConfig object containing logger configurations.
Returns:
A list of instantiated loggers.
"""
logger: List[Logger] = []
if not logger_cfg:
log.warning("No logger configs found! Skipping...")
return logger
if not isinstance(logger_cfg, DictConfig):
raise TypeError("Logger config must be a DictConfig!")
for _, lg_conf in logger_cfg.items():
if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
log.info(f"Instantiating logger <{lg_conf._target_}>")
logger.append(hydra.utils.instantiate(lg_conf))
return logger
[docs]
def instantiate_model(
checkpoint_path,
device: str | torch.device,
model_dtype: torch.dtype = torch.float64,
) -> MLDFTLitModule:
"""Instantiate a model from a checkpoint.
Args:
checkpoint_path: The path to the checkpoint.
device: The device to load the model on.
model_dtype: The dtype of the model.
deterministic: Whether the model should be deterministic.
Returns:
The instantiated model.
"""
lightning_module = MLDFTLitModule.load_from_checkpoint(checkpoint_path, map_location=device)
lightning_module.eval()
lightning_module.to(model_dtype)
return lightning_module
[docs]
def instantiate_datamodule(
cfg: DictConfig, limit_scf_iterations: int | list[int] | None = -1
) -> OFDataModule:
"""Instantiates datamodule from config.
Instantiates a datamodule from the provided configuration. Adds additional keys and the
transformation matrix.
Args:
cfg: A DictConfig object containing the train configuration.
limit_scf_iterations: Which SCF iterations to use (see
:py:class:`mldft.ml.data.components.dataset.OFDataset`). By default, only the ground
state is loaded.
Returns:
An instantiated datamodule.
"""
# copy the datamodule config to avoid side effects
cfg = cfg.copy()
datamodule = cfg.data.datamodule
with open_dict(datamodule):
datamodule.batch_size = 1
datamodule.num_workers = 1
datamodule.transforms.add_transformation_matrix = True
datamodule.transforms.use_cached_data = False
datamodule.transforms.pre_transforms = [
{
"_target_": "mldft.ml.data.components.convert_transforms.AddOverlapMatrix",
"basis_info": datamodule.basis_info,
},
] + datamodule.transforms.pre_transforms
datamodule.transforms.post_transforms = []
datamodule.dataset_kwargs.limit_scf_iterations = limit_scf_iterations
datamodule.dataset_kwargs.keep_initial_guess = False
datamodule.dataset_kwargs.additional_keys_at_scf_iteration = {
"of_labels/spatial/grad_kin": Representation.GRADIENT,
"of_labels/spatial/grad_xc": Representation.GRADIENT,
}
datamodule.dataset_kwargs.additional_keys_at_ground_state = {
"of_labels/spatial/grad_kin": Representation.GRADIENT,
"of_labels/spatial/grad_xc": Representation.GRADIENT,
"of_labels/energies/e_electron": Representation.SCALAR,
"of_labels/energies/e_ext": Representation.SCALAR,
"of_labels/energies/e_hartree": Representation.SCALAR,
"of_labels/energies/e_kin": Representation.SCALAR,
"of_labels/energies/e_kin_plus_xc": Representation.SCALAR,
"of_labels/energies/e_kin_minus_apbe": Representation.SCALAR,
"of_labels/energies/e_kinapbe": Representation.SCALAR,
"of_labels/energies/e_xc": Representation.SCALAR,
"of_labels/energies/e_tot": Representation.SCALAR,
}
datamodule.dataset_kwargs.additional_keys_per_geometry = {
"of_labels/n_scf_steps": Representation.NONE,
}
return hydra.utils.instantiate(datamodule)