Source code for gammagl.layers.conv.hgt_conv

import math
import tensorlayerx as tlx
from gammagl.layers.conv import MessagePassing
from tensorlayerx.nn import ModuleDict, Linear, Parameter, ParameterDict
from gammagl.utils import segment_softmax


[docs] class HGTConv(MessagePassing): r"""The Heterogeneous Graph Transformer (HGT) operator from the `"Heterogeneous Graph Transformer" <https://arxiv.org/abs/2003.01332>`_ paper. Parameters ---------- in_channels: int, dsict[str, int] Size of each input sample of every node type, or :obj:`-1` to derive the size from the first input(s) to the forward method. out_channels: int Size of each output sample. metadata: tuple[list[str], list[tuple[str, str, str]]] The metadata of the heterogeneous graph, *i.e.* its node and edge types given by a list of strings and a list of string triplets, respectively. See :class:`gammagl.data.HeteroGraph.metadata` for more information. heads: int, optional Number of multi-head-attentions. (default: :obj:`1`) group: str, optional The aggregation scheme to use for grouping node embeddings generated by different relations. (:obj:`"sum"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`). (default: :obj:`"sum"`) **kwargs: optional Additional arguments of :class:`gammagl.layers.conv.MessagePassing`. """ def __init__( self, in_channels, out_channels, metadata, heads: int = 1, group: str = "sum", dropout_rate=0., ): super().__init__() if not isinstance(in_channels, dict): in_channels = {node_type: in_channels for node_type in metadata[0]} self.in_channels = in_channels self.out_channels = out_channels self.heads = heads self.group = group self.k_lin = ModuleDict() self.q_lin = ModuleDict() self.v_lin = ModuleDict() self.a_lin = ModuleDict() self.skip = ParameterDict() self.dropout_rate = dropout_rate self.dropout = tlx.layers.Dropout(self.dropout_rate) for node_type, in_channels in self.in_channels.items(): self.k_lin[node_type] = Linear(in_features=in_channels, out_features=out_channels, act='relu6') self.q_lin[node_type] = Linear(in_features=in_channels, out_features=out_channels, act='relu6') self.v_lin[node_type] = Linear(in_features=in_channels, out_features=out_channels, act='relu6') self.a_lin[node_type] = Linear(in_features=out_channels, out_features=out_channels, act='relu6') self.skip[node_type] = Parameter(data=tlx.ops.convert_to_tensor(1.0)) self.a_rel = ParameterDict() self.m_rel = ParameterDict() self.p_rel = ParameterDict() dim = out_channels // heads init_a = tlx.initializers.TruncatedNormal() init_m = tlx.initializers.TruncatedNormal() for edge_type in metadata[1]: edge_type = '__'.join(edge_type) self.a_rel[edge_type + 'a'] = self._get_weights(edge_type + 'a', shape=(heads, dim, dim), init=init_a, order=True) self.m_rel[edge_type + 'm'] = self._get_weights(edge_type + 'm', shape=(heads, dim, dim), init=init_m, order=True) self.p_rel[edge_type] = Parameter(tlx.ones(shape=(heads,)))
[docs] def forward(self, x_dict, edge_index_dict): H, D = self.heads, self.out_channels // self.heads k_dict, q_dict, v_dict, out_dict = {}, {}, {}, {} # Iterate over node-types: for node_type, x in x_dict.items(): k_dict[node_type] = tlx.ops.reshape(self.k_lin[node_type](x), (-1, H, D)) q_dict[node_type] = tlx.ops.reshape(self.q_lin[node_type](x), (-1, H, D)) v_dict[node_type] = tlx.ops.reshape(self.v_lin[node_type](x), (-1, H, D)) out_dict[node_type] = [] # Iterate over edge-types: for edge_type, edge_index in edge_index_dict.items(): transpose = tlx.nn.Transpose([1, 0, 2]) src_type, _, dst_type = edge_type edge_type = '__'.join(edge_type) a_rel = self.a_rel[edge_type + 'a'] k = transpose((transpose(k_dict[src_type]) @ a_rel)) m_rel = self.m_rel[edge_type + 'm'] v = transpose((transpose(v_dict[src_type]) @ m_rel)) if tlx.BACKEND != 'tensorflow': edge_index = tlx.ops.convert_to_tensor(edge_index, dtype='int64') source_index, target_index = edge_index[0], edge_index[1] q_i = tlx.gather(q_dict[dst_type], target_index, axis=0) v_j = tlx.gather(v, source_index, axis=0) k_j = tlx.gather(k, source_index, axis=0) rel = self.p_rel[edge_type] out = self.propagate(edge_index=edge_index, aggr='sum', q_i=q_i, k_j=k_j, v_j=v_j, rel=rel, num_nodes=x_dict[dst_type].shape[0]) out_dict[dst_type].append(out) # Iterate over node-types: for node_type, outs in out_dict.items(): outs = tlx.stack(outs) out = tlx.ops.reduce_sum(outs, axis=0, keepdims=False) out = self.a_lin[node_type](out) alpha = tlx.ops.sigmoid(self.skip[node_type]) out = alpha * out + (1 - alpha) * x_dict[node_type] out_dict[node_type] = out return out_dict
[docs] def propagate(self, edge_index, aggr='sum', **kwargs): coll_dict = {} for k, v in kwargs.items(): coll_dict[k] = v coll_dict['edge_index'] = edge_index coll_dict['aggr'] = aggr coll_dict['target_index'] = edge_index[1] msg_kwargs = self.inspector.distribute('message', coll_dict) msg = self.message(**msg_kwargs) x = self.aggregate(msg, edge_index, num_nodes=kwargs['num_nodes'], aggr=aggr) x = self.update(x) return x
[docs] def message(self, k_j, q_i, v_j, rel, target_index, num_nodes): alpha = tlx.ops.reduce_sum(k_j * q_i, axis=-1, keepdims=False) alpha = alpha * rel alpha = alpha / math.sqrt(q_i.shape[-1]) alpha = self.dropout(segment_softmax(alpha, target_index, num_nodes)) out = v_j * tlx.expand_dims(alpha, -1) return tlx.ops.reshape(out, (-1, self.out_channels))