Source code for gammagl.loader.node_neighbor_loader

# -*- coding: utf-8 -*-
# @author WuJing
# @created 2023/3/27

from gammagl.loader.node_loader import NodeLoader
from gammagl.sampler.neighbor_sampler import NeighborSampler
from gammagl.loader.utils import get_input_nodes_index


[docs] class NodeNeighborLoader(NodeLoader): r"""A data loader that performs neighbor sampling as introduced in the `"Inductive Representation Learning on Large Graphs" <https://arxiv.org/abs/1706.02216>`_ paper. This loader allows for mini-batch training of GNNs on large-scale graphs where full-batch training is not feasible. More specifically, :obj:`num_neighbors` denotes how much neighbors are sampled for each node in each iteration. :class:`~gammagl.loader.NodeNeighborLoader` takes in this list of :obj:`num_neighbors` and iteratively samples :obj:`num_neighbors[i]` for each node involved in iteration :obj:`i - 1`. Sampled nodes are sorted based on the order in which they were sampled. In particular, the first :obj:`batch_size` nodes represent the set of original mini-batch nodes. .. code-block:: python loader = NodeNeighborLoader( graph, # Sample 30 neighbors for each node for 2 iterations num_neighbors=[30] * 2, # Use a batch size of 128 for sampling training nodes batch_size=128, input_nodes=data.train_mask, ) sampled_data = next(iter(loader)) print(sampled_data.batch_size) By default, the data loader will only include the edges that were originally sampled (:obj:`directed = True`). This option should only be used in case the number of hops is equivalent to the number of GNN layers. In case the number of GNN layers is greater than the number of hops, consider setting :obj:`directed = False`, which will include all edges between all sampled nodes (but is slightly slower as a result). Furthermore, :class:`~gammagl.loader.NodeNeighborLoader` works for both **homogeneous** graphs stored via :class:`~gammagl.data.graph.Graph` as well as **heterogeneous** graphs stored via :class:`~gammagl.data.heterograph.HeteroGraph`. When operating in heterogeneous graphs, up to :obj:`num_neighbors` neighbors will be sampled for each :obj:`edge_type`. However, more fine-grained control over the amount of sampled neighbors of individual edge types is possible: .. code-block:: python loader = NodeNeighborLoader( hetero_graph, # Sample 30 neighbors for each node and edge type for 2 iterations num_neighbors={key: [30] * 2 for key in hetero_graph.edge_types}, # Use a batch size of 128 for sampling training nodes of type paper batch_size=128, input_nodes=('paper', hetero_graph['paper'].train_mask), ) sampled_hetero_graph = next(iter(loader)) print(sampled_hetero_graph['paper'].batch_size) .. note:: The :class:`~gammagl.loader.NeighborLoader` will return subgraphs where global node indices are mapped to local indices corresponding to this specific subgraph. However, often times it is desired to map the nodes of the current subgraph back to the global node indices. A simple trick to achieve this is to include this mapping as part of the :obj:`graph` object: .. code-block:: python # Assign each node its global node index: graph.n_id = tlx.arange(graph.num_nodes) loader = NeighborLoader(graph, ...) sampled_graph = next(iter(loader)) print(sampled_graph.n_id) Parameters ---------- graph: graph, heterograph The :class:`~gammagl.data.Graph` or :class:`~gammagl.data.HeteroGraoh` graph object. num_neighbors: list[int], dict[tuple[str, str, str], list[int]] The number of neighbors to sample for each node in each iteration. In heterogeneous graphs, may also take in a dictionary denoting the amount of neighbors to sample for each individual edge type. If an entry is set to :obj:`-1`, all neighbors will be included. input_nodes: tensor, str, tuple[str, tensor] The indices of nodes for which neighbors are sampled to create mini-batches. If set to :obj:`None`, all nodes will be considered. In heterogeneous graphs, needs to be passed as a tuple that holds the node type and node indices. (default: :obj:`None`) replace: bool, optional If set to :obj:`True`, will sample with replacement. (default: :obj:`False`) directed: bool, optional If set to :obj:`False`, will include all edges between all sampled nodes. (default: :obj:`True`) is_sorted: bool, optional If set to :obj:`True`, assumes that :obj:`edge_index` is sorted by column. This avoids internal re-sorting of the graph and can improve runtime and memory efficiency. (default: :obj:`False`) Setting this to :obj:`True` is generally not recommended: (1) it may result in too many open file handles, (2) it may slown down graph loading, (3) it requires operating on CPU tensors. (default: :obj:`False`) **kwargs: optional Additional arguments of :class:`tensorlayerx.dataflow.DataLoader`, such as :obj:`batch_size`, :obj:`shuffle`, :obj:`drop_last`. """ def __init__(self, graph, num_neighbors, input_nodes_type=None, replace: bool = False, directed: bool = True, is_sorted: bool = False, neighbor_sampler=None, **kwargs): node_type, _ = get_input_nodes_index(graph, input_nodes_type) if neighbor_sampler is None: neighbor_sampler = NeighborSampler( graph, num_neighbors=num_neighbors, replace=replace, directed=directed, input_type=node_type, is_sorted=is_sorted ) super().__init__( graph=graph, node_sampler=neighbor_sampler, input_nodes_type=input_nodes_type, **kwargs, )