# This file includes code from PyTorch Geometric (https://github.com/pyg-team/pytorch_geometric), licensed under the MIT License.
# The file was adapted to include the `list_keys` parameter.
from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import (
Any,
Dict,
Iterable,
List,
Optional,
Self,
Tuple,
Type,
TypeVar,
Union,
)
import torch
import torch.utils.data
from torch import Tensor
from torch.utils.data.dataloader import default_collate
from torch_geometric.data import Batch, Dataset
from torch_geometric.data.collate import _batch_and_ptr, _collate, repeat_interleave
from torch_geometric.data.data import BaseData
from torch_geometric.data.datapipes import DatasetAdapter
from torch_geometric.data.storage import NodeStorage
from torch_geometric.typing import TensorFrame, torch_frame
from torch_geometric.utils import cumsum
T = TypeVar("T")
SliceDictType = Dict[str, Union[Tensor, Dict[str, Tensor]]]
IncDictType = Dict[str, Union[Tensor, Dict[str, Tensor]]]
[docs]
class OFCollater:
"""Collater for OF-DFT data.
Copy paste from torch_geometric.data.collate.Collater except for using
our custom OFBatch and the `list_keys` attribute.
"""
[docs]
def __init__(
self,
dataset: Union[Dataset, Sequence[BaseData], DatasetAdapter],
follow_batch: Optional[List[str]] = None,
exclude_keys: Optional[List[str]] = None,
list_keys: Optional[Sequence[str]] = None,
):
"""Collater for OF-DFT data."""
self.dataset = dataset
self.follow_batch = follow_batch
self.exclude_keys = exclude_keys
self.list_keys = list_keys
[docs]
def __call__(self, batch: List[Any]) -> Any:
"""Collates a list of data objects into a single object."""
elem = batch[0]
if isinstance(elem, BaseData):
# Changed to use our custom OFBatch
return OFBatch.from_data_list(
batch,
follow_batch=self.follow_batch,
exclude_keys=self.exclude_keys,
list_keys=self.list_keys,
)
elif isinstance(elem, torch.Tensor):
return default_collate(batch)
elif isinstance(elem, TensorFrame):
return torch_frame.cat(batch, dim=0)
elif isinstance(elem, float):
return torch.tensor(batch, dtype=torch.float)
elif isinstance(elem, int):
return torch.tensor(batch)
elif isinstance(elem, str):
return batch
elif isinstance(elem, Mapping):
return {key: self([data[key] for data in batch]) for key in elem}
elif isinstance(elem, tuple) and hasattr(elem, "_fields"):
return type(elem)(*(self(s) for s in zip(*batch)))
elif isinstance(elem, Sequence) and not isinstance(elem, str):
return [self(s) for s in zip(*batch)]
raise TypeError(f"DataLoader found invalid type: '{type(elem)}'")
[docs]
class OFBatch(Batch):
"""A batch object for OFData.
Copied from `torch_geometric.data.Batch` with the addition
of the `list_keys` attribute, which just appends attributes to a list instead of collating them, which
is needed for square matrices of varying size.
"""
[docs]
@classmethod
def from_data_list(
cls,
data_list: List[BaseData],
follow_batch: Optional[List[str]] = None,
exclude_keys: Optional[List[str]] = None,
list_keys: Optional[Sequence[str]] = None, # Added to the original implementation
) -> Self:
r"""Constructs a :class:`~torch_geometric.data.Batch` object from a list of
:class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` objects.
The assignment vector :obj:`batch` is created on the fly.
In addition, creates assignment vectors for each key in
:obj:`follow_batch`.
Will exclude any keys given in :obj:`exclude_keys`.
Keys in :obj:`list_keys` will be appended to a list instead
of collated (useful for square matrices of varying size).
"""
batch, slice_dict, inc_dict = collate(
cls,
data_list=data_list,
increment=True,
add_batch=not isinstance(data_list[0], Batch),
follow_batch=follow_batch,
exclude_keys=exclude_keys,
list_keys=list_keys,
)
batch._num_graphs = len(data_list) # type: ignore
batch._slice_dict = slice_dict # type: ignore
batch._inc_dict = inc_dict # type: ignore
return batch
[docs]
def collate(
cls: Type[T],
data_list: List[BaseData],
increment: bool = True,
add_batch: bool = True,
follow_batch: Optional[Iterable[str]] = None,
exclude_keys: Optional[Iterable[str]] = None,
list_keys: Optional[Sequence[str]] = None, # Added to the original implementation
) -> Tuple[T, SliceDictType, IncDictType]:
"""Collates a list of `data` objects into a single object of type `cls`.
`collate` can handle both homogeneous and heterogeneous data objects by
individually collating all their stores.
In addition, `collate` can handle nested data structures such as
dictionaries and lists.
"""
if not isinstance(data_list, (list, tuple)):
# Materialize `data_list` to keep the `_parent` weakref alive.
data_list = list(data_list)
if cls != data_list[0].__class__: # Dynamic inheritance.
out = cls(_base_cls=data_list[0].__class__) # type: ignore
else:
out = cls()
# Create empty stores:
out.stores_as(data_list[0]) # type: ignore
follow_batch = set(follow_batch or [])
exclude_keys = set(exclude_keys or [])
list_keys = set(list_keys or [])
# Group all storage objects of every data object in the `data_list` by key,
# i.e. `key_to_stores = { key: [store_1, store_2, ...], ... }`:
key_to_stores = defaultdict(list)
for data in data_list:
for store in data.stores:
key_to_stores[store._key].append(store)
# With this, we iterate over each list of storage objects and recursively
# collate all its attributes into a unified representation:
# We maintain two additional dictionaries:
# * `slice_dict` stores a compressed index representation of each attribute
# and is needed to re-construct individual elements from mini-batches.
# * `inc_dict` stores how individual elements need to be incremented, e.g.,
# `edge_index` is incremented by the cumulated sum of previous elements.
# We also need to make use of `inc_dict` when re-constructuing individual
# elements as attributes that got incremented need to be decremented
# while separating to obtain original values.
device: Optional[torch.device] = None
slice_dict: SliceDictType = {}
inc_dict: IncDictType = {}
for out_store in out.stores: # type: ignore
key = out_store._key
stores = key_to_stores[key]
for attr in stores[0].keys():
if attr in exclude_keys: # Do not include top-level attribute.
continue
values = [store[attr] for store in stores]
###########################################################
# This is the only change from the original implementation:
if attr in list_keys:
out_store[attr] = values
if isinstance(values[0], Tensor):
if device is None:
device = values[0].device
slice_dict[attr] = torch.arange(len(values) + 1, device=device)
inc_dict[attr] = torch.zeros(len(values), device=device)
continue
###########################################################
# The `num_nodes` attribute needs special treatment, as we need to
# sum their values up instead of merging them to a list:
if attr == "num_nodes":
out_store._num_nodes = values
out_store.num_nodes = sum(values)
continue
# Skip batching of `ptr` vectors for now:
if attr == "ptr":
continue
# Collate attributes into a unified representation:
value, slices, incs = _collate(attr, values, data_list, stores, increment)
# If parts of the data are already on GPU, make sure that auxiliary
# data like `batch` or `ptr` are also created on GPU:
if isinstance(value, Tensor) and value.is_cuda:
device = value.device
out_store[attr] = value
if key is not None: # Heterogeneous:
store_slice_dict = slice_dict.get(key, {})
assert isinstance(store_slice_dict, dict)
store_slice_dict[attr] = slices
slice_dict[key] = store_slice_dict
store_inc_dict = inc_dict.get(key, {})
assert isinstance(store_inc_dict, dict)
store_inc_dict[attr] = incs
inc_dict[key] = store_inc_dict
else: # Homogeneous:
slice_dict[attr] = slices
inc_dict[attr] = incs
# Add an additional batch vector for the given attribute:
if attr in follow_batch:
batch, ptr = _batch_and_ptr(slices, device)
out_store[f"{attr}_batch"] = batch
out_store[f"{attr}_ptr"] = ptr
# In case of node-level storages, we add a top-level batch vector it:
if add_batch and isinstance(stores[0], NodeStorage) and stores[0].can_infer_num_nodes:
repeats = [store.num_nodes or 0 for store in stores]
out_store.batch = repeat_interleave(repeats, device=device)
out_store.ptr = cumsum(torch.tensor(repeats, device=device))
return out, slice_dict, inc_dict