Source code for gammagl.layers.conv.edgeconv

import tensorlayerx as tlx

from gammagl.layers.conv import MessagePassing


[docs] class EdgeConv(MessagePassing): r"""The Edge Convolution operator from the `"Dynamic Graph CNN for Learning on Point Clouds" <https://arxiv.org/pdf/1801.07829.pdf>`_ paper .. math:: \mathbf{x}^{(k)}_i = \max_{j\in N(i)}h_\Theta(\mathbf x_i^{(k-1)},x_j^{(k-1)}-x_i^{(k-1)}) where :math:`\mathbf{x}^{(k)}_i` denotes k-th layer's vector i, and :math:`h_\Theta` denotes a multilayer perceptron. Parameters ---------- nn: tlx.nn.Module A neural network :math:`h_{\mathbf{\Theta}}` that maps pair-wise concatenated node features :obj:`x` of shape :obj:`[-1, 2 * in_channels]` to shape :obj:`[-1, out_channels]`, *e.g.*, defined by :class:`tlx.nn.Sequential`. aggr: str, optional The aggregation scheme to use (:obj:`"add"`, :obj:`"mean"`, :obj:`"max"`). (default: :obj:`"sum"`) **kwargs: optional Additional arguments of :class:`gammagl.layers.conv.MessagePassing`. """ def __init__(self, nn, aggr = 'max', **kwargs): super().__init__(**kwargs) self.aggr = aggr self.nn = nn
[docs] def forward(self, x, edge_index): """""" if not isinstance(x, tuple): x = (x, x) # propagate_type: (x: PairTensor) return self.propagate(x=x, edge_index = edge_index, num_nodes = int(tlx.get_tensor_shape(x[0])[0]))
[docs] def message(self, x): x_i = x[0] x_j = x[1] return self.nn(tlx.concat([x_i, x_j - x_i], axis=-1))