Source code for gammagl.data.batch

import inspect
from collections.abc import Sequence
from typing import Any, List, Optional, Union

import numpy as np
import tensorlayerx as tlx
from gammagl.data import Graph
from gammagl.data.collate import collate
from gammagl.data.dataset import IndexType
from gammagl.data.separate import separate


class DynamicInheritance(type):
    # A meta class that sets the base class of a `Batch` object, e.g.:
    # * `Batch(Graph)` in case `Graph` objects are batched together
    # * `Batch(HeteroGraph)` in case `HeteroGraph` objects are batched together
    def __call__(cls, *args, **kwargs):
        base_cls = kwargs.pop('_base_cls', Graph)

        if issubclass(base_cls, BatchGraph):
            new_cls = base_cls
        else:
            name = f'{base_cls.__name__}{cls.__name__}'
            if name not in globals():
                globals()[name] = type(name, (cls, base_cls), {})
            new_cls = globals()[name]
        # https://docs.python.org/zh-cn/3.7/library/inspect.html#introspecting-callables-with-the-signature-object
        params = list(inspect.signature(base_cls.__init__).parameters.items())
        for i, (k, v) in enumerate(params[1:]):
            if k == 'args' or k == 'kwargs':
                continue
            if i < len(args) or k in kwargs:
                continue
            if v.default is not inspect.Parameter.empty:
                continue
            kwargs[k] = None

        return super(DynamicInheritance, new_cls).__call__(*args, **kwargs)


class DynamicInheritanceGetter(object):
    def __call__(self, cls, base_cls):
        return cls(_base_cls=base_cls)


[docs] class BatchGraph(metaclass=DynamicInheritance): r"""A data object describing a batch of graphs as one big (disconnected) graph. Inherits from :class:`gammagl.data.Graph` or :class:`gammagl.data.HeteroGraph`. In addition, single graphs can be identified via the assignment vector :obj:`batch`, which maps each node to its respective graph identifier. """
[docs] @classmethod def from_data_list(cls, data_list: List[Graph], follow_batch: Optional[List[str]] = None, exclude_keys: Optional[List[str]] = None): # https://github.com/pyg-team/pytorch_geometric/issues/3332 r"""Constructs a :class:`~gammagl.data.BatchGraph` object from a Python list of :class:`~gammagl.data.Graph` or :class:`~gammagl.data.HeteroGraph` objects. The assignment vector :obj:`batch` is created on the fly. In addition, creates assignment vectors for each key in :obj:`follow_batch`. Will exclude any keys given in :obj:`exclude_keys`.""" batch, slice_dict, inc_dict = collate( cls, data_list=data_list, increment=True, add_batch=not isinstance(data_list[0], BatchGraph), follow_batch=follow_batch, exclude_keys=exclude_keys, ) batch._num_graphs = len(data_list) batch._slice_dict = slice_dict batch._inc_dict = inc_dict return batch
[docs] def get_example(self, idx: int) -> Graph: r"""Gets the :class:`~gammagl.data.Graph` or :class:`~gammagl.data.HeteroGraph` object at index :obj:`idx`. The :class:`~gammagl.data.BatchGraph` object must have been created via :meth:`from_data_list` in order to be able to reconstruct the initial object.""" if not hasattr(self, '_slice_dict'): raise RuntimeError( ("Cannot reconstruct 'Data' object from 'Batch' because " "'Batch' was not created via 'Batch.from_data_list()'")) data = separate( cls=self.__class__.__bases__[-1], batch=self, idx=idx, slice_dict=self._slice_dict, inc_dict=self._inc_dict, decrement=True, ) return data
[docs] def index_select(self, idx: IndexType) -> List[Graph]: r"""Creates a subset of :class:`~gammagl.data.Graph` or :class:`~gammagl.data.HeteroGraph` objects from specified indices :obj:`idx`. Indices :obj:`idx` can be a slicing object, *e.g.*, :obj:`[2:5]`, a list, a tuple, or a :obj:`torch.Tensor` or :obj:`np.ndarray` of type long or bool. The :class:`~gammagl.data.BatchGraph` object must have been created via :meth:`from_data_list` in order to be able to reconstruct the initial objects.""" if isinstance(idx, slice): idx = list(range(self.num_graphs)[idx]) elif tlx.is_tensor(idx) and idx.dtype == tlx.int64: idx = tlx.convert_to_numpy(idx).flatten().tolist() elif tlx.is_tensor(idx) and idx.dtype == tlx.bool: idx = tlx.convert_to_numpy(idx).flatten().nonzero()[0].flatten().tolist() elif isinstance(idx, np.ndarray) and idx.dtype == np.int64: idx = idx.flatten().tolist() elif isinstance(idx, np.ndarray) and idx.dtype == bool: idx = idx.flatten().nonzero()[0].flatten().tolist() elif isinstance(idx, Sequence) and not isinstance(idx, str): pass else: raise IndexError( f"Only slices (':'), list, tuples, torch.tensor and " f"np.ndarray of dtype long or bool are valid indices (got " f"'{type(idx).__name__}')") return [self.get_example(i) for i in idx]
def __getitem__(self, idx: Union[int, np.integer, str, IndexType]) -> Any: if (isinstance(idx, (int, np.integer)) or (tlx.is_tensor(idx) and idx.ndim() == 0) or (isinstance(idx, np.ndarray) and np.isscalar(idx))): return self.get_example(idx) elif isinstance(idx, str) or (isinstance(idx, tuple) and isinstance(idx[0], str)): # Accessing attributes or node/edge types: return super().__getitem__(idx) else: return self.index_select(idx)
[docs] def to_data_list(self) -> List[Graph]: r"""Reconstructs the list of :class:`~gammagl.data.Graph` or :class:`~gammagl.data.HeteroGraph` objects from the :class:`~gammagl.data.BatchGraph` object. The :class:`~gammagl.data.BatchGraph` object must have been created via :meth:`from_data_list` in order to be able to reconstruct the initial objects.""" return [self.get_example(i) for i in range(self.num_graphs)]
@property def num_graphs(self) -> int: """Returns the number of graphs in the batch.""" if hasattr(self, '_num_graphs'): return self._num_graphs elif hasattr(self, 'ptr'): return self.ptr.numel() - 1 elif hasattr(self, 'batch'): return int(self.batch.max()) + 1 else: raise ValueError("Can not infer the number of graphs") def __reduce__(self): state = self.__dict__.copy() return DynamicInheritanceGetter(), self.__class__.__bases__, state