Source code for gammagl.models.specformer

import math
import tensorlayerx as tlx

def transpose_qkv(X, num_heads):
    """
    To split the q, k, v with multiheads
    
    Parameters
    ----------
    X: 
       The feature of shape: [bsz, query, embed_dim]
    num_heads:
        The number of heads
    
    Returns
    -------
    Tensor
        The tensor of shape: [bsz, query, num_heads, embed_dim / num_heads]
    """
    X = tlx.reshape(X, (tlx.get_tensor_shape(X)[0],
                        tlx.get_tensor_shape(X)[1], num_heads, -1))
    X = tlx.convert_to_tensor(tlx.convert_to_numpy(X).transpose((0, 2, 1, 3)))
    X = tlx.reshape(X, (-1, tlx.get_tensor_shape(X)[2], tlx.get_tensor_shape(X)[3]))
    return X

def transepose_output(X, num_heads):
    X = tlx.reshape(X, (-1, num_heads, tlx.get_tensor_shape(X)[1],
                        tlx.get_tensor_shape(X)[2]))
    X = tlx.convert_to_tensor(tlx.convert_to_numpy(X).transpose((0, 2, 1, 3)))
    X = tlx.reshape(X, (tlx.get_tensor_shape(X)[0],
                        tlx.get_tensor_shape(X)[1], -1))

    return X

class MultiHeadAttention(tlx.nn.Module):
    def __init__(self, hidden_dim, n_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = n_heads

        self.W_q = tlx.nn.Linear(in_features=hidden_dim, out_features=hidden_dim, act=tlx.relu)
        self.W_k = tlx.nn.Linear(in_features=hidden_dim, out_features=hidden_dim, act=tlx.relu)
        self.W_v = tlx.nn.Linear(in_features=hidden_dim, out_features=hidden_dim, act=tlx.relu)
        self.W_o = tlx.nn.Linear(in_features=hidden_dim, out_features=hidden_dim, act=tlx.relu)
    
    def dot_product_attention(self, querys, keys=None, values=None):
        keys = querys
        values = querys
        d = tlx.get_tensor_shape(querys)[-1]
        scores = tlx.bmm(querys, tlx.transpose(keys, perm=[0, 2, 1])) / math.sqrt(d)
        self.attn_weights = tlx.nn.Softmax(axis=-1)(scores)
        return tlx.bmm(self.attn_weights, values)

    def forward(self, q, k, v):
        is_batched = len(tlx.get_tensor_shape(q)) == 3

        if not is_batched:  # if input with no batchs
            q = tlx.expand_dims(q, axis=0)
            k = tlx.expand_dims(k, axis=0)
            v = tlx.expand_dims(v, axis=0)

        q = transpose_qkv(self.W_q(q), self.num_heads)
        k = transpose_qkv(self.W_k(k), self.num_heads)
        v = transpose_qkv(self.W_v(v), self.num_heads)

        output = self.dot_product_attention(q, k, v)
        output_concat = transepose_output(output, self.num_heads)
        res = self.W_o(output_concat)

        if not is_batched:  # if input with no batchs, remove the batchs which is added at the beginning
            res = tlx.squeeze(res, axis=0)

        return res

class SineEncoding(tlx.nn.Module):
    def __init__(self, hidden_dim=16):
        super(SineEncoding, self).__init__()
        self.constant = 100
        self.hidden_dim = hidden_dim
        initor = tlx.nn.initializers.XavierNormal()
        self.eig_w = tlx.nn.Linear(in_features=hidden_dim + 1,
                                   out_features=hidden_dim,
                                   act=tlx.ReLU,
                                   W_init=initor)

    def forward(self, e):
        ee = e * self.constant
        div = tlx.exp(tlx.arange(0, self.hidden_dim, 2, dtype=tlx.float32)
                      * (-math.log(10000) / self.hidden_dim))
        pe = tlx.expand_dims(ee, axis=1) * div
        eeig = tlx.concat(
            (tlx.expand_dims(e, axis=1), tlx.sin(pe), tlx.cos(pe)),
            axis=1
            )
        eeig = tlx.cast(eeig, dtype=tlx.float32)

        return self.eig_w(eeig)

class FeedForwardNetwork(tlx.nn.Module):

    def __init__(self, input_dim, hidden_dim, output_dim):
        super(FeedForwardNetwork, self).__init__()
        self.layer1 = tlx.nn.Linear(in_features=input_dim, out_features=hidden_dim, act=tlx.relu)
        self.layer2 = tlx.nn.Linear(in_features=hidden_dim, out_features=output_dim)

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        return x

class SpecLayer(tlx.nn.Module):
    def __init__(self, nbases, ncombines, prop_dropout=0.0, norm='none'):
        super(SpecLayer, self).__init__()
        self.prop_dropout = tlx.nn.Dropout(p=prop_dropout)

        # self.weight with shape: [1, m, d]
        if norm == 'none':
            self.weight = tlx.nn.Parameter(data=tlx.ones((1, nbases, ncombines)))
        else:
            self.weight = tlx.nn.Parameter(tlx.initializers.random_normal(mean=0.0, stddev=0.01) \
                                               (shape=(1, nbases, ncombines)))

        if norm == 'layer':
            self.norm = tlx.nn.LayerNorm(ncombines)
        elif norm == 'batch':
            self.norm = tlx.nn.BatchNorm1d(num_features=ncombines)
        else:
            self.norm = None

    def forward(self, x):
        x = self.prop_dropout(x) * self.weight
        x = tlx.reduce_sum(x, axis=1)

        if self.norm is not None:
            x = self.norm(x)
            x = tlx.relu(x)

        return x

[docs] class Specformer(tlx.nn.Module): r"""The Specformer from the `"Specformer:Spectral Graph Neural Networks Meet Transformers" <https://openreview.net/pdf?id=0pdSt3oyJa1>`_ paper Parameters ---------- nclass: int the number of node classes n_feat: int the node feature input dimension n_layer: int number of Speclayers hidden_dim: int the eigvalue representation dimension and the node feature dimension n_heads: int the number of attention heads tran_dropout: int the probability of dropout feat_dropout: int the probability of dropout prop_dropout: float the probability of dropout """ def __init__(self, nclass, n_feat, n_layer=2, hidden_dim=32, n_heads=4, tran_dropout=0.2, feat_dropout=0.4, prop_dropout=0.5, norm='none'): super(Specformer, self).__init__() # node feat: n_feat, eigenvalue repre shape: hidden_dim self.norm = norm self.n_feat = n_feat self.n_layer = n_layer self.n_heads = n_heads self.hidden_dim = hidden_dim self.feat_encoder = tlx.nn.Sequential( tlx.nn.Linear(in_features=n_feat, out_features=hidden_dim, act=tlx.relu), tlx.nn.Linear(in_features=hidden_dim, out_features=nclass), ) self.linear_encoder = tlx.nn.Linear(in_features=n_feat, out_features=hidden_dim, act=tlx.relu) self.classify = tlx.nn.Linear(in_features=hidden_dim, out_features=nclass, act=tlx.relu) self.eig_encoder = SineEncoding(hidden_dim) self.decoder = tlx.nn.Linear(in_features=hidden_dim, out_features=n_heads, act=tlx.relu) self.mha_norm = tlx.nn.LayerNorm(hidden_dim) self.ffn_norm = tlx.nn.LayerNorm(hidden_dim) self.mha_dropout = tlx.nn.Dropout(p=tran_dropout) self.ffn_dropout = tlx.nn.Dropout(p=tran_dropout) self.mha = MultiHeadAttention(hidden_dim=hidden_dim, n_heads=n_heads) self.ffn = FeedForwardNetwork(hidden_dim, hidden_dim, hidden_dim) self.feat_dp1 = tlx.nn.Dropout(p=feat_dropout) self.feat_dp2 = tlx.nn.Dropout(p=feat_dropout) list_layer = [] if norm == "none": for _ in range(n_layer): list_layer.append(SpecLayer(n_heads+1, nclass, prop_dropout, norm=norm)) else: for _ in range(n_layer): list_layer.append(SpecLayer(n_heads+1, hidden_dim, prop_dropout, norm=norm)) self.layers = list_layer
[docs] def forward(self, x, edge_index, e, u): ut = tlx.transpose(u, perm=[1, 0]) x = tlx.cast(x=x, dtype=tlx.float32) ut = tlx.cast(x=ut, dtype=tlx.float32) e = tlx.cast(x=e, dtype=tlx.float32) u = tlx.cast(x=u, dtype=tlx.float32) # make conversion of node feat if self.norm == 'none': h = self.feat_dp1(x) h = tlx.cast(x=h, dtype=tlx.float32) h = self.feat_encoder(h) h = self.feat_dp2(h) else: h = self.feat_dp1(x) h = self.linear_encoder(h) # conversion of eigenvalues repre eig = self.eig_encoder(e) mha_eig = self.mha_norm(eig) mha_eig = self.mha(mha_eig, mha_eig, mha_eig) eig = eig + self.mha_dropout(mha_eig) ffn_eig = self.ffn_norm(eig) ffn_eig = self.ffn(ffn_eig) eig = eig + self.ffn_dropout(ffn_eig) new_e = self.decoder(eig) for i in range(self.n_layer): basic_feats = [h] utx = ut @ h for j in range(self.n_heads): basic_feats.append(u @ (tlx.expand_dims(input=new_e[:, j], axis=1) * utx)) basic_feats = tlx.stack(values=basic_feats, axis=1) h = self.layers[i](basic_feats) if self.norm == 'none': return h else: h = self.feat_dp2(h) h = self.classify(h) return h