Source code for gammagl.loader.link_neighbor_loader

# -*- coding: utf-8 -*-
# @author WuJing
# @created 2023/3/2

from typing import List
from gammagl.loader.link_loader import LinkLoader
from gammagl.sampler.neighbor_sampler import NeighborSampler


[docs] class LinkNeighborLoader(LinkLoader): r"""A link-based graph loader derived as an extension of the node-based :class:`gammagl.loader.NeighborLoader`. This loader allows for mini-batch training of GNNs on large-scale graphs where full-batch training is not feasible. More specifically, this loader first selects a sample of edges from the set of input edges :obj:`edge_label_index` (which may or not be edges in the original graph) and then constructs a subgraph from all the nodes present in this list by sampling :obj:`num_neighbors` neighbors in each iteration. .. code-block:: python loader = LinkNeighborLoader( graph, # Sample 30 neighbors for each node for 2 iterations num_neighbors=[30] * 2, # Use a batch size of 128 for sampling training nodes batch_size=128, edge_label_index=graph.edge_index, ) sampled_graph = next(iter(loader)) print(sampled_graph) It is additionally possible to provide edge labels for sampled edges, which are then added to the batch: .. code-block:: python loader = LinkNeighborLoader( graph, num_neighbors=[30] * 2, batch_size=128, edge_label_index=graph.edge_index, edge_label=torch.ones(graph.edge_index.size(1)) ) sampled_graph = next(iter(loader)) print(sampled_graph) The rest of the functionality mirrors that of :class:`~gammagl.loader.NeighborLoader`, including support for heterogenous graphs. .. note:: :obj:`neg_sampling_ratio` is currently implemented in an approximate way, *i.e.* negative edges may contain false negatives. Parameters ---------- graph: graph, heterograph The :class:`~gammagl.data.Graph` or :class:`~gammagl.data.HeteroGraph` graph object. num_neighbors: list[int], dict[tuple[str, str, str], list[int]] The number of neighbors to sample for each node in each iteration. In heterogeneous graphs, may also take in a dictionary denoting the amount of neighbors to sample for each individual edge type. If an entry is set to :obj:`-1`, all neighbors will be included. edge_label_index: tensor or str or tuple[str, tensor] The edge indices for which neighbors are sampled to create mini-batches. If set to :obj:`None`, all edges will be considered. In heterogeneous graphs, needs to be passed as a tuple that holds the edge type and corresponding edge indices. (default: :obj:`None`) edge_label: tensor, optional The labels of edge indices for which neighbors are sampled. Must be the same length as the :obj:`edge_label_index`. If set to :obj:`None` its set to `torch.zeros(...)` internally. (default: :obj:`None`) edge_label_time: tensor, optional The timestamps for edge indices for which neighbors are sampled. Must be the same length as :obj:`edge_label_index`. If set, temporal sampling will be used such that neighbors are guaranteed to fulfill temporal constraints, *i.e.*, neighbors have an earlier timestamp than the ouput edge. The :obj:`time_attr` needs to be set for this to work. (default: :obj:`None`) replace: bool, optional If set to :obj:`True`, will sample with replacement. (default: :obj:`False`) directed: bool, optional If set to :obj:`False`, will include all edges between all sampled nodes. (default: :obj:`True`) neg_sampling_ratio: float, optional The ratio of sampled negative edges to the number of positive edges. If :obj:`neg_sampling_ratio > 0` and in case :obj:`edge_label` does not exist, it will be automatically created and represents a binary classification task (:obj:`1` = edge, :obj:`0` = no edge). If :obj:`neg_sampling_ratio > 0` and in case :obj:`edge_label` exists, it has to be a categorical label from :obj:`0` to :obj:`num_classes - 1`. After negative sampling, label :obj:`0` represents negative edges, and labels :obj:`1` to :obj:`num_classes` represent the labels of positive edges. Note that returned labels are of type :obj:`torch.float` for binary classification (to facilitate the ease-of-use of :meth:`F.binary_cross_entropy`) and of type :obj:`torch.long` for multi-class classification (to facilitate the ease-of-use of :meth:`F.cross_entropy`). (default: :obj:`0.0`). time_attr: str, optional The name of the attribute that denotes timestamps for the nodes in the graph. Only used if :obj:`edge_label_time` is set. (default: :obj:`None`) transform: callable, optional A function/transform that takes in a sampled mini-batch and returns a transformed version. (default: :obj:`None`) is_sorted: bool, optional If set to :obj:`True`, assumes that :obj:`edge_index` is sorted by column. This avoids internal re-sorting of the graph and can improve runtime and memory efficiency. (default: :obj:`False`) filter_per_worker: bool, optional If set to :obj:`True`, will filter the returning graph in each worker's subprocess rather than in the main process. Setting this to :obj:`True` is generally not recommended: (1) it may result in too many open file handles, (2) it may slown down graph loading, (3) it requires operating on CPU tensors. (default: :obj:`False`) **kwargs: optional Additional arguments of :class:`tensorlayerx.dataflow.DataLoader`, such as :obj:`batch_size`, :obj:`shuffle`, :obj:`drop_last`. """ def __init__(self, graph, num_neighbors: List[int], edge_label_index, edge_label, replace=False, directed=True, neg_sampling_ratio=0.0, neighbor_sampler=None, is_sorted=False, **kwargs): edge_type = None if neighbor_sampler is None: neighbor_sampler = NeighborSampler( graph, num_neighbors=num_neighbors, replace=replace, directed=directed, input_type=edge_type, is_sorted=is_sorted ) super(LinkNeighborLoader, self).__init__(graph=graph, link_sampler=neighbor_sampler, edge_label_index=edge_label_index, edge_label=edge_label, neg_sampling_ratio=neg_sampling_ratio, **kwargs)