gammagl.layers.conv.MessagePassing¶
- class MessagePassing[source]¶
Base class for creating message passing layers of the form
\[\mathbf{x}_i^{\prime} = \gamma_{\mathbf{\Theta}} \left( \mathbf{x}_i, \square_{j \in \mathcal{N}(i)} \, \phi_{\mathbf{\Theta}} \left(\mathbf{x}_i, \mathbf{x}_j,\mathbf{e}_{j,i}\right) \right),\]where \(\square\) denotes a differentiable, permutation invariant function, e.g., sum, mean or max, and \(\gamma_{\mathbf{\Theta}}\) and \(\phi_{\mathbf{\Theta}}\) denote differentiable functions such as MLPs.
- special_args = {'edge_index', 'edge_weight', 'x'}¶
- message(x, edge_index, edge_weight=None)[source]¶
Function that construct message from source nodes to destination nodes.
- Parameters:
x (tensor) – input node feature.
edge_index (tensor) – edges from src to dst.
edge_weight (tensor, optional) – weight of each edge.
- Returns:
tensor – output message
Returns – the message matrix, and the shape is [num_edges, message_dim]
- aggregate(msg, edge_index, num_nodes=None, aggr='sum')[source]¶
Function that aggregates message from edges to destination nodes.
- Parameters:
- Returns:
aggregation outcome.
- Return type:
tensor
- message_aggregate(x, edge_index, edge_weight=None, aggr='sum')[source]¶
try to fuse message and aggregate to reduce expensed edge information.
- update(x)[source]¶
Function defines how to update node embeddings.
- Parameters:
x (tensor) – aggregated message