"""PyTorch Lightning DataModule for all molecular datasets."""
import pickle
import platform
from functools import partial
from pathlib import Path
import lightning
import torch
from mldft.ml.data.components.basis_info import BasisInfo
from mldft.ml.data.components.basis_transforms import MasterTransformation
from mldft.ml.data.components.dataset import OFDataset
from mldft.ml.data.components.loader import OFLoader
from mldft.utils import RankedLogger
log = RankedLogger(__name__, rank_zero_only=True)
[docs]
def worker_init_fn(worker_id: int, dtype: torch.dtype):
"""Helper function to set the default torch dtype inside the worker.
Args:
worker_id: The worker id, is not used here.
dtype: The torch dtype to be used inside the data loader.
"""
torch.set_default_dtype(dtype)
# Optionally, define prepare_data, teardown, test_dataloader
[docs]
class OFDataModule(lightning.LightningDataModule):
dataset_class = OFDataset
[docs]
def __init__(
self,
split_file: Path | str,
data_dir: Path | str,
transforms: MasterTransformation,
basis_info: BasisInfo,
batch_size: int,
num_workers: int = 0,
pin_memory: bool = False,
shuffle_train: bool = True,
shuffle_val: bool = True,
shuffle_test: bool = False,
use_cached_iterations: bool = True,
dataset_kwargs: dict = None,
dataloader_kwargs: dict = None,
):
"""Initialize the DataModule.
Set up the parameters directory, basis info and transforms that will later be used in setup().
Args:
split_file: Name of the yaml file containing the train, val, test split
data_dir: Path to the directory containing the data
transforms: transforms for data augmentation
basis_info: BasisInfo object describing which basis functions are used in the dataset
batch_size: batch size
num_workers: number of workers for the dataloader
pin_memory: whether to pin memory
shuffle_train: whether to shuffle the training set
shuffle_val: whether to shuffle the validation set
shuffle_test: whether to shuffle the test set
use_cached_iterations: whether to load the number of scf iterations from the split file or recompute them
dataset_kwargs: Keyword arguments passed to :class:`OFDataset`.
dataloader_kwargs: Keyword arguments passed to :class:`OFLoader`.
"""
super().__init__()
self.split_file = Path(split_file)
self.data_dir = Path(data_dir)
# Set transforms and configure them
self.transforms = transforms
self.label_subdir = transforms.label_subdir
self.batch_size = batch_size
self.num_workers = num_workers
self.pin_memory = pin_memory
self.shuffle_train = shuffle_train
self.shuffle_val = shuffle_val
self.shuffle_test = shuffle_test
self.basis_info = basis_info
self.use_cached_iterations = use_cached_iterations
self.train_set = None
self.val_set = None
self.test_set = None
self.predict_set = None
self.dataset_kwargs = dataset_kwargs if dataset_kwargs is not None else {}
self.dataloader_kwargs = dataloader_kwargs if dataloader_kwargs is not None else {}
# Get the torch default dtype from the main process and pass it to the workers.
if platform.system() == "Darwin":
self.worker_init_fn = partial(worker_init_fn, dtype=torch.get_default_dtype())
log.warning(
"Setting worker_init_fn for MacOS. This will break the seeding of the workers."
)
else:
self.worker_init_fn = None
[docs]
def setup(self, stage: str):
"""Load the necessary datasets, split into train, validation and test sets.
Args:
stage: "fit" (train + validate), "validate" or "test"
"""
if stage not in ["fit", "validate", "test"]:
raise ValueError(
f"Stage '{stage}' not supported. Must be 'fit', 'validate' or 'test'."
)
log.info(f"Using labels from {self.label_subdir}.")
assert self.split_file.exists(), f"Split file {self.split_file} does not exist."
with self.split_file.open("rb") as f:
split_dict = pickle.load(f)
# Load paths and iterations from the split file, split this to avoid unnecessary loading
if stage == "fit":
train_paths = [
self.data_dir / dataset / self.label_subdir / label_path
for dataset, label_path, _ in split_dict["train"]
]
if self.use_cached_iterations:
train_iterations = [scf_iterations for _, _, scf_iterations in split_dict["train"]]
else:
train_iterations = None
self.train_set = self.dataset_class(
paths=train_paths,
num_scf_iterations_per_path=train_iterations,
basis_info=self.basis_info,
transforms=self.transforms,
**self.dataset_kwargs,
)
if stage == "fit" or (stage == "validate" and self.val_set is None):
val_paths = [
self.data_dir / dataset / self.label_subdir / label_path
for dataset, label_path, _ in split_dict["val"]
]
if self.use_cached_iterations:
val_iterations = [scf_iterations for _, _, scf_iterations in split_dict["val"]]
else:
val_iterations = None
self.val_set = self.dataset_class(
paths=val_paths,
num_scf_iterations_per_path=val_iterations,
basis_info=self.basis_info,
transforms=self.transforms,
**self.dataset_kwargs,
)
elif stage == "test":
test_paths = [
self.data_dir / dataset / self.label_subdir / label_path
for dataset, label_path, _ in split_dict["test"]
]
if self.use_cached_iterations:
test_iterations = [scf_iterations for _, _, scf_iterations in split_dict["test"]]
else:
test_iterations = None
self.test_set = self.dataset_class(
paths=test_paths,
num_scf_iterations_per_path=test_iterations,
basis_info=self.basis_info,
transforms=self.transforms,
**self.dataset_kwargs,
)
[docs]
def train_dataloader(self) -> OFLoader:
"""Return the training dataloader."""
return OFLoader(
self.train_set,
batch_size=self.batch_size,
shuffle=self.shuffle_train,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
worker_init_fn=self.worker_init_fn,
drop_last=True, # Important during training to avoid large gradients
**self.dataloader_kwargs,
)
[docs]
def val_dataloader(self) -> OFLoader:
"""Return the validation dataloader."""
return OFLoader(
self.val_set,
batch_size=self.batch_size,
shuffle=self.shuffle_val,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
worker_init_fn=self.worker_init_fn,
**self.dataloader_kwargs,
)
[docs]
def test_dataloader(self) -> OFLoader:
"""Return the test dataloader."""
return OFLoader(
self.test_set,
batch_size=self.batch_size,
shuffle=self.shuffle_test,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
worker_init_fn=self.worker_init_fn,
**self.dataloader_kwargs,
)
[docs]
def predict_dataloader(self) -> OFLoader:
"""Return the prediction dataloader."""
return OFLoader(
self.predict_set,
batch_size=self.batch_size,
shuffle=self.shuffle_test,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
worker_init_fn=self.worker_init_fn,
**self.dataloader_kwargs,
)