gammagl.loader.NodeNeighborLoader¶
- class NodeNeighborLoader(graph, num_neighbors, input_nodes_type=None, replace: bool = False, directed: bool = True, is_sorted: bool = False, neighbor_sampler=None, **kwargs)[source]¶
A data loader that performs neighbor sampling as introduced in the “Inductive Representation Learning on Large Graphs” paper. This loader allows for mini-batch training of GNNs on large-scale graphs where full-batch training is not feasible.
More specifically,
num_neighbors
denotes how much neighbors are sampled for each node in each iteration.NodeNeighborLoader
takes in this list ofnum_neighbors
and iteratively samplesnum_neighbors[i]
for each node involved in iterationi - 1
.Sampled nodes are sorted based on the order in which they were sampled. In particular, the first
batch_size
nodes represent the set of original mini-batch nodes.loader = NodeNeighborLoader( graph, # Sample 30 neighbors for each node for 2 iterations num_neighbors=[30] * 2, # Use a batch size of 128 for sampling training nodes batch_size=128, input_nodes=data.train_mask, ) sampled_data = next(iter(loader)) print(sampled_data.batch_size)
By default, the data loader will only include the edges that were originally sampled (
directed = True
). This option should only be used in case the number of hops is equivalent to the number of GNN layers. In case the number of GNN layers is greater than the number of hops, consider settingdirected = False
, which will include all edges between all sampled nodes (but is slightly slower as a result).Furthermore,
NodeNeighborLoader
works for both homogeneous graphs stored viaGraph
as well as heterogeneous graphs stored viaHeteroGraph
. When operating in heterogeneous graphs, up tonum_neighbors
neighbors will be sampled for eachedge_type
. However, more fine-grained control over the amount of sampled neighbors of individual edge types is possible:loader = NodeNeighborLoader( hetero_graph, # Sample 30 neighbors for each node and edge type for 2 iterations num_neighbors={key: [30] * 2 for key in hetero_graph.edge_types}, # Use a batch size of 128 for sampling training nodes of type paper batch_size=128, input_nodes=('paper', hetero_graph['paper'].train_mask), ) sampled_hetero_graph = next(iter(loader)) print(sampled_hetero_graph['paper'].batch_size)
Note
The
NeighborLoader
will return subgraphs where global node indices are mapped to local indices corresponding to this specific subgraph. However, often times it is desired to map the nodes of the current subgraph back to the global node indices. A simple trick to achieve this is to include this mapping as part of thegraph
object:# Assign each node its global node index: graph.n_id = tlx.arange(graph.num_nodes) loader = NeighborLoader(graph, ...) sampled_graph = next(iter(loader)) print(sampled_graph.n_id)
- Parameters:
graph (graph, heterograph) – The
Graph
orHeteroGraoh
graph object.num_neighbors (list[int], dict[tuple[str, str, str], list[int]]) – The number of neighbors to sample for each node in each iteration. In heterogeneous graphs, may also take in a dictionary denoting the amount of neighbors to sample for each individual edge type. If an entry is set to
-1
, all neighbors will be included.input_nodes (tensor, str, tuple[str, tensor]) – The indices of nodes for which neighbors are sampled to create mini-batches. If set to
None
, all nodes will be considered. In heterogeneous graphs, needs to be passed as a tuple that holds the node type and node indices. (default:None
)replace (bool, optional) – If set to
True
, will sample with replacement. (default:False
)directed (bool, optional) – If set to
False
, will include all edges between all sampled nodes. (default:True
)is_sorted (bool, optional) – If set to
True
, assumes thatedge_index
is sorted by column. This avoids internal re-sorting of the graph and can improve runtime and memory efficiency. (default:False
) Setting this toTrue
is generally not recommended: (1) it may result in too many open file handles, (2) it may slown down graph loading, (3) it requires operating on CPU tensors. (default:False
)**kwargs (optional) – Additional arguments of
tensorlayerx.dataflow.DataLoader
, such asbatch_size
,shuffle
,drop_last
.