gammagl.layers.conv.HGTConv¶
- class HGTConv(in_channels, out_channels, metadata, heads: int = 1, group: str = 'sum', dropout_rate=0.0)[source]¶
The Heterogeneous Graph Transformer (HGT) operator from the “Heterogeneous Graph Transformer” paper.
- Parameters:
in_channels (int, dsict[str, int]) – Size of each input sample of every node type, or
-1to derive the size from the first input(s) to the forward method.out_channels (int) – Size of each output sample.
metadata (tuple[list[str], list[tuple[str, str, str]]]) – The metadata of the heterogeneous graph, i.e. its node and edge types given by a list of strings and a list of string triplets, respectively. See
gammagl.data.HeteroGraph.metadatafor more information.heads (int, optional) – Number of multi-head-attentions. (default:
1)group (str, optional) – The aggregation scheme to use for grouping node embeddings generated by different relations. (
"sum","mean","min","max"). (default:"sum")**kwargs (optional) – Additional arguments of
gammagl.layers.conv.MessagePassing.
- forward(x_dict, edge_index_dict)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- propagate(edge_index, aggr='sum', **kwargs)[source]¶
Function that perform message passing.
- Parameters:
- message(k_j, q_i, v_j, rel, target_index, num_nodes)[source]¶
Function that construct message from source nodes to destination nodes.
- Parameters:
x (tensor) – input node feature.
edge_index (tensor) – edges from src to dst.
edge_weight (tensor, optional) – weight of each edge.
- Returns:
tensor – output message
Returns – the message matrix, and the shape is [num_edges, message_dim]