Source code for mldft.ml.data.components.of_batch

# This file includes code from PyTorch Geometric (https://github.com/pyg-team/pytorch_geometric), licensed under the MIT License.
# The file was adapted to include the `list_keys` parameter.

from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import (
    Any,
    Dict,
    Iterable,
    List,
    Optional,
    Self,
    Tuple,
    Type,
    TypeVar,
    Union,
)

import torch
import torch.utils.data
from torch import Tensor
from torch.utils.data.dataloader import default_collate
from torch_geometric.data import Batch, Dataset
from torch_geometric.data.collate import _batch_and_ptr, _collate, repeat_interleave
from torch_geometric.data.data import BaseData
from torch_geometric.data.datapipes import DatasetAdapter
from torch_geometric.data.storage import NodeStorage
from torch_geometric.typing import TensorFrame, torch_frame
from torch_geometric.utils import cumsum

T = TypeVar("T")
SliceDictType = Dict[str, Union[Tensor, Dict[str, Tensor]]]
IncDictType = Dict[str, Union[Tensor, Dict[str, Tensor]]]


[docs] class OFCollater: """Collater for OF-DFT data. Copy paste from torch_geometric.data.collate.Collater except for using our custom OFBatch and the `list_keys` attribute. """
[docs] def __init__( self, dataset: Union[Dataset, Sequence[BaseData], DatasetAdapter], follow_batch: Optional[List[str]] = None, exclude_keys: Optional[List[str]] = None, list_keys: Optional[Sequence[str]] = None, ): """Collater for OF-DFT data.""" self.dataset = dataset self.follow_batch = follow_batch self.exclude_keys = exclude_keys self.list_keys = list_keys
[docs] def __call__(self, batch: List[Any]) -> Any: """Collates a list of data objects into a single object.""" elem = batch[0] if isinstance(elem, BaseData): # Changed to use our custom OFBatch return OFBatch.from_data_list( batch, follow_batch=self.follow_batch, exclude_keys=self.exclude_keys, list_keys=self.list_keys, ) elif isinstance(elem, torch.Tensor): return default_collate(batch) elif isinstance(elem, TensorFrame): return torch_frame.cat(batch, dim=0) elif isinstance(elem, float): return torch.tensor(batch, dtype=torch.float) elif isinstance(elem, int): return torch.tensor(batch) elif isinstance(elem, str): return batch elif isinstance(elem, Mapping): return {key: self([data[key] for data in batch]) for key in elem} elif isinstance(elem, tuple) and hasattr(elem, "_fields"): return type(elem)(*(self(s) for s in zip(*batch))) elif isinstance(elem, Sequence) and not isinstance(elem, str): return [self(s) for s in zip(*batch)] raise TypeError(f"DataLoader found invalid type: '{type(elem)}'")
[docs] class OFBatch(Batch): """A batch object for OFData. Copied from `torch_geometric.data.Batch` with the addition of the `list_keys` attribute, which just appends attributes to a list instead of collating them, which is needed for square matrices of varying size. """
[docs] @classmethod def from_data_list( cls, data_list: List[BaseData], follow_batch: Optional[List[str]] = None, exclude_keys: Optional[List[str]] = None, list_keys: Optional[Sequence[str]] = None, # Added to the original implementation ) -> Self: r"""Constructs a :class:`~torch_geometric.data.Batch` object from a list of :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` objects. The assignment vector :obj:`batch` is created on the fly. In addition, creates assignment vectors for each key in :obj:`follow_batch`. Will exclude any keys given in :obj:`exclude_keys`. Keys in :obj:`list_keys` will be appended to a list instead of collated (useful for square matrices of varying size). """ batch, slice_dict, inc_dict = collate( cls, data_list=data_list, increment=True, add_batch=not isinstance(data_list[0], Batch), follow_batch=follow_batch, exclude_keys=exclude_keys, list_keys=list_keys, ) batch._num_graphs = len(data_list) # type: ignore batch._slice_dict = slice_dict # type: ignore batch._inc_dict = inc_dict # type: ignore return batch
[docs] def collate( cls: Type[T], data_list: List[BaseData], increment: bool = True, add_batch: bool = True, follow_batch: Optional[Iterable[str]] = None, exclude_keys: Optional[Iterable[str]] = None, list_keys: Optional[Sequence[str]] = None, # Added to the original implementation ) -> Tuple[T, SliceDictType, IncDictType]: """Collates a list of `data` objects into a single object of type `cls`. `collate` can handle both homogeneous and heterogeneous data objects by individually collating all their stores. In addition, `collate` can handle nested data structures such as dictionaries and lists. """ if not isinstance(data_list, (list, tuple)): # Materialize `data_list` to keep the `_parent` weakref alive. data_list = list(data_list) if cls != data_list[0].__class__: # Dynamic inheritance. out = cls(_base_cls=data_list[0].__class__) # type: ignore else: out = cls() # Create empty stores: out.stores_as(data_list[0]) # type: ignore follow_batch = set(follow_batch or []) exclude_keys = set(exclude_keys or []) list_keys = set(list_keys or []) # Group all storage objects of every data object in the `data_list` by key, # i.e. `key_to_stores = { key: [store_1, store_2, ...], ... }`: key_to_stores = defaultdict(list) for data in data_list: for store in data.stores: key_to_stores[store._key].append(store) # With this, we iterate over each list of storage objects and recursively # collate all its attributes into a unified representation: # We maintain two additional dictionaries: # * `slice_dict` stores a compressed index representation of each attribute # and is needed to re-construct individual elements from mini-batches. # * `inc_dict` stores how individual elements need to be incremented, e.g., # `edge_index` is incremented by the cumulated sum of previous elements. # We also need to make use of `inc_dict` when re-constructuing individual # elements as attributes that got incremented need to be decremented # while separating to obtain original values. device: Optional[torch.device] = None slice_dict: SliceDictType = {} inc_dict: IncDictType = {} for out_store in out.stores: # type: ignore key = out_store._key stores = key_to_stores[key] for attr in stores[0].keys(): if attr in exclude_keys: # Do not include top-level attribute. continue values = [store[attr] for store in stores] ########################################################### # This is the only change from the original implementation: if attr in list_keys: out_store[attr] = values if isinstance(values[0], Tensor): if device is None: device = values[0].device slice_dict[attr] = torch.arange(len(values) + 1, device=device) inc_dict[attr] = torch.zeros(len(values), device=device) continue ########################################################### # The `num_nodes` attribute needs special treatment, as we need to # sum their values up instead of merging them to a list: if attr == "num_nodes": out_store._num_nodes = values out_store.num_nodes = sum(values) continue # Skip batching of `ptr` vectors for now: if attr == "ptr": continue # Collate attributes into a unified representation: value, slices, incs = _collate(attr, values, data_list, stores, increment) # If parts of the data are already on GPU, make sure that auxiliary # data like `batch` or `ptr` are also created on GPU: if isinstance(value, Tensor) and value.is_cuda: device = value.device out_store[attr] = value if key is not None: # Heterogeneous: store_slice_dict = slice_dict.get(key, {}) assert isinstance(store_slice_dict, dict) store_slice_dict[attr] = slices slice_dict[key] = store_slice_dict store_inc_dict = inc_dict.get(key, {}) assert isinstance(store_inc_dict, dict) store_inc_dict[attr] = incs inc_dict[key] = store_inc_dict else: # Homogeneous: slice_dict[attr] = slices inc_dict[attr] = incs # Add an additional batch vector for the given attribute: if attr in follow_batch: batch, ptr = _batch_and_ptr(slices, device) out_store[f"{attr}_batch"] = batch out_store[f"{attr}_ptr"] = ptr # In case of node-level storages, we add a top-level batch vector it: if add_batch and isinstance(stores[0], NodeStorage) and stores[0].can_infer_num_nodes: repeats = [store.num_nodes or 0 for store in stores] out_store.batch = repeat_interleave(repeats, device=device) out_store.ptr = cumsum(torch.tensor(repeats, device=device)) return out, slice_dict, inc_dict