Source code for gammagl.models.compgcn

import tensorlayerx as tlx
from gammagl.layers.conv import CompConv
[docs] class CompGCN(tlx.nn.Module): """Composition-based Multi-Relational graph convoluation nerworks""" def __init__(self, feature_dim, hidden_dim, num_class,op, num_relations,num_entity, name=None): super().__init__(name=name) self.op = op self.conv1 = CompConv(feature_dim, hidden_dim, num_relations,op) self.conv2 = CompConv(hidden_dim, num_class,num_relations,op) self.relu = tlx.ReLU() self.init_input = tlx.random_normal(shape=(num_entity, feature_dim), dtype=tlx.float32) self.ref_input = tlx.random_normal(shape=(num_relations + 1, 32), dtype=tlx.float32)
[docs] def forward(self, edge_index, edge_type): x,r = self.conv1(self.init_input, edge_index, edge_type, self.ref_input) x = self.relu(x) x,r = self.conv2(x, edge_index, edge_type, r) return x