Source code for mldft.datagen.generate_labels_dataset

"""Compute labels for molecules in the dataset and save them as zarr.zip files.

Can be run with the `mldft_labelgen` command.
"""

import multiprocessing
from functools import partial
from pathlib import Path

import hydra
from loguru import logger
from omegaconf import DictConfig, OmegaConf
from tqdm import tqdm

from mldft.datagen.datasets.dataset import DataGenDataset
from mldft.datagen.kohn_sham_dataset import check_config
from mldft.datagen.methods.label_generation import str_to_calculation_fct
from mldft.datagen.methods.save_labels_in_zarr_file import save_density_fitted_data
from mldft.utils import extras
from mldft.utils.molecules import build_molecule_np, construct_aux_mol, load_scf
from mldft.utils.multiprocess import (
    configure_max_memory_per_process,
    configure_processes_and_threads,
    unpack_args_for_imap,
)
from mldft.utils.pyscf_pretty_print import mole_to_sum_formula


[docs] def get_id_and_sample_id_from_chk_file(file: Path) -> tuple[int, int | None]: """Get the molecule id and sample id from the .chk file. Args: file: The .chk file. Returns: Tuple of molecule id and sample id. """ split_filename = file.stem.split("_")[-1].split(".") if len(split_filename) == 2: molecule_id, sample_id = int(split_filename[0]), int(split_filename[1]) else: molecule_id = int(split_filename[0]) sample_id = None return molecule_id, sample_id
def get_zarr_file_path(label_dir: Path, molecule_id: int, sample_id: int | None = None) -> Path: if sample_id is not None: return label_dir / f"{molecule_id:07}.{sample_id:07}.zarr.zip" else: return label_dir / f"{molecule_id:07}.zarr.zip"
[docs] @unpack_args_for_imap def run_labelgen_task( dataset: DataGenDataset, chk_file: Path, calculation_fct: callable, orbital_basis: str, kohn_sham_basis: str, of_basis_set: str, ): """Run the label generation task for a single molecule.""" molecule_id, sample_id = get_id_and_sample_id_from_chk_file(chk_file) # Load charges and build molecule in a way similar to the improved Kohn-Sham logic: charges, positions = dataset.load_charges_and_positions(molecule_id) mol_orbital_basis = build_molecule_np(charges, positions, basis=orbital_basis, unit="Angstrom") mol_density_basis = construct_aux_mol( mol_orbital_basis, aux_basis_name=of_basis_set, unit="Bohr" ) logger.info( f"Computing molecule {molecule_id} {mole_to_sum_formula(mol_orbital_basis, True)} with " f"{len(mol_orbital_basis.atom_charges())} atoms." ) label_path = get_zarr_file_path(dataset.label_dir, molecule_id, sample_id) results, initialization, data_of_iteration = load_scf(chk_file) data = calculation_fct( results=results, initialization=initialization, data_of_iteration=data_of_iteration, mol_orbital_basis=mol_orbital_basis, mol_density_basis=mol_density_basis, ) save_density_fitted_data( label_path, mol_density_basis, kohn_sham_basis=kohn_sham_basis, path_to_basis_info=(chk_file.parent / "basis_info.npz").as_posix(), molecular_data=data, mol_id=chk_file.stem, )
[docs] def save_dataset_info(cfg: DictConfig, label_dir: Path): """Save relevant config information for the dataset in a yaml file.""" keys_to_save = ("kohn_sham", "density_fitting_method", "of_basis_set", "dataset") dataset_info = {key: cfg.get(key) for key in keys_to_save} OmegaConf.save(config=dataset_info, f=label_dir / "dataset_info.yaml")
[docs] def run_label_generation(cfg: DictConfig): """Run the label generation for the dataset.""" # Setup logger to log via tqdm and to the configured log file logger.remove() logger_format = ( "<green>{time:HH:mm:ss}</green>|<level>{level: <8}</level>|<level>{message}</level>" ) logger.add( lambda msg: tqdm.write(msg, end=""), format=logger_format, level=cfg.get("log_level", "INFO"), colorize=True, enqueue=True, ) num_processes = cfg.get("num_processes") num_threads_per_process = cfg.get("num_threads_per_process") # Instantiate Dataset dataset_settings = OmegaConf.to_container(cfg.dataset, resolve=True) dataset_settings["num_processes"] = num_processes dataset = hydra.utils.instantiate(dataset_settings) kohn_sham_settings = cfg.kohn_sham check_config(cfg.kohn_sham, dataset.kohn_sham_data_dir / "config.yaml") orbital_basis = kohn_sham_settings.basis # Save dataset info (config) in the directory above the labels directory save_dataset_info(cfg, dataset.label_dir.parent) density_fitting_method = cfg.get("density_fitting_method") of_basis_set = cfg.get("of_basis_set") ids_this_run = dataset.get_ids_todo_labelgen( start_idx=cfg.start_idx, max_num_molecules=cfg.n_molecules ) chk_files = dataset.get_all_chk_files_from_ids(ids_this_run) calculation_fct_name = cfg.label_src.calculation_fct calculation_fct = str_to_calculation_fct[calculation_fct_name] calculation_fct = partial(calculation_fct, density_fitting_method=density_fitting_method) num_processes = min(num_processes, len(chk_files)) if num_processes <= 1: if num_threads_per_process == 1: logger.warning("Running in single process mode using 1 thread, this may take a while") else: logger.info(f"Running in single process mode using {num_threads_per_process} threads.") for chk_file in tqdm(chk_files, desc="Label Generation", dynamic_ncols=True): run_labelgen_task( dataset, chk_file, calculation_fct, orbital_basis, kohn_sham_settings.basis, of_basis_set, ) else: logger.info( f"Using {num_processes} processes with {num_threads_per_process} threads each." ) args_list = [ ( dataset, chk_file, calculation_fct, orbital_basis, kohn_sham_settings.basis, of_basis_set, ) for chk_file in chk_files ] with multiprocessing.Pool(processes=num_processes) as pool: imap = pool.imap_unordered(run_labelgen_task, args_list) for _ in tqdm( imap, total=len(chk_files), desc="Label Generation", dynamic_ncols=True, smoothing=0, ): pass logger.info("Label generation and saving done.")
[docs] @hydra.main(version_base="1.3", config_path="../../configs/datagen", config_name="config.yaml") def main(cfg: DictConfig): """Hydra entry point for label generation.""" logger.add(cfg.log_file, rotation="10 MB", enqueue=True, backtrace=True, diagnose=True) cfg.num_processes, cfg.num_threads_per_process = configure_processes_and_threads( cfg.get("num_processes"), cfg.get("num_threads_per_process") ) cfg.max_memory_per_process = configure_max_memory_per_process( cfg.get("max_memory_per_process") ) extras(cfg) run_label_generation(cfg)
if __name__ == "__main__": main()