Source code for mldft.datagen.datasets.misc

"""Arbitrary dataset of molecules in .xyz format.

This dataset is itended for small case tests of molecules that are not part of specific datasets.
"""

from functools import lru_cache

import numpy as np
from loguru import logger

from mldft.datagen.datasets.dataset import DataGenDataset
from mldft.utils.molecules import read_xyz_file


[docs] class MiscXYZ(DataGenDataset): """Class for the arbitrary molecules. Attributes: name: Name of the dataset. raw_data_dir: Path to the raw data directory. kohn_sham_data_dir: Path to the kohn-sham data directory. """
[docs] def __init__( self, raw_data_dir: str, kohn_sham_data_dir: str, label_dir: str, filename: str, name: str = "MiscXYZ", num_processes: int = 1, ): """Initialize the MiscXYZ dataset. Args: raw_data_dir: Path to the raw data directory. kohn_sham_data_dir: Path to the kohn-sham data directory. label_dir: Path to the directory containing the labels. filename: The filename to use for the output files. name: Name of the dataset. num_processes: Number of processes to use for dataset verifying or loading. external_potential_modification: configuration for external potential modification. Raises: AssertionError: If the subset is not in the list of available subsets. """ super().__init__( raw_data_dir=raw_data_dir, kohn_sham_data_dir=kohn_sham_data_dir, label_dir=label_dir, filename=filename, name=name, num_processes=num_processes, ) self.filename = filename.split(".")[0] self.num_molecules = self.get_num_molecules()
[docs] def download(self) -> None: """This is just a stub as this kind of dataset can not be downloaded.""" logger.info("No download required/possible for MiscXYZ dataset.")
[docs] def get_num_molecules(self) -> int: """Get the number of molecules in the dataset. Returns: int: Number of molecules in the dataset. """ return len(list(self.raw_data_dir.glob("*.xyz")))
[docs] @lru_cache(maxsize=1) def get_all_atomic_numbers(self) -> np.ndarray: """Get all atomic numbers present in the dataset. Iterates over all molecules in the dataset and collects all atomic numbers. Returns: np.ndarray: Array of atomic numbers present in the dataset. """ all_atomic_numbers = set() for id in self.get_ids(): charges, _ = self.load_charges_and_positions(id) all_atomic_numbers.update(charges) return np.array(sorted(all_atomic_numbers))
[docs] def load_charges_and_positions(self, id: int) -> tuple[list, list]: """Load nuclear charges and positions for the given molecule indices from the .xyz files. Args: ids: Array of indices of the molecules to compute. Returns: np.ndarray: Array of atomic numbers (A). np.ndarray: Array of atomic positions (A, 3). """ # We iterate over this list of files often, but it's still negligible compared to the kohn-sham time file_name = list(self.raw_data_dir.glob(f"*_{id:06}.xyz"))[0] charges, positions = read_xyz_file(file_name) return charges, positions
[docs] def get_ids(self) -> np.ndarray: """Get the indices of the molecules in the dataset. Returns: np.ndarray: Array of indices of the molecules in the dataset. """ return np.sort( np.array( [int(f.stem.split("_")[1]) for f in self.raw_data_dir.glob("*.xyz") if f.is_file()] ) )