Source code for mldft.ml.data.components.loader

# This file includes code from PyTorch Geometric (https://github.com/pyg-team/pytorch_geometric), licensed under the MIT License.
# The file was adapted to include the `list_keys` parameter.
"""DataLoader class for machine learning."""
from typing import List, Optional, Sequence, Union

from torch.utils.data import DataLoader
from torch_geometric.data import Dataset
from torch_geometric.data.data import BaseData
from torch_geometric.data.datapipes import DatasetAdapter

from mldft.ml.data.components.of_batch import OFCollater


[docs] class OFLoader(DataLoader): """Data loader for OF-DFT data. Thin wrapper around torch_geometric.loader.DataLoader, setting "follow_batch" to ["coeffs"] by default, to facilitate splitting of all basis-function wise fields by molecule, e.g. "coeffs", "ground_state_coeffs", "gradient_label". """
[docs] def __init__( self, dataset: Union[Dataset, Sequence[BaseData], DatasetAdapter], batch_size: int = 1, shuffle: bool = False, follow_batch: Optional[List[str]] = ("coeffs", "atomic_numbers"), exclude_keys: Optional[List[str]] = None, list_keys: Optional[List[str]] = None, **kwargs, ): """Data loader for OF-DFT data.""" # Remove for PyTorch Lightning: kwargs.pop("collate_fn", None) # Save for PyTorch Lightning < 1.6: self.follow_batch = follow_batch self.exclude_keys = exclude_keys super().__init__( dataset, batch_size, shuffle, collate_fn=OFCollater(dataset, follow_batch, exclude_keys, list_keys), **kwargs, )