Source code for gammagl.layers.conv.message_passing

import tensorlayerx as tlx
from gammagl.mpops import *
from gammagl.utils import Inspector


[docs] class MessagePassing(tlx.nn.Module): r"""Base class for creating message passing layers of the form .. math:: \mathbf{x}_i^{\prime} = \gamma_{\mathbf{\Theta}} \left( \mathbf{x}_i, \square_{j \in \mathcal{N}(i)} \, \phi_{\mathbf{\Theta}} \left(\mathbf{x}_i, \mathbf{x}_j,\mathbf{e}_{j,i}\right) \right), where :math:`\square` denotes a differentiable, permutation invariant function, *e.g.*, sum, mean or max, and :math:`\gamma_{\mathbf{\Theta}}` and :math:`\phi_{\mathbf{\Theta}}` denote differentiable functions such as MLPs. """ special_args = { 'edge_index', 'x', 'edge_weight' } def __init__(self): super().__init__() self.inspector = Inspector(self) self.inspector.inspect(self.message) self.inspector.inspect(self.message_aggregate) self.__user_args__ = self.inspector.keys( ['message','message_aggregate']).difference(self.special_args)
[docs] def message(self, x, edge_index, edge_weight=None): """ Function that construct message from source nodes to destination nodes. Parameters ---------- x: tensor input node feature. edge_index: tensor edges from src to dst. edge_weight: tensor, optional weight of each edge. Returns ------- tensor output message Returns: the message matrix, and the shape is [num_edges, message_dim] """ msg = tlx.gather(x, edge_index[0, :]) if edge_weight is not None: edge_weight = tlx.expand_dims(edge_weight, -1) return msg * edge_weight else: return msg
[docs] def aggregate(self, msg, edge_index, num_nodes=None, aggr='sum'): """ Function that aggregates message from edges to destination nodes. Parameters ---------- msg: tensor message construct by message function. edge_index: tensor edges from src to dst. num_nodes: int, optional number of nodes of the graph. aggr: str, optional aggregation type, default = 'sum', optional=['sum', 'mean', 'max']. Returns ------- tensor aggregation outcome. """ dst_index = edge_index[1, :] if aggr == 'sum': return unsorted_segment_sum(msg, dst_index, num_nodes) elif aggr == 'mean': return unsorted_segment_mean(msg, dst_index, num_nodes) elif aggr == 'max': return unsorted_segment_max(msg, dst_index, num_nodes) else: raise NotImplementedError('Not support for this opearator')
[docs] def message_aggregate(self, x, edge_index, edge_weight=None, aggr='sum'): """ try to fuse message and aggregate to reduce expensed edge information. """ # use_ext is defined in mpops if use_ext is not None and use_ext: if edge_weight is None: edge_weight = torch.ones(edge_index.shape[1], device=x.device, dtype=x.dtype) out = gspmm(edge_index, edge_weight, x, aggr) else: msg = self.message(x, edge_index, edge_weight) out = self.aggregate(msg, edge_index) return out
[docs] def update(self, x): """ Function defines how to update node embeddings. Parameters ---------- x: tensor aggregated message """ return x
[docs] def propagate(self, x, edge_index, aggr='sum', **kwargs): """ Function that perform message passing. Parameters ---------- x: tensor input node feature. edge_index: tensor edges from src to dst. aggr: str, optional aggregation type, default='sum', optional=['sum', 'mean', 'max']. fuse_kernel: bool, optional use fused kernel function to speed up, default = False. kwargs: optional other parameters dict. """ if 'num_nodes' not in kwargs.keys() or kwargs['num_nodes'] is None: kwargs['num_nodes'] = x.shape[0] if tlx.BACKEND == "torch" and 'message_aggregate' in self.__class__.__dict__: coll_dict = self.__collect__(x, edge_index, aggr, kwargs) msg_agg_kwargs = self.inspector.distribute('message_aggregate', coll_dict) x = self.message_aggregate(**msg_agg_kwargs) else: coll_dict = self.__collect__(x, edge_index, aggr, kwargs) msg_kwargs = self.inspector.distribute('message', coll_dict) msg = self.message(**msg_kwargs) x = self.aggregate(msg, edge_index, num_nodes=kwargs['num_nodes'], aggr=aggr) x = self.update(x) return x
def __collect__(self, x, edge_index, aggr, kwargs): out = {} for k, v in kwargs.items(): out[k] = v out['x'] = x out['edge_index'] = edge_index out['aggr'] = aggr return out