Source code for gammagl.utils.subgraph

import tensorlayerx as tlx
import numpy as np
from .num_nodes import maybe_num_nodes

[docs] def k_hop_subgraph(node_idx, num_hops, edge_index, relabel_nodes=False, num_nodes=None, reverse=False): r""" Computes the induced subgraph of :obj:`edge_index` around all nodes in :attr:`node_idx` reachable within :math:`k` hops. The :attr:`flow` argument denotes the direction of edges for finding :math:`k`-hop neighbors. If set to :obj:`"source_to_target"`, then the method will find all neighbors that point to the initial set of seed nodes in :attr:`node_idx.` This mimics the natural flow of message passing in Graph Neural Networks. The method returns (1) the nodes involved in the subgraph, (2) the filtered :obj:`edge_index` connectivity, (3) the mapping from node indices in :obj:`node_idx` to their new location, and (4) the edge mask indicating which edges were preserved. Parameters ---------- node_idx: int, list, tuple, tensor The central seed node(s). num_hops: int The number of hops :math:`k`. edge_index: tensor The edge indices. relabel_nodes: bool, optional If set to :obj:`True`, the resulting :obj:`edge_index` will be relabeled to hold consecutive indices starting from zero. (default: :obj:`False`) num_nodes: int, optional The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) reverse: bool, optional The flow direction of :math:`k`-hop, :obj:`False` for "source to target" or vice versa. """ num_nodes = maybe_num_nodes(edge_index, num_nodes) if reverse: row, col = edge_index else: col, row = edge_index if isinstance(node_idx, (int, list, tuple)): node_idx = tlx.convert_to_tensor(np.array([node_idx]).flatten(), dtype=edge_index.dtype) # node_idx = np.array([i.item() for i in node_idx]).flatten() # node_idx = tlx.convert_to_tensor(node_idx, dtype=tlx.int64) subsets = [node_idx] node_mask = tlx.zeros((num_nodes,), dtype=tlx.bool) for _ in range(num_hops): node_mask = tlx.zeros((num_nodes,), dtype = tlx.int64) node_mask = tlx.scatter_update(node_mask, subsets[-1], tlx.ones_like(subsets[-1], dtype = tlx.int64)) edge_mask = tlx.gather(node_mask, row) edge_mask = tlx.cast(edge_mask, dtype = tlx.bool) subsets.append(tlx.mask_select(col, edge_mask)) subset, inv = np.unique(tlx.convert_to_numpy(tlx.concat(subsets, axis=0)), return_inverse=True) subset = tlx.convert_to_tensor(subset) if tlx.BACKEND == 'paddle': idx = tlx.convert_to_numpy(tlx.count_nonzero(node_idx+1))[0] else: idx = tlx.convert_to_numpy(tlx.count_nonzero(node_idx+1)) inv = inv[:idx] node_mask = tlx.zeros((num_nodes,), dtype = tlx.int64) node_mask = tlx.scatter_update(node_mask, subset, tlx.ones_like(subset, dtype = tlx.int64)) if tlx.BACKEND != 'paddle': node_mask = tlx.cast(node_mask, dtype = tlx.bool) edge_mask = tlx.logical_and(tlx.gather(node_mask, row), tlx.gather(node_mask, col)) edge_mask = tlx.cast(edge_mask, dtype = tlx.bool) edge_index = tlx.mask_select(edge_index, edge_mask, axis=1) if relabel_nodes: node_idx = tlx.constant(-1, dtype=node_idx.dtype, shape=(num_nodes,)) node_idx = tlx.scatter_update(node_idx, subset, tlx.arange(0, subset.shape[0], dtype=edge_index.dtype)) if tlx.BACKEND == 'paddle': edge_index = tlx.gather(node_idx, tlx.reshape(edge_index, (-1, 1))) edge_index = tlx.reshape(edge_index, (2, -1)) else: edge_index = tlx.gather(node_idx, edge_index) return subset, edge_index, inv, edge_mask