Source code for mldft.datagen.datasets.small_dataset

import numpy as np
from pyscf import gto

from mldft.datagen.datasets.dataset import DataGenDataset


[docs] class SmallDataset(DataGenDataset): """Class for small dataset. Attributes: name: Name of the dataset. raw_data_dir: Path to the directory containing the raw data. kohn_sham_data_dir: Path to the directory containing the Kohn-Sham data. num_processes: Number of processes to use for the computation. num_molecules: Number of molecules in the dataset. molecules: List of molecules in the dataset. """
[docs] def __init__( self, raw_data_dir: str, kohn_sham_data_dir: str, label_dir: str, filename: str, name: str, molecules: list[gto.M], num_processes: int = 1, ): """Define molecules and initialize using parent class. Args: raw_data_dir: Path to the directory containing the raw data. kohn_sham_data_dir: Path to the directory containing the Kohn-Sham data. label_dir: Path to the directory containing the labels. filename: The filename to use for the output files. name: Name of the dataset, used as the folder name. molecules: List of molecules in the dataset. num_processes: Number of processes to use for dataset verifying or loading. """ self.molecules = molecules super().__init__( raw_data_dir=raw_data_dir, kohn_sham_data_dir=kohn_sham_data_dir, label_dir=label_dir, filename=filename, name=name, num_processes=num_processes, ) self.num_molecules = self.get_num_molecules()
[docs] def download(self) -> None: """Create the raw data directory and save the molecules as npz files.""" self.raw_data_dir.mkdir(parents=True, exist_ok=True) self.raw_data_dir.chmod(0o770) for i, mol in enumerate(self.molecules): positions = mol.atom_coords(unit="angstrom") atomic_numbers = mol.atom_charges() # Save data into separate npz files np.savez( self.raw_data_dir / f"mol_{int(i):07}.npz", positions=positions, atomic_numbers=atomic_numbers, )
[docs] def get_num_molecules(self) -> int: """Get the number of molecules in the dataset.""" return len(self.molecules)
[docs] def load_charges_and_positions(self, id: int) -> tuple[np.ndarray, np.ndarray]: """Load nuclear charges and positions for the given molecule indices. Args: id: Index of the molecules to compute. """ data = np.load(self.raw_data_dir / f"mol_{id:07}.npz") charges = data["atomic_numbers"] positions = data["positions"] return charges, positions
[docs] def get_ids(self) -> np.ndarray: """Get the indices of the molecules in the dataset.""" return np.arange(self.get_num_molecules())
[docs] def get_all_atomic_numbers(self) -> np.ndarray: """Get all atomic numbers in the dataset.""" return np.unique(np.concatenate([mol.atom_charges() for mol in self.molecules]))