Source code for gammagl.layers.conv.simplehgn_conv

import tensorlayerx as tlx
from gammagl.layers.conv import MessagePassing
from gammagl.utils import segment_softmax
from gammagl.mpops import *


[docs] class SimpleHGNConv(MessagePassing): r'''The SimpleHGN layer from the `"Are we really making much progress? Revisiting, benchmarking, and refining heterogeneous graph neural networks" <https://dl.acm.org/doi/pdf/10.1145/3447548.3467350>`_ paper The model extend the original graph attention mechanism in GAT by including edge type information into attention calculation. Calculating the coefficient: .. math:: \alpha_{ij} = \frac{exp(LeakyReLU(a^T[Wh_i||Wh_j||W_r r_{\psi(<i,j>)}]))}{\Sigma_{k\in\mathcal{E}}{exp(LeakyReLU(a^T[Wh_i||Wh_k||W_r r_{\psi(<i,k>)}]))}} (1) Residual connection including Node residual: .. math:: h_i^{(l)} = \sigma(\Sigma_{j\in \mathcal{N}_i} {\alpha_{ij}^{(l)}W^{(l)}h_j^{(l-1)}} + h_i^{(l-1)}) (2) and Edge residual: .. math:: \alpha_{ij}^{(l)} = (1-\beta)\alpha_{ij}^{(l)}+\beta\alpha_{ij}^{(l-1)} (3) Multi-heads: .. math:: h^{(l+1)}_j = \parallel^M_{m = 1}h^{(l + 1, m)}_j (4) Residual: .. math:: h^{(l+1)}_j = h^{(l)}_j + \parallel^M_{m = 1}h^{(l + 1, m)}_j (5) Parameters ---------- in_feats: int the input dimension out_feats: int the output dimension num_etypes: int the number of the edge type edge_feats: int the edge dimension heads: int, optional the number of heads in this layer negative_slope: float, optional the negative slope used in the LeakyReLU feat_drop: float, optional the feature drop rate attn_drop: float, optional the attention score drop rate residual: bool, optional whether we need the residual operation activation:, optional the activation function bias: bool, optional whether we need the bias beta: float, optional the hyperparameter used in edge residual ''' def __init__(self, in_feats, out_feats, num_etypes, edge_feats, heads=1, negative_slope=0.2, feat_drop=0., attn_drop=0., residual=False, activation=None, bias=False, beta=0.,): super().__init__() self.in_feats = in_feats self.out_feats = out_feats self.edge_feats = edge_feats self.heads = heads self.out_feats = out_feats self.edge_embedding = tlx.nn.Embedding(num_etypes, edge_feats) self.fc_node = tlx.nn.Linear(out_feats * heads, in_features=in_feats, b_init=None, W_init=tlx.initializers.XavierNormal(gain=1.414)) self.fc_edge = tlx.nn.Linear(edge_feats * heads, in_features=edge_feats, b_init=None, W_init=tlx.initializers.XavierNormal(gain=1.414)) self.attn_src = self._get_weights('attn_l', shape=(1, heads, out_feats), init=tlx.initializers.XavierNormal(gain=1.414), order=True) self.attn_dst = self._get_weights('attn_r', shape=(1, heads, out_feats), init=tlx.initializers.XavierNormal(gain=1.414), order=True) self.attn_edge = self._get_weights('attn_e', shape=(1, heads, edge_feats), init=tlx.initializers.XavierNormal(gain=1.414), order=True) self.feat_drop = tlx.nn.Dropout(feat_drop) self.attn_drop = tlx.nn.Dropout(attn_drop) self.leaky_relu = tlx.nn.LeakyReLU(negative_slope) self.fc_res = tlx.nn.Linear(heads * out_feats, in_features=in_feats, b_init=None, W_init=tlx.initializers.XavierNormal(gain=1.414)) if residual else None self.activation = activation self.bias = self._get_weights("bias", (1, heads, out_feats)) if bias else None self.beta = beta
[docs] def message(self, x, edge_index, edge_feat, num_nodes, res_alpha=None): x_new = self.fc_node(x) x_new = tlx.ops.reshape(x_new, shape=[-1, self.heads, self.out_feats]) x_new = self.feat_drop(x_new) edge_feat = self.edge_embedding(edge_feat) edge_feat = self.fc_edge(edge_feat) edge_feat = tlx.ops.reshape(edge_feat, [-1, self.heads, self.edge_feats]) #calculate the alpha node_src = edge_index[0, :] node_dst = edge_index[1, :] weight_src = tlx.ops.gather(tlx.reduce_sum(x_new * self.attn_src, -1), node_src) weight_dst = tlx.ops.gather(tlx.reduce_sum(x_new * self.attn_dst, -1), node_dst) weight_edge = tlx.reduce_sum(edge_feat * self.attn_edge, -1) weight = self.leaky_relu(weight_src + weight_dst + weight_edge) alpha = self.attn_drop(segment_softmax(weight, node_dst, num_nodes)) #edge residual if res_alpha is not None: alpha = alpha * (1 - self.beta) + res_alpha * self.beta rst = tlx.ops.gather(x_new, node_src) * tlx.ops.expand_dims(alpha, axis=-1) rst = unsorted_segment_sum(rst, node_dst, num_nodes) #node residual if self.fc_res is not None: res_val = self.fc_res(x) res_val = tlx.ops.reshape(res_val, shape=[x.shape[0], -1, self.out_feats]) rst = rst + res_val if self.bias is not None: rst = rst + self.bias if self.activation is not None: rst = self.activation(rst) x = rst return x, alpha
[docs] def propagate(self, x, edge_index, aggr='sum', **kwargs): """ Function that perform message passing. Parameters ---------- x: input node feature. edge_index: edges from src to dst. aggr: aggregation type, default='sum', optional=['sum', 'mean', 'max']. kwargs: other parameters dict. """ if 'num_nodes' not in kwargs.keys() or kwargs['num_nodes'] is None: kwargs['num_nodes'] = x.shape[0] coll_dict = self.__collect__(x, edge_index, aggr, kwargs) msg_kwargs = self.inspector.distribute('message', coll_dict) x, alpha = self.message(**msg_kwargs) x = self.update(x) return x, alpha
[docs] def forward(self, x, edge_index, edge_feat, res_attn=None): return self.propagate(x, edge_index, edge_feat=edge_feat)