import tensorlayerx as tlx
from gammagl.layers.conv import GCNConv
import tensorlayerx.nn as nn
import numpy as np
from scipy.sparse import coo_matrix
from scipy.special import softmax
[docs]
class CoGSLModel(nn.Module):
r"""CoGSL Model proposed in '"Compact Graph Structure Learning via Mutual Information Compression"
<https://arxiv.org/pdf/2201.05540.pdf>'_ paper.
Parameters
----------
num_feature: int
input feature dimension.
cls_hid: int
Classification hidden dimension.
num_class: int
number of classes.
gen_hid: int
GenView hidden dimension.
mi_hid: int
Mi_NCE hidden dimension.
com_lambda_v1: float
hyperparameter used to generate estimated view 1.
com_lambda_v2: float
hyperparameter used to generate estimated view 2.
lam: float
hyperparameter used to fusion views.
alpha: float
hyperparameter used to fusion views.
cls_dropout: float
Classification dropout rate.
ve_dropout: float
View_Estimator dropout rate.
tau: float
hyperparameter used to generate sim_matrix to get mi loss.
ggl: bool
whether to use gcnconv of gammagl.
big: bool
whether the dataset is too big.
batch: int
determine the sampling size when the dataset is too big.
"""
def __init__(self, num_feature, cls_hid, num_class, gen_hid, mi_hid,
com_lambda_v1, com_lambda_v2, lam, alpha, cls_dropout, ve_dropout, tau, ggl, big, batch):
super(CoGSLModel, self).__init__()
self.cls = Classification(num_feature, cls_hid, num_class, cls_dropout, ggl)
self.ve = View_Estimator(num_feature, gen_hid, com_lambda_v1, com_lambda_v2, ve_dropout, ggl)
self.mi = MI_NCE(num_feature, mi_hid, tau, ggl, big, batch)
self.fusion = Fusion(lam, alpha)
[docs]
def get_view(self, data):
new_v1, new_v2 = self.ve(data)
return new_v1, new_v2
[docs]
def get_mi_loss(self, feat, views):
mi_loss = self.mi(views, feat)
return mi_loss
[docs]
def get_cls_loss(self, v1, v2, feat):
prob_v1 = self.cls(feat, v1, "v1")
prob_v2 = self.cls(feat, v2, "v2")
logits_v1 = tlx.log(prob_v1 + 1e-8)
logits_v2 = tlx.log(prob_v2 + 1e-8)
return logits_v1, logits_v2, prob_v1, prob_v2
[docs]
def get_v_cls_loss(self, v, feat):
logits = tlx.log(self.cls(feat, v, "v") + 1e-8)
return logits
[docs]
def get_fusion(self, v1, prob_v1, v2, prob_v2):
v = self.fusion(v1, prob_v1, v2, prob_v2)
return v
# base
class GCN_two(nn.Module):
def __init__(self, input_dim, hid_dim1, hid_dim2, dropout=0., activation="relu"):
super(GCN_two, self).__init__()
self.conv1 = GCN_one(input_dim, hid_dim1)
self.conv2 = GCN_one(hid_dim1, hid_dim2)
self.dropout = tlx.layers.Dropout(dropout)
assert activation in ["relu", "leaky_relu", "elu"]
if activation == 'relu':
self.activation = nn.ReLU()
if activation == 'leaky_relu':
self.activation = nn.LeakyReLU()
if activation == 'elu':
self.activation = nn.ELU()
def forward(self, feature, adj):
x1 = self.activation(self.conv1(feature, adj))
x1 = self.dropout(x1)
x2 = self.conv2(x1, adj)
return x2
class GCN_one(nn.Module):
def __init__(self, in_ft, out_ft, bias=True, activation=None):
super(GCN_one, self).__init__()
self.fc = nn.Linear(in_features=in_ft, out_features=out_ft, W_init='xavier_uniform')
self.activation = activation
if bias:
initor = tlx.initializers.Zeros()
self.bias = self._get_weights("bias", shape=(1, out_ft), init=initor)
else:
self.register_parameter('bias', None)
def forward(self, feat, adj):
feat = self.fc(feat)
out = tlx.matmul(adj, feat)
if self.bias is not None:
out += self.bias
if self.activation is not None:
out = self.activation(out)
return out
class GCN_two_ggl(nn.Module):
def __init__(self, input_dim, hid_dim1, hid_dim2, dropout=0., activation="relu"):
super(GCN_two_ggl, self).__init__()
self.conv1 = GCNConv(input_dim, hid_dim1)
self.conv2 = GCNConv(hid_dim1, hid_dim2)
self.dropout = tlx.layers.Dropout(dropout)
assert activation in ["relu", "leaky_relu", "elu"]
if activation == 'relu':
self.activation = nn.ReLU()
if activation == 'leaky_relu':
self.activation = nn.LeakyReLU()
if activation == 'elu':
self.activation = nn.ELU()
def forward(self, feature, adj):
adj = tlx.convert_to_numpy(adj)
non_zero_rows, non_zero_cols = np.nonzero(adj)
edge_index = tlx.convert_to_tensor([non_zero_rows, non_zero_cols])
x1 = self.activation(self.conv1(feature, edge_index))
x1 = self.dropout(x1)
x2 = self.conv2(x1, edge_index)
return x2
class GCN_one_ggl(nn.Module):
def __init__(self, in_ft, out_ft, bias=True, activation=None):
super(GCN_one_ggl, self).__init__()
self.conv1 = GCNConv(in_ft, out_ft)
self.activation = activation
if bias:
initor = tlx.initializers.Zeros()
self.bias = self._get_weights("bias", shape=(1, out_ft), init=initor)
else:
self.register_parameter('bias', None)
def forward(self, feat, adj):
adj = tlx.convert_to_numpy(adj)
non_zero_rows, non_zero_cols = np.nonzero(adj)
edge_index = tlx.convert_to_tensor([non_zero_rows, non_zero_cols])
out = self.conv1(feat, edge_index)
if self.bias is not None:
out += self.bias
if self.activation is not None:
out = self.activation(out)
return out
# cls
class Classification(nn.Module):
def __init__(self, num_feature, cls_hid, num_class, dropout, ggl):
super(Classification, self).__init__()
if ggl==False:
self.encoder_v1 = GCN_two(num_feature, cls_hid, num_class, dropout)
self.encoder_v2 = GCN_two(num_feature, cls_hid, num_class, dropout)
self.encoder_v = GCN_two(num_feature, cls_hid, num_class, dropout)
else:
self.encoder_v1 = GCN_two_ggl(num_feature, cls_hid, num_class, dropout)
self.encoder_v2 = GCN_two_ggl(num_feature, cls_hid, num_class, dropout)
self.encoder_v = GCN_two_ggl(num_feature, cls_hid, num_class, dropout)
def forward(self, feat, view, flag):
if flag == "v1":
prob = nn.Softmax()(self.encoder_v1(feat, view))
elif flag == "v2":
prob = nn.Softmax()(self.encoder_v2(feat, view))
elif flag == "v":
prob = nn.Softmax()(self.encoder_v(feat, view))
return prob
# contrast
class Contrast:
def __init__(self, tau):
self.tau = tau
def sim(self, z1, z2):
z1_norm = tlx.sqrt(tlx.reduce_sum(tlx.square(z1), axis=1, keepdims=True))
z2_norm = tlx.sqrt(tlx.reduce_sum(tlx.square(z2), axis=1, keepdims=True))
dot_numerator = tlx.matmul(z1, tlx.transpose(z2))
dot_denominator = tlx.matmul(z1_norm, tlx.transpose(z2_norm))
sim_matrix = tlx.exp(dot_numerator / dot_denominator / self.tau)
return sim_matrix
def cal(self, z1_proj, z2_proj):
matrix_z1z2 = self.sim(z1_proj, z2_proj)
matrix_z2z1 = tlx.transpose(matrix_z1z2)
matrix_z1z2 = matrix_z1z2 / (tlx.reshape(tlx.reduce_sum(matrix_z1z2, axis=1), [-1,1]) + 1e-8)
lori_v1v2 = -tlx.reduce_mean(tlx.log(tlx.diag(matrix_z1z2)+1e-8))
matrix_z2z1 = matrix_z2z1 / (tlx.reshape(tlx.reduce_sum(matrix_z2z1, axis=1), [-1, 1]) + 1e-8)
lori_v2v1 = -tlx.reduce_mean(tlx.log(tlx.diag(matrix_z2z1)+1e-8))
return (lori_v1v2 + lori_v2v1) / 2
# fusion
class Fusion(nn.Module):
def __init__(self, lam, alpha):
super(Fusion, self).__init__()
self.lam = lam
self.alpha = alpha
def get_weight(self, prob):
out, _ = tlx.topk(prob, 2, dim=1, largest=True, sorted=True)
fir = out[:, 0]
sec = out[:, 1]
w = tlx.exp(self.alpha*(self.lam*tlx.log(fir+1e-8) + (1-self.lam)*tlx.log(fir-sec+1e-8)))
return w
def forward(self, v1, prob_v1, v2, prob_v2):
w_v1 = self.get_weight(prob_v1)
w_v2 = self.get_weight(prob_v2)
beta_v1 = w_v1 / (w_v1 + w_v2)
beta_v2 = w_v2 / (w_v1 + w_v2)
beta_v1 = tlx.reshape(beta_v1, (-1,1))
beta_v2 = tlx.reshape(beta_v2, (-1,1))
v = beta_v1 * v1 + beta_v2 * v2
return v
# mi_nce
class MI_NCE(nn.Module):
def __init__(self, num_feature, mi_hid, tau, ggl, big, batch):
super(MI_NCE, self).__init__()
if ggl == False:
self.gcn = GCN_one(num_feature, mi_hid, activation=nn.PRelu())
self.gcn1 = GCN_one(num_feature, mi_hid, activation=nn.PRelu())
self.gcn2 = GCN_one(num_feature, mi_hid, activation=nn.PRelu())
else:
self.gcn = GCN_one_ggl(num_feature, mi_hid, activation=nn.PRelu())
self.gcn1 = GCN_one_ggl(num_feature, mi_hid, activation=nn.PRelu())
self.gcn2 = GCN_one_ggl(num_feature, mi_hid, activation=nn.PRelu())
self.proj = nn.Sequential(
nn.Linear(in_features=mi_hid, out_features=mi_hid),
nn.ELU(),
nn.Linear(in_features=mi_hid, out_features=mi_hid)
)
self.con = Contrast(tau)
self.big = big
self.batch = batch
def forward(self, views, feat):
v_emb = self.proj(self.gcn(feat, views[0]))
v1_emb = self.proj(self.gcn1(feat, views[1]))
v2_emb = self.proj(self.gcn2(feat, views[2]))
# if dataset is so big, we will randomly sample part of nodes to perform MI estimation
if self.big == True:
idx = np.random.choice(feat.shape[0], self.batch, replace=False)
idx.sort()
v_emb = v_emb[idx]
v1_emb = v1_emb[idx]
v2_emb = v2_emb[idx]
vv1 = self.con.cal(v_emb, v1_emb)
vv2 = self.con.cal(v_emb, v2_emb)
v1v2 = self.con.cal(v1_emb, v2_emb)
return vv1, vv2, v1v2
# view_estimator
class GenView(nn.Module):
def __init__(self, num_feature, hid, com_lambda, dropout, ggl):
super(GenView, self).__init__()
if ggl == False:
self.gen_gcn = GCN_one(num_feature, hid, activation=nn.ReLU())
else:
self.gen_gcn = GCN_one_ggl(num_feature, hid, activation=nn.ReLU())
self.gen_mlp = nn.Linear(in_features=2 * hid, out_features=1, W_init='xavier_normal')
self.relu = nn.ReLU()
self.softmax = nn.Softmax(axis=1)
self.com_lambda = com_lambda
self.dropout = tlx.layers.Dropout(dropout)
def forward(self, v_ori, feat, v_indices, num_node):
emb = self.gen_gcn(feat, v_ori)
f1 = tlx.gather(emb, v_indices[0])
f2 = tlx.gather(emb, v_indices[1])
ff = tlx.concat([f1, f2], axis=1)
temp = tlx.reshape(self.gen_mlp(self.dropout(ff)), (-1,))
coo = coo_matrix( (tlx.convert_to_numpy(temp.cpu()), (v_indices[0].cpu(), v_indices[1].cpu())), shape= (num_node, num_node))
dense = coo.todense()
dense[dense == 0] = np.NINF
pi = tlx.convert_to_tensor(softmax(dense,axis=1))
gen_v = v_ori + self.com_lambda * pi
return gen_v
class View_Estimator(nn.Module):
def __init__(self, num_feature, gen_hid, com_lambda_v1, com_lambda_v2, dropout, ggl):
super(View_Estimator, self).__init__()
self.v1_gen = GenView(num_feature, gen_hid, com_lambda_v1, dropout, ggl)
self.v2_gen = GenView(num_feature, gen_hid, com_lambda_v2, dropout, ggl)
def forward(self, data):
new_v1 = self.normalize(data['name'], self.v1_gen(data['view1'], data['x'], data['v1_indice'], data['num_nodes']))
new_v2 = self.normalize(data['name'], self.v2_gen(data['view2'], data['x'], data['v2_indice'], data['num_nodes']))
return new_v1, new_v2
def normalize(self, dataset, adj):
if dataset in ["wikics", "ms", "citeseer"]:
adj_ = (adj + tlx.transpose(adj))
normalized_adj = adj_
else:
adj_ = (adj + tlx.transpose(adj))
normalized_adj = self._normalize(adj_ + tlx.convert_to_tensor(np.eye(adj_.shape[0]), dtype=tlx.float32))
return normalized_adj
def _normalize(self, mx):
rowsum = tlx.reduce_sum(mx, axis=1) + 1e-6 # avoid NaN
r_inv = tlx.pow(rowsum, -1/2)
r_inv = tlx.convert_to_tensor(r_inv)
r_mat_inv = tlx.diag(r_inv)
mx = r_mat_inv @ mx
mx = mx @ r_mat_inv
return mx