Source code for mldft.ml.data.datamodule

"""PyTorch Lightning DataModule for all molecular datasets."""
import pickle
import platform
from functools import partial
from pathlib import Path

import lightning
import torch

from mldft.ml.data.components.basis_info import BasisInfo
from mldft.ml.data.components.basis_transforms import MasterTransformation
from mldft.ml.data.components.dataset import OFDataset
from mldft.ml.data.components.loader import OFLoader
from mldft.utils import RankedLogger

log = RankedLogger(__name__, rank_zero_only=True)


[docs] def worker_init_fn(worker_id: int, dtype: torch.dtype): """Helper function to set the default torch dtype inside the worker. Args: worker_id: The worker id, is not used here. dtype: The torch dtype to be used inside the data loader. """ torch.set_default_dtype(dtype)
# Optionally, define prepare_data, teardown, test_dataloader
[docs] class OFDataModule(lightning.LightningDataModule): dataset_class = OFDataset
[docs] def __init__( self, split_file: Path | str, data_dir: Path | str, transforms: MasterTransformation, basis_info: BasisInfo, batch_size: int, num_workers: int = 0, pin_memory: bool = False, shuffle_train: bool = True, shuffle_val: bool = True, shuffle_test: bool = False, use_cached_iterations: bool = True, dataset_kwargs: dict = None, dataloader_kwargs: dict = None, ): """Initialize the DataModule. Set up the parameters directory, basis info and transforms that will later be used in setup(). Args: split_file: Name of the yaml file containing the train, val, test split data_dir: Path to the directory containing the data transforms: transforms for data augmentation basis_info: BasisInfo object describing which basis functions are used in the dataset batch_size: batch size num_workers: number of workers for the dataloader pin_memory: whether to pin memory shuffle_train: whether to shuffle the training set shuffle_val: whether to shuffle the validation set shuffle_test: whether to shuffle the test set use_cached_iterations: whether to load the number of scf iterations from the split file or recompute them dataset_kwargs: Keyword arguments passed to :class:`OFDataset`. dataloader_kwargs: Keyword arguments passed to :class:`OFLoader`. """ super().__init__() self.split_file = Path(split_file) self.data_dir = Path(data_dir) # Set transforms and configure them self.transforms = transforms self.label_subdir = transforms.label_subdir self.batch_size = batch_size self.num_workers = num_workers self.pin_memory = pin_memory self.shuffle_train = shuffle_train self.shuffle_val = shuffle_val self.shuffle_test = shuffle_test self.basis_info = basis_info self.use_cached_iterations = use_cached_iterations self.train_set = None self.val_set = None self.test_set = None self.predict_set = None self.dataset_kwargs = dataset_kwargs if dataset_kwargs is not None else {} self.dataloader_kwargs = dataloader_kwargs if dataloader_kwargs is not None else {} # Get the torch default dtype from the main process and pass it to the workers. if platform.system() == "Darwin": self.worker_init_fn = partial(worker_init_fn, dtype=torch.get_default_dtype()) log.warning( "Setting worker_init_fn for MacOS. This will break the seeding of the workers." ) else: self.worker_init_fn = None
[docs] def setup(self, stage: str): """Load the necessary datasets, split into train, validation and test sets. Args: stage: "fit" (train + validate), "validate" or "test" """ if stage not in ["fit", "validate", "test"]: raise ValueError( f"Stage '{stage}' not supported. Must be 'fit', 'validate' or 'test'." ) log.info(f"Using labels from {self.label_subdir}.") assert self.split_file.exists(), f"Split file {self.split_file} does not exist." with self.split_file.open("rb") as f: split_dict = pickle.load(f) # Load paths and iterations from the split file, split this to avoid unnecessary loading if stage == "fit": train_paths = [ self.data_dir / dataset / self.label_subdir / label_path for dataset, label_path, _ in split_dict["train"] ] if self.use_cached_iterations: train_iterations = [scf_iterations for _, _, scf_iterations in split_dict["train"]] else: train_iterations = None self.train_set = self.dataset_class( paths=train_paths, num_scf_iterations_per_path=train_iterations, basis_info=self.basis_info, transforms=self.transforms, **self.dataset_kwargs, ) if stage == "fit" or (stage == "validate" and self.val_set is None): val_paths = [ self.data_dir / dataset / self.label_subdir / label_path for dataset, label_path, _ in split_dict["val"] ] if self.use_cached_iterations: val_iterations = [scf_iterations for _, _, scf_iterations in split_dict["val"]] else: val_iterations = None self.val_set = self.dataset_class( paths=val_paths, num_scf_iterations_per_path=val_iterations, basis_info=self.basis_info, transforms=self.transforms, **self.dataset_kwargs, ) elif stage == "test": test_paths = [ self.data_dir / dataset / self.label_subdir / label_path for dataset, label_path, _ in split_dict["test"] ] if self.use_cached_iterations: test_iterations = [scf_iterations for _, _, scf_iterations in split_dict["test"]] else: test_iterations = None self.test_set = self.dataset_class( paths=test_paths, num_scf_iterations_per_path=test_iterations, basis_info=self.basis_info, transforms=self.transforms, **self.dataset_kwargs, )
[docs] def train_dataloader(self) -> OFLoader: """Return the training dataloader.""" return OFLoader( self.train_set, batch_size=self.batch_size, shuffle=self.shuffle_train, num_workers=self.num_workers, pin_memory=self.pin_memory, worker_init_fn=self.worker_init_fn, drop_last=True, # Important during training to avoid large gradients **self.dataloader_kwargs, )
[docs] def val_dataloader(self) -> OFLoader: """Return the validation dataloader.""" return OFLoader( self.val_set, batch_size=self.batch_size, shuffle=self.shuffle_val, num_workers=self.num_workers, pin_memory=self.pin_memory, worker_init_fn=self.worker_init_fn, **self.dataloader_kwargs, )
[docs] def test_dataloader(self) -> OFLoader: """Return the test dataloader.""" return OFLoader( self.test_set, batch_size=self.batch_size, shuffle=self.shuffle_test, num_workers=self.num_workers, pin_memory=self.pin_memory, worker_init_fn=self.worker_init_fn, **self.dataloader_kwargs, )
[docs] def predict_dataloader(self) -> OFLoader: """Return the prediction dataloader.""" return OFLoader( self.predict_set, batch_size=self.batch_size, shuffle=self.shuffle_test, num_workers=self.num_workers, pin_memory=self.pin_memory, worker_init_fn=self.worker_init_fn, **self.dataloader_kwargs, )