Source code for gammagl.models.hpn

# !/usr/bin/env python3
# -*- coding:utf-8 -*-

import tensorlayerx as tlx
from gammagl.layers.conv import HPNConv

[docs] class HPN(tlx.nn.Module): r'''HPN proposed in `"Heterogeneous Graph Propagation Network" <https://ieeexplore.ieee.org/abstract/document/9428609>`_ paper. Parameters ---------- in_channels: int input dimension of the feature. out_channels: int output dimension of the feature. metadata: Tuple[List[str], List[Tuple[str, str, str]]] the metadata of the heterogeneous graph. drop_rate: float dropout probability. iter_K: int number of iteration used in APPNPConv. hidden_channels: int, optional hidden dimension of the feature. alpha: float, optional parameters used in APPNPConv. name: str, optional model name. ''' def __init__(self, in_channels, out_channels, metadata, drop_rate, iter_K, hidden_channels=128, alpha=0.2, name=None): super().__init__(name=name) self.hpn_conv = HPNConv(in_channels, hidden_channels, metadata, iter_K=iter_K, alpha=alpha, drop_rate=drop_rate) self.lin = tlx.nn.Linear(in_features=hidden_channels, out_features=out_channels)
[docs] def forward(self, x_dict, edge_index_dict, num_nodes_dict): x = self.hpn_conv(x_dict, edge_index_dict, num_nodes_dict) out = {} for node_type, _ in num_nodes_dict.items(): out[node_type] = self.lin(x[node_type]) return out