Source code for gammagl.models.hardgat

import tensorlayerx as tlx
from gammagl.layers.conv import HardGATConv
from tensorlayerx.nn import ModuleList

[docs] class HardGATModel(tlx.nn.Module): r"""The graph hard attentional operator from the `"Graph Representation Learning via Hard and Channel-Wise Attention Networks" <https://dl.acm.org/doi/pdf/10.1145/3292500.3330897>`_ paper. Parameters ---------- feature_dim: int input feature dimension. hidden_dim: int hidden dimension. num_class: int number of classes. heads: int number of attention heads. drop_rate: float dropout rate. k: int number of neighbors to attention. name: str, optional model name. """ def __init__(self, feature_dim, hidden_dim, num_class, heads, drop_rate, k, num_layers, name=None): super().__init__(name=name) self.hardgat_list = ModuleList() if (num_layers == 1): hidden_dim = num_class for i in range(num_layers): if i==0: self.hardgat_list.append( HardGATConv(in_channels=feature_dim, out_channels=hidden_dim, heads=heads, dropout_rate=drop_rate, k=k, concat=True) ) elif i==(num_layers-1): self.hardgat_list.append( HardGATConv(in_channels=hidden_dim * heads, out_channels=num_class, heads=heads, dropout_rate=drop_rate, k=k, concat=False) ) else: self.hardgat_list.append( HardGATConv(in_channels=hidden_dim * heads, out_channels=hidden_dim, heads=heads, dropout_rate=drop_rate, k=k, concat=True) ) self.elu = tlx.layers.ELU() self.dropout = tlx.layers.Dropout(drop_rate)
[docs] def forward(self, x, edge_index, num_nodes): for index, hardgat in enumerate(self.hardgat_list): x = self.dropout(x) x = hardgat(x, edge_index, num_nodes) if index < (self.hardgat_list.__len__() - 1): x = self.elu(x) return x