from typing import Optional, Tuple, Union
import tensorlayerx as tlx
import numpy as np
from gammagl.utils.check import check_is_numpy
from .num_nodes import maybe_num_nodes
def contains_self_loops(edge_index) -> bool:
r"""Returns :obj:`True` if the graph given by :attr:`edge_index` contains
self-loops.
Parameters
----------
edge_index: tensor
The edge indices.
Returns
-------
bool
"""
mask = edge_index[0] == edge_index[1]
return tlx.any(mask, axis=0)
[docs]
def remove_self_loops(edge_index, edge_attr=None):
r"""Removes every self-loop in the graph given by :attr:`edge_index`, so
that :math:`(i,i) \not\in \mathcal{E}` for every :math:`i \in \mathcal{V}`.
Parameters
----------
edge_index: tensor
The edge indices.
edge_attr: tensor, optional
Edge weights or multi-dimensional
edge features. (default: :obj:`None`)
Returns
-------
Tensor if edge_index inputted is Tensor
|| np.ndarray if edge_index inputted is np.ndarray
"""
mask = edge_index[0] != edge_index[1]
if tlx.is_tensor(edge_index):
edge_index = tlx.mask_select(edge_index, mask, axis = 1)
edge_index = tlx.cast(edge_index, dtype = tlx.int64)
elif check_is_numpy(edge_index):
edge_index = edge_index[:, mask]
if edge_attr is None:
return edge_index, None
else:
edge_attr = tlx.mask_select(edge_attr, mask)
return edge_index, edge_attr
[docs]
def add_self_loops(
edge_index, edge_attr=None, n_loops=1,
fill_value: Union[float, str] = None,
num_nodes: Optional[int] = None):
r"""Adds a self-loop :math:`(i,i) \in \mathcal{E}` to every node
:math:`i \in \mathcal{V}` in the graph given by :attr:`edge_index`.
In case the graph is weighted or has multi-dimensional edge features
(:obj:`edge_attr != None`), edge features of self-loops will be added
according to :obj:`fill_value`.
.. code:: python
>>> from gammagl.data import Graph
>>> from gammagl.utils.loop import add_self_loops
>>> import numpy
>>> edge_index = tlx.constant([[0, 0, 0], [1, 2, 3]])
array([[0, 0, 0],
[1, 2, 3]])
>>> edge_index, _ = add_self_loops(edge_index)
array([[0, 0, 0, 0, 1, 2, 3],
[1, 2, 3, 0, 1, 2, 3]])
Parameters
----------
edge_index: tensor
The edge indices.
n_loops: int
the number of loops
edge_attr: tensor, optional
Edge weights or multi-dimensional edge
features. (default: :obj:`None`)
fill_value: float, tensor, str, optional
The way to generate
edge features of self-loops (in case :obj:`edge_attr != None`).
If given as :obj:`float` or :class:`torch.Tensor`, edge features of
self-loops will be directly given by :obj:`fill_value`.
If given as :obj:`str`, edge features of self-loops are computed by
aggregating all features of edges that point to the specific node,
according to a reduce operation. (:obj:`"add"`, :obj:`"mean"`,
:obj:`"min"`, :obj:`"max"`, :obj:`"mul"`). (default: :obj:`1.`)
num_nodes: int, optional
The number of nodes, *i.e.*
:obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)
Returns
-------
:class:`LongTensor`, :class:`Tensor`
"""
N = maybe_num_nodes(edge_index, num_nodes)
# loop_index = tlx.convert_to_tensor(np.arange(0, N), dtype=tlx.int64)
# edge_index = tlx.convert_to_tensor(edge_index, dtype=tlx.int64) # torch raise Error
loop_index = tlx.convert_to_tensor(np.arange(int(N)).repeat(n_loops), dtype=edge_index.dtype)
loop_index = tlx.stack([loop_index, loop_index])
if edge_attr is not None:
if tlx.BACKEND in ['paddle']:
shape = ([N] + edge_attr.shape[1:]) if edge_attr.ndim > 1 else (N,)
else:
shape = ([N] + tlx.get_tensor_shape(edge_attr)[1:]) if edge_attr.ndim > 1 else (N,)
if fill_value is None:
loop_attr = tlx.ones(shape, dtype=edge_attr.dtype)
elif isinstance(fill_value, (int, float)):
loop_attr = tlx.constant(value=fill_value, shape=shape, dtype=edge_attr.dtype)
elif tlx.is_tensor(fill_value):
loop_attr = tlx.convert_to_numpy(fill_value)
if edge_attr.ndim != loop_attr.size:
loop_attr = np.expand_dims(loop_attr, axis=0)
# sizes = [N] + [1] * (loop_attr.size - 1)
loop_attr = tlx.convert_to_tensor(np.repeat(loop_attr, [N], axis=0), dtype=fill_value.dtype)
elif isinstance(fill_value, str):
# TODO
raise NotImplementedError
# loop_attr = scatter(edge_attr, edge_index[1], dim=0, dim_size=N,
# reduce=fill_value)
else:
raise AttributeError("No valid 'fill_value' provided")
edge_attr = tlx.concat([edge_attr, loop_attr], axis=0)
edge_index = tlx.concat([edge_index, loop_index], axis=1)
return edge_index, edge_attr