import tensorlayerx as tlx
from gammagl.layers.conv import GCNConv, SAGEConv
from gammagl.layers.pool import global_sort_pool
import math
[docs]
class DGCNN(tlx.nn.Module):
r"""DGCNN proposed in `"An End-to-End Deep Learning Architecture for Graph Classification"
<https://dl.acm.org/doi/pdf/10.5555/3504035.3504579>`_ paper.
Parameters
----------
feature_dim: int
input feature dimension.
hidden_dim: int
hidden dimension.
num_layers: int
number of layers.
gcn_type: str
convolution layer type.
k: int or float
The number of nodes to hold for each graph in SortPooling.
train_dataset: dataset
train dataset to extract minimum number of nodes to generate k.
dropout: float
dropout rate.
name: str
model name.
"""
def __init__(self, feature_dim,
hidden_dim,
num_layers,
gcn_type = 'gcn',
k = 0.6,
train_dataset=None,
dropout = 0.5,
name=None):
if gcn_type == 'gcn':
GNN = GCNConv
else:
GNN = SAGEConv
super().__init__(name=name)
if k < 1: # Transform percentile to number.
if train_dataset is None:
k = 30
else:
num_nodes = sorted([g.num_nodes for g in train_dataset])
k = num_nodes[int(math.ceil(k * len(num_nodes))) - 1]
k = max(10, k)
self.k = int(k)
self.convs = tlx.nn.ModuleList()
self.convs.append(GNN(feature_dim, hidden_dim))
for i in range(0, num_layers - 1):
self.convs.append(GNN(hidden_dim, hidden_dim))
self.convs.append(GNN(hidden_dim, 1))
conv1d_channels = [16, 32]
total_latent_dim = hidden_dim * num_layers + 1
conv1d_kws = [total_latent_dim, 5]
self.conv1 = tlx.nn.Conv1d(
out_channels=conv1d_channels[0],
kernel_size=conv1d_kws[0],
stride=conv1d_kws[0],
act=tlx.nn.ReLU,
padding='VALID'
)
self.maxpool1d = tlx.nn.MaxPool1d(2, 2, 'VALID')
self.conv2 = tlx.nn.Conv1d(
out_channels=conv1d_channels[1],
kernel_size=conv1d_kws[1],
act=tlx.nn.ReLU,
padding='VALID'
)
self.lin1 = tlx.nn.Linear(out_features=128, act=tlx.nn.ReLU)
self.drop = tlx.nn.Dropout(p=dropout)
self.lin2 = tlx.nn.Linear(out_features=1)
[docs]
def forward(self, x, edge_index, batch):
xs = [x]
for conv in self.convs:
xs += [tlx.tanh(conv(xs[-1], edge_index))]
x = tlx.concat(xs[1:], axis=-1)
# Global pooling.
x = global_sort_pool(x, batch, self.k)
x = tlx.expand_dims(x, -1) # [num_graphs, k * hidden, 1]
x = self.conv1(x)
x = self.maxpool1d(x)
x = self.conv2(x)
x = tlx.reshape(x, (x.shape[0], -1)) # [num_graphs, dense_dim]
# MLP.
x = self.lin1(x)
x = self.drop(x)
x = self.lin2(x)
return x