Source code for gammagl.loader.dataloader

from collections.abc import Mapping, Sequence
from typing import List, Optional, Union

# import torch.utils.data
# from torch.utils.data.dataloader import default_collate

from gammagl.data import BatchGraph, Graph, Dataset
import tensorlayerx as tlx


class Collater:
    def __init__(self, follow_batch, exclude_keys):
        self.follow_batch = follow_batch
        self.exclude_keys = exclude_keys

    def __call__(self, batch):
        elem = batch[0]
        if isinstance(elem, Graph):
            return BatchGraph.from_data_list(batch, self.follow_batch,
                                        self.exclude_keys)
        # elif isinstance(elem, torch.Tensor):
        #     return default_collate(batch)
        elif isinstance(elem, float):
            return tlx.convert_to_tensor(batch, dtype=tlx.float32)
        elif isinstance(elem, int):
            return tlx.convert_to_tensor(batch, dtype=tlx.int64)
        elif isinstance(elem, str):
            return batch
        elif isinstance(elem, Mapping):
            return {key: self([data[key] for data in batch]) for key in elem}
        elif isinstance(elem, tuple) and hasattr(elem, '_fields'):
            return type(elem)(*(self(s) for s in zip(*batch)))
        elif isinstance(elem, Sequence) and not isinstance(elem, str):
            return [self(s) for s in zip(*batch)]

        raise TypeError(f'DataLoader found invalid type: {type(elem)}')

    def collate(self, batch):  # Deprecated...
        return self(batch)


[docs] class DataLoader(tlx.dataflow.DataLoader): r"""A data loader which merges data objects from a :class:`gammagl.data.Dataset` to a mini-batch. Data objects can be either of type :class:`~gammagl.data.Graph` or :class:`~gammagl.data.HeteroGraph`. Parameters ---------- dataset: dataset The dataset from which to load the data. batch_size: int, optional How many samples per batch to load. (default: :obj:`1`) shuffle: bool, optional If set to :obj:`True`, the data will be reshuffled at every epoch. (default: :obj:`False`) follow_batch: list[str], optional Creates assignment batch vectors for each key in the list. (default: :obj:`None`) exclude_keys: list[str], optional Will exclude each key in the list. (default: :obj:`None`) **kwargs: optional Additional arguments of :class:`tensorlayerx.dataflow.DataLoader`. """ def __init__( self, dataset: Union[Dataset, List[Graph]], batch_size: int = 1, shuffle: bool = False, follow_batch: Optional[List[str]] = None, exclude_keys: Optional[List[str]] = None, **kwargs, ): if 'collate_fn' in kwargs: del kwargs['collate_fn'] # Save for PyTorch Lightning: self.follow_batch = follow_batch self.exclude_keys = exclude_keys super().__init__( dataset, batch_size, shuffle, collate_fn=Collater(follow_batch, exclude_keys), **kwargs, )