Source code for gammagl.layers.conv.sage_conv

import tensorlayerx as tlx
from gammagl.layers.conv import MessagePassing
from gammagl.utils import add_self_loops, calc_gcn_norm


[docs] class SAGEConv(MessagePassing): r"""The GraphSAGE operator from the `"Inductive Representation Learning on Large Graphs" <https://arxiv.org/abs/1706.02216>`_ paper .. math:: \mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i + \mathbf{W}_2 \cdot \ \mathrm{mean}_{j \in \mathcal{N(i)}} \mathbf{x}_j Parameters ---------- in_channels: int, tuple Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities. out_channels: int Size of each output sample. norm: callable, None, optional If not None, applies normalization to the updated node features. aggr: str, optional Aggregator type to use (``mean``). add_bias: bool, optional If set to :obj:`False`, the layer will not learn an additive bias. (default: :obj:`True`) """ def __init__(self, in_channels, out_channels, activation=None, aggr="mean", add_bias=True): super(SAGEConv, self).__init__() # self.aggr = aggr self.act = activation self.in_feat = in_channels # relu use he_normal initor = tlx.initializers.he_normal() # self and neighbor self.fc_neigh = tlx.nn.Linear(in_features=in_channels, out_features=out_channels, W_init=initor, b_init=None) if aggr != 'gcn': self.fc_self = tlx.nn.Linear(in_features=in_channels, out_features=out_channels, W_init=initor, b_init=None) if aggr == "lstm": self.lstm = tlx.nn.LSTM(input_size=in_channels, hidden_size=in_channels, batch_first=True) if aggr == "pool": self.pool = tlx.nn.Linear(in_features=in_channels, out_features=in_channels, W_init=initor, b_init=None) self.add_bias = add_bias if add_bias: init = tlx.initializers.zeros() self.bias = self._get_weights("bias", shape=(1, out_channels), init=init)
[docs] def forward(self, feat, edge): r""" Compute GraphSAGE layer. Parameters ---------- feat : Pair of Tensor The pair must contain two tensors of shape :math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`. Returns ------- Tensor The output feature of shape :math:`(N_{dst}, D_{out})` where :math:`N_{dst}` is the number of destination nodes in the input graph, math:`D_{out}` is size of output feature. """ if isinstance(feat, tuple): src_feat = feat[0] dst_feat = feat[1] else: src_feat = feat dst_feat = feat num_nodes = int(dst_feat.shape[0]) if self.aggr == 'mean': src_feat = self.fc_neigh(src_feat) out = self.propagate(src_feat, edge, edge_weight=None, num_nodes=num_nodes, aggr='mean') elif self.aggr == 'gcn': src_feat = self.fc_neigh(src_feat) edge, _ = add_self_loops(edge) weight = calc_gcn_norm(edge, int(1 + tlx.reduce_max(edge[0]))) # col, row, weight = calc(edge, 1 + tlx.reduce_max(edge[0])) out = self.propagate(src_feat, edge, edge_weight=weight, num_nodes=int(1 + tlx.reduce_max(edge[0])), aggr='sum') out = tlx.gather(out, tlx.arange(0, num_nodes)) elif self.aggr == 'pool': src_feat = tlx.nn.ReLU()(self.pool(src_feat)) out = self.propagate(src_feat, edge, edge_weight=None, num_nodes=num_nodes, aggr='max') out = self.fc_neigh(out) elif self.aggr == 'lstm': src_feat = tlx.reshape(src_feat, (dst_feat.shape[0], -1, src_feat.shape[1])) size = dst_feat.shape[0] h = (tlx.zeros((1, size, self.in_feat)), tlx.zeros((1, size, self.in_feat))) _, (rst, _) = self.lstm(src_feat, h) out = self.fc_neigh(rst[0]) if self.aggr != 'gcn': out += self.fc_self(dst_feat) if self.add_bias: out += self.bias if self.act is not None: out = self.act(out) return out