gammagl.layers.conv.MessagePassing

class MessagePassing[source]

Base class for creating message passing layers of the form

\[\mathbf{x}_i^{\prime} = \gamma_{\mathbf{\Theta}} \left( \mathbf{x}_i, \square_{j \in \mathcal{N}(i)} \, \phi_{\mathbf{\Theta}} \left(\mathbf{x}_i, \mathbf{x}_j,\mathbf{e}_{j,i}\right) \right),\]

where \(\square\) denotes a differentiable, permutation invariant function, e.g., sum, mean or max, and \(\gamma_{\mathbf{\Theta}}\) and \(\phi_{\mathbf{\Theta}}\) denote differentiable functions such as MLPs.

special_args = {'edge_index', 'edge_weight', 'x'}
message(x, edge_index, edge_weight=None)[source]

Function that construct message from source nodes to destination nodes.

Parameters:
  • x (tensor) – input node feature.

  • edge_index (tensor) – edges from src to dst.

  • edge_weight (tensor, optional) – weight of each edge.

Returns:

  • tensor – output message

  • Returns – the message matrix, and the shape is [num_edges, message_dim]

aggregate(msg, edge_index, num_nodes=None, aggr='sum')[source]

Function that aggregates message from edges to destination nodes.

Parameters:
  • msg (tensor) – message construct by message function.

  • edge_index (tensor) – edges from src to dst.

  • num_nodes (int, optional) – number of nodes of the graph.

  • aggr (str, optional) – aggregation type, default = ‘sum’, optional=[‘sum’, ‘mean’, ‘max’].

Returns:

aggregation outcome.

Return type:

tensor

message_aggregate(x, edge_index, edge_weight=None, aggr='sum')[source]

try to fuse message and aggregate to reduce expensed edge information.

update(x)[source]

Function defines how to update node embeddings.

Parameters:

x (tensor) – aggregated message

propagate(x, edge_index, aggr='sum', fuse_kernel=False, **kwargs)[source]

Function that perform message passing.

Parameters:
  • x (tensor) – input node feature.

  • edge_index (tensor) – edges from src to dst.

  • aggr (str, optional) – aggregation type, default=’sum’, optional=[‘sum’, ‘mean’, ‘max’].

  • fuse_kernel (bool, optional) – use fused kernel function to speed up, default = False.

  • kwargs (optional) – other parameters dict.