Source code for gammagl.utils.sort_edge_index

import tensorlayerx as tlx
from .num_nodes import maybe_num_nodes


[docs] def sort_edge_index(edge_index, edge_attr=None, num_nodes=None, sort_by_row: bool = True): """Row-wise sorts :obj:`edge_index`. Parameters ---------- edge_index: tensor The edge indices. edge_attr: tensor, list[tensor], optional Edge weights or multi- dimensional edge features. If given as a list, will re-shuffle and remove duplicates for all its entries. (default: :obj:`None`) num_nodes: int, optional The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) sort_by_row: bool, optional If set to :obj:`False`, will sort :obj:`edge_index` column-wise. Returns ------- :class:`LongTensor` if :attr:`edge_attr` is :obj:`None`, else (:class:`LongTensor`, :obj:`Tensor` or :obj:`List[Tensor]]`) """ num_nodes = maybe_num_nodes(edge_index, num_nodes) idx = edge_index[1 - int(sort_by_row)] * num_nodes idx += edge_index[int(sort_by_row)] perm = tlx.ops.argsort(idx, descending=False) edge_index = tlx.gather(edge_index, indices=perm, axis=1) if edge_attr is None: return edge_index elif tlx.is_tensor(edge_attr): return edge_index, tlx.gather(edge_attr, indices=perm, axis=0) else: return edge_index, [tlx.gather(e, indices=perm, axis=0) for e in edge_attr]