Source code for gammagl.layers.conv.compgcn_conv

# !/usr/bin/env python3
# -*- coding:utf-8 -*-

# @Time    : 2022/11/8 23:47
# @Author  : yijian
# @FileName: compgcn_conv.py.py

import tensorlayerx as tlx
from gammagl.layers.conv import MessagePassing
from gammagl.mpops import *

def masked_edge_index(edge_index, edge_mask):
    if tlx.BACKEND == 'mindspore':
        idx = tlx.convert_to_tensor([i for i, v in enumerate(edge_mask) if v], dtype=tlx.int64)
        return tlx.gather(edge_index, idx)
    else:
        return tlx.transpose(tlx.transpose(edge_index)[edge_mask])


[docs] class CompConv(MessagePassing): ''' Paper: Composition-based Multi-Relational Graph Convolutional Networks Code: https://github.com/MichSchli/RelationPrediction Parameters ---------- in_channels: int the input dimension of the features. out_channels: int the output dimension of the features. num_relations: int the number of relations in the graph. op: str the operation used in message creation. add_bias: bool whether to add bias. ''' def __init__(self, in_channels, out_channels, num_relations, op='sub', add_bias=True): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.num_relations = num_relations self.op = op self.add_bias = add_bias self.w_loop = tlx.layers.Linear(out_features=out_channels, in_features=in_channels, W_init='xavier_uniform', b_init=None) self.w_in = tlx.layers.Linear(out_features=out_channels, in_features =in_channels, W_init='xavier_uniform', b_init=None) self.w_out = tlx.layers.Linear(out_features=out_channels, in_features=in_channels, W_init='xavier_uniform', b_init=None) self.w_rel = tlx.layers.Linear(out_features=out_channels, in_features=in_channels, W_init='xavier_uniform', b_init=None) self.initor = tlx.initializers.truncated_normal() if self.add_bias: self.bias = self._get_weights(var_name="bias", shape=(out_channels,), init=self.initor) return
[docs] def forward(self, x, edge_index, edge_type=None, ref_emb=None): edge_half_num = int(edge_index.shape[1]/2) edge_in_index = edge_index[:,:edge_half_num] edge_out_index = edge_index[:,edge_half_num:] edge_in_type = edge_type[:edge_half_num] edge_out_type = edge_type[edge_half_num:] loop_index = [n for n in range(0, x.shape[0])] loop_index = tlx.ops.convert_to_tensor(loop_index) loop_index = tlx.ops.stack([loop_index,loop_index]) loop_type = [self.num_relations for n in range(0, x.shape[0])] loop_type = tlx.ops.convert_to_tensor(loop_type) in_res = self.propagate(x,edge_in_index,edge_type=edge_in_type,linear=self.w_in,rel_emb=ref_emb) out_res = self.propagate(x,edge_out_index,edge_type=edge_out_type,linear=self.w_out,rel_emb=ref_emb) loop_res = self.propagate(x,loop_index,edge_type=loop_type,linear=self.w_loop,rel_emb=ref_emb) ref_emb = self.w_rel(ref_emb) res = in_res*(1/3) + out_res*(1/3) + loop_res*(1/3) if self.add_bias: res = res + self.bias return res,ref_emb
# def propagate(self, x, edge_index,edge_type, aggr='sum', **kwargs): # """ # Function that perform message passing. # Parameters # ---------- # x: # input node feature. # edge_index: # edges from src to dst. # aggr: # aggregation type, default='sum', optional=['sum', 'mean', 'max']. # kwargs: # other parameters dict. # """ # if 'num_nodes' not in kwargs.keys() or kwargs['num_nodes'] is None: # kwargs['num_nodes'] = x.shape[0] # coll_dict = self.__collect__(x, edge_index,edge_type, aggr, kwargs) # msg_kwargs = self.inspector.distribute('message', coll_dict) # msg_kwargs['linear'] = kwargs['linear'] # msg_kwargs['rel_emb'] = kwargs['rel_emb'] # msg_kwargs['edge_type'] = edge_type # msg = self.message(**msg_kwargs) # x = self.aggregate(msg, edge_index, num_nodes=kwargs['num_nodes'], aggr=aggr,dim_size=x.shape[0]) # x = self.update(x) # return x # def __collect__(self, x, edge_index,edge_type, aggr, kwargs): # out = {} # for k, v in kwargs.items(): # out[k] = v # out['x'] = x # out['edge_index'] = edge_index # out['aggr'] = aggr # out['edge_type'] = edge_type # return out
[docs] def message(self, x, edge_index, edge_type, edge_weight=None, rel_emb=None, linear=None): """ Function that construct message from source nodes to destination nodes. Parameters ---------- x: tensor input node feature. edge_index: tensor edges from src to dst. edge_weight: tensor weight of each edge. Returns ------- tensor output message. """ rel_emb = tlx.gather(rel_emb, edge_type) x_emb = tlx.gather(x, edge_index[1]) if self.op == 'sub': x_rel_emb = x_emb - rel_emb elif self.op == 'mult': x_rel_emb = x_emb * rel_emb msg = linear(x_rel_emb) if edge_weight is not None: edge_weight = tlx.expand_dims(edge_weight, -1) return msg * edge_weight else: return msg
# def aggregate(self, msg, edge_index, num_nodes=None, aggr='sum',dim_size=None): # """ # Function that aggregates message from edges to destination nodes. # Parameters # ---------- # msg: tensor # message construct by message function. # edge_index: tensor # edges from src to dst. # num_nodes: int # number of nodes of the graph. # aggr: str # aggregation type, default = 'sum', optional=['sum', 'mean', 'max']. # Returns # ------- # tensor # output representation. # """ # dst_index = edge_index[0, :] # if aggr == 'sum': # return unsorted_segment_sum(msg, dst_index, num_nodes) # #return unsorted_segment_sum(msg, dst_index, num_nodes) # elif aggr == 'mean': # return unsorted_segment_mean(msg, dst_index, num_nodes) # elif aggr == 'max': # return unsorted_segment_max(msg, dst_index, num_nodes) # else: # raise NotImplementedError('Not support for this opearator')