# -*- coding: utf-8 -*-
# @author WuJing
# @created 2023/4/10
from dataclasses import dataclass
from time import time
import numpy as np
import tensorlayerx as tlx
from typing import Union, List, Optional, Callable, Tuple
from gammagl.sparse import SparseGraph
from gammagl.utils.platform_utils import Tensor, all_to_tensor, to_list
@dataclass
class EdgeIndex:
edge_index: Tensor
e_id: Optional[Tensor]
size: Tuple[int, int]
@dataclass
class Adj:
adj_t: SparseGraph
e_id: Optional[Tensor]
size: Tuple[int, int]
[docs]
class NeighborSampler(tlx.dataflow.DataLoader):
def __init__(self, edge_index: Union[Tensor, SparseGraph],
sample_lists: List[int], node_idx: Optional[Tensor] = None,
num_nodes: Optional[int] = None, return_e_id: bool = True,
transform: Callable = None, **kwargs):
self.edge_index = edge_index
self.node_idx = node_idx
self.num_nodes = num_nodes
self.sizes = sample_lists
self.return_e_id = return_e_id
self.transform = transform
self.is_sparse_graph = isinstance(edge_index, SparseGraph)
self.__val__ = None
if not self.is_sparse_graph:
if (num_nodes is None and node_idx is not None
and node_idx.dtype == tlx.bool):
num_nodes = node_idx.shape[0]
if (num_nodes is None and node_idx is not None
and node_idx.dtype == tlx.int64):
num_nodes = max(int(tlx.reduce_max(edge_index)), int(tlx.reduce_max(node_idx))) + 1
if num_nodes is None:
num_nodes = int(tlx.reduce_max(edge_index)) + 1
value = tlx.arange(start = 0, limit = tlx.get_tensor_shape(edge_index)[1]) if return_e_id else None
self.adj_t = SparseGraph(row=edge_index[0], col=edge_index[1],
value=value,
sparse_sizes=(num_nodes, num_nodes)).t()
else:
adj_t = edge_index
if return_e_id:
self.__val__ = adj_t.storage.value()
value = tlx.arange(start = 0, limit = adj_t.nnz())
adj_t = adj_t.set_value(value, layout='coo')
self.adj_t = adj_t
self.adj_t.storage.rowptr()
if node_idx is None:
node_idx = tlx.arange(start = 0, limit = self.adj_t.sparse_size(0))
elif node_idx.dtype == tlx.bool:
node_idx = tlx.convert_to_tensor(np.reshape(tlx.convert_to_numpy(node_idx).nonzero(), -1))
super().__init__(to_list(tlx.reshape(node_idx, (-1,))), collate_fn=self.sample, **kwargs)
[docs]
def sample(self, batch):
if not isinstance(batch, Tensor):
batch = all_to_tensor(batch)
batch_size: int = len(batch)
adjs = []
n_id = batch
for size in self.sizes:
adj_t, n_id = self.adj_t.sample_adj(n_id, size, replace=False)
e_id = adj_t.storage.value()
size = adj_t.sparse_sizes()[::-1]
if self.__val__ is not None:
# adj_t.set_value_(self.__val__[e_id], layout='coo')
adj_t.set_value_(tlx.gather(self.__val__, e_id), layout='coo')
if self.is_sparse_graph:
adjs.append(Adj(adj_t, e_id, size))
else:
row, col, _ = adj_t.coo()
edge_index = tlx.stack([col, row], axis=0)
adjs.append(EdgeIndex(edge_index, e_id, size))
adjs = adjs[0] if len(adjs) == 1 else adjs[::-1]
out = (batch, n_id, adjs)
out = self.transform(*out) if self.transform is not None else out
return out
def __repr__(self) -> str:
return f'{self.__class__.__name__}(sizes={self.sizes})'