datamodule

PyTorch Lightning DataModule for all molecular datasets.

class OFDataModule(split_file: Path | str, data_dir: Path | str, transforms: MasterTransformation, basis_info: BasisInfo, batch_size: int, num_workers: int = 0, pin_memory: bool = False, shuffle_train: bool = True, shuffle_val: bool = True, shuffle_test: bool = False, use_cached_iterations: bool = True, dataset_kwargs: dict = None, dataloader_kwargs: dict = None)[source]
__init__(split_file: Path | str, data_dir: Path | str, transforms: MasterTransformation, basis_info: BasisInfo, batch_size: int, num_workers: int = 0, pin_memory: bool = False, shuffle_train: bool = True, shuffle_val: bool = True, shuffle_test: bool = False, use_cached_iterations: bool = True, dataset_kwargs: dict = None, dataloader_kwargs: dict = None)[source]

Initialize the DataModule.

Set up the parameters directory, basis info and transforms that will later be used in setup().

Parameters:
  • split_file – Name of the yaml file containing the train, val, test split

  • data_dir – Path to the directory containing the data

  • transforms – transforms for data augmentation

  • basis_info – BasisInfo object describing which basis functions are used in the dataset

  • batch_size – batch size

  • num_workers – number of workers for the dataloader

  • pin_memory – whether to pin memory

  • shuffle_train – whether to shuffle the training set

  • shuffle_val – whether to shuffle the validation set

  • shuffle_test – whether to shuffle the test set

  • use_cached_iterations – whether to load the number of scf iterations from the split file or recompute them

  • dataset_kwargs – Keyword arguments passed to OFDataset.

  • dataloader_kwargs – Keyword arguments passed to OFLoader.

dataset_class

alias of OFDataset

predict_dataloader() OFLoader[source]

Return the prediction dataloader.

setup(stage: str)[source]

Load the necessary datasets, split into train, validation and test sets.

Parameters:

stage – “fit” (train + validate), “validate” or “test”

test_dataloader() OFLoader[source]

Return the test dataloader.

train_dataloader() OFLoader[source]

Return the training dataloader.

val_dataloader() OFLoader[source]

Return the validation dataloader.

worker_init_fn(worker_id: int, dtype: dtype)[source]

Helper function to set the default torch dtype inside the worker.

Parameters:
  • worker_id – The worker id, is not used here.

  • dtype – The torch dtype to be used inside the data loader.