Source code for gammagl.models.rgcn

import tensorlayerx as tlx
from gammagl.layers.conv import RGCNConv


[docs] class RGCN(tlx.nn.Module): """relational graph convoluation nerworks""" def __init__(self, feature_dim, hidden_dim, num_class, num_relations, name=None): super().__init__(name=name) self.conv1 = RGCNConv(feature_dim, hidden_dim, num_relations) self.conv2 = RGCNConv(hidden_dim, num_class, num_relations) self.relu = tlx.ReLU()
[docs] def forward(self, edge_index, edge_type): x = self.relu(self.conv1(None, edge_index, edge_type)) x = self.conv2(x, edge_index, edge_type) return x