gammagl.layers.conv.HGTConv

class HGTConv(in_channels, out_channels, metadata, heads: int = 1, group: str = 'sum', dropout_rate=0.0)[source]

The Heterogeneous Graph Transformer (HGT) operator from the “Heterogeneous Graph Transformer” paper.

Parameters:
  • in_channels (int, dsict[str, int]) – Size of each input sample of every node type, or -1 to derive the size from the first input(s) to the forward method.

  • out_channels (int) – Size of each output sample.

  • metadata (tuple[list[str], list[tuple[str, str, str]]]) – The metadata of the heterogeneous graph, i.e. its node and edge types given by a list of strings and a list of string triplets, respectively. See gammagl.data.HeteroGraph.metadata for more information.

  • heads (int, optional) – Number of multi-head-attentions. (default: 1)

  • group (str, optional) – The aggregation scheme to use for grouping node embeddings generated by different relations. ("sum", "mean", "min", "max"). (default: "sum")

  • **kwargs (optional) – Additional arguments of gammagl.layers.conv.MessagePassing.

forward(x_dict, edge_index_dict)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

propagate(edge_index, aggr='sum', **kwargs)[source]

Function that perform message passing.

Parameters:
  • x (tensor) – input node feature.

  • edge_index (tensor) – edges from src to dst.

  • aggr (str, optional) – aggregation type, default=’sum’, optional=[‘sum’, ‘mean’, ‘max’].

  • fuse_kernel (bool, optional) – use fused kernel function to speed up, default = False.

  • kwargs (optional) – other parameters dict.

message(k_j, q_i, v_j, rel, target_index, num_nodes)[source]

Function that construct message from source nodes to destination nodes.

Parameters:
  • x (tensor) – input node feature.

  • edge_index (tensor) – edges from src to dst.

  • edge_weight (tensor, optional) – weight of each edge.

Returns:

  • tensor – output message

  • Returns – the message matrix, and the shape is [num_edges, message_dim]