gammagl.loader.LinkNeighborLoader¶
- class LinkNeighborLoader(graph, num_neighbors: List[int], edge_label_index, edge_label, replace=False, directed=True, neg_sampling_ratio=0.0, neighbor_sampler=None, is_sorted=False, **kwargs)[source]¶
A link-based graph loader derived as an extension of the node-based
gammagl.loader.NeighborLoader
. This loader allows for mini-batch training of GNNs on large-scale graphs where full-batch training is not feasible.More specifically, this loader first selects a sample of edges from the set of input edges
edge_label_index
(which may or not be edges in the original graph) and then constructs a subgraph from all the nodes present in this list by samplingnum_neighbors
neighbors in each iteration.loader = LinkNeighborLoader( graph, # Sample 30 neighbors for each node for 2 iterations num_neighbors=[30] * 2, # Use a batch size of 128 for sampling training nodes batch_size=128, edge_label_index=graph.edge_index, ) sampled_graph = next(iter(loader)) print(sampled_graph)
It is additionally possible to provide edge labels for sampled edges, which are then added to the batch:
loader = LinkNeighborLoader( graph, num_neighbors=[30] * 2, batch_size=128, edge_label_index=graph.edge_index, edge_label=torch.ones(graph.edge_index.size(1)) ) sampled_graph = next(iter(loader)) print(sampled_graph)
The rest of the functionality mirrors that of
NeighborLoader
, including support for heterogenous graphs.Note
neg_sampling_ratio
is currently implemented in an approximate way, i.e. negative edges may contain false negatives.- Parameters:
graph (graph, heterograph) – The
Graph
orHeteroGraph
graph object.num_neighbors (list[int], dict[tuple[str, str, str], list[int]]) – The number of neighbors to sample for each node in each iteration. In heterogeneous graphs, may also take in a dictionary denoting the amount of neighbors to sample for each individual edge type. If an entry is set to
-1
, all neighbors will be included.edge_label_index (tensor or str or tuple[str, tensor]) – The edge indices for which neighbors are sampled to create mini-batches. If set to
None
, all edges will be considered. In heterogeneous graphs, needs to be passed as a tuple that holds the edge type and corresponding edge indices. (default:None
)edge_label (tensor, optional) – The labels of edge indices for which neighbors are sampled. Must be the same length as the
edge_label_index
. If set toNone
its set to torch.zeros(…) internally. (default:None
)edge_label_time (tensor, optional) – The timestamps for edge indices for which neighbors are sampled. Must be the same length as
edge_label_index
. If set, temporal sampling will be used such that neighbors are guaranteed to fulfill temporal constraints, i.e., neighbors have an earlier timestamp than the ouput edge. Thetime_attr
needs to be set for this to work. (default:None
)replace (bool, optional) – If set to
True
, will sample with replacement. (default:False
)directed (bool, optional) – If set to
False
, will include all edges between all sampled nodes. (default:True
)neg_sampling_ratio (float, optional) – The ratio of sampled negative edges to the number of positive edges. If
neg_sampling_ratio > 0
and in caseedge_label
does not exist, it will be automatically created and represents a binary classification task (1
= edge,0
= no edge). Ifneg_sampling_ratio > 0
and in caseedge_label
exists, it has to be a categorical label from0
tonum_classes - 1
. After negative sampling, label0
represents negative edges, and labels1
tonum_classes
represent the labels of positive edges. Note that returned labels are of typetorch.float
for binary classification (to facilitate the ease-of-use ofF.binary_cross_entropy()
) and of typetorch.long
for multi-class classification (to facilitate the ease-of-use ofF.cross_entropy()
). (default:0.0
).time_attr (str, optional) – The name of the attribute that denotes timestamps for the nodes in the graph. Only used if
edge_label_time
is set. (default:None
)transform (callable, optional) – A function/transform that takes in a sampled mini-batch and returns a transformed version. (default:
None
)is_sorted (bool, optional) – If set to
True
, assumes thatedge_index
is sorted by column. This avoids internal re-sorting of the graph and can improve runtime and memory efficiency. (default:False
)filter_per_worker (bool, optional) – If set to
True
, will filter the returning graph in each worker’s subprocess rather than in the main process. Setting this toTrue
is generally not recommended: (1) it may result in too many open file handles, (2) it may slown down graph loading, (3) it requires operating on CPU tensors. (default:False
)**kwargs (optional) – Additional arguments of
tensorlayerx.dataflow.DataLoader
, such asbatch_size
,shuffle
,drop_last
.