Source code for gammagl.layers.conv.hetero_wrapper

import tensorlayerx as tlx
import gammagl.mpops as mpops
import warnings
from collections import defaultdict
from typing import Dict, Optional

def group(xs, aggr):
    if len(xs) == 0:
        return None
    elif aggr is None:
        return tlx.stack(xs, axis=1)
    elif len(xs) == 1:
        return xs[0]
    else:
        out = tlx.stack(xs, axis=0)
        out = getattr(tlx, 'reduce_'+aggr)(out, dim=0)
        out = out[0] if isinstance(out, tuple) else out
        return out

[docs] class HeteroConv(tlx.nn.Module): r"""A generic wrapper for computing graph convolution on heterogeneous graphs. This layer will pass messages from source nodes to target nodes based on the bipartite GNN layer given for a specific edge type. If multiple relations point to the same destination, their results will be aggregated according to :attr:`aggr`. .. code:: python >>> hetero_conv = HeteroConv({ ('paper', 'cites', 'paper'): GCNConv(64, 16), ('author', 'writes', 'paper'): SAGEConv((128, 64), 64), ('paper', 'written_by', 'author'): GATConv((64, 128), 64), }, aggr='sum') >>> out_dict = hetero_conv(x_dict, edge_index_dict) >>> print(list(out_dict.keys())) ['paper', 'author'] Parameters ---------- convs: dict[tuple[str, str, str], nn.module] A dictionary holding a bipartite :class:`~gammagl.layers.conv.MessagePassing` layer for each individual edge type. aggr: str, optional The aggregation scheme to use for grouping node embeddings generated by different relations. (:obj:`"sum"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`, :obj:`None`). (default: :obj:`"sum"`) """ def __init__(self, convs: dict, aggr: Optional[str] = "sum"): super().__init__() src_node_types = set([key[0] for key in convs.keys()]) dst_node_types = set([key[-1] for key in convs.keys()]) if len(src_node_types - dst_node_types) > 0: warnings.warn( f"There exist node types ({src_node_types - dst_node_types}) " f"whose representations do not get updated during message " f"passing as they do not occur as destination type in any " f"edge type. This may lead to unexpected behaviour.") self.convs = ModuleDict({'__'.join(k): v for k, v in convs.items()}) self.aggr = aggr
[docs] def reset_parameters(self): for conv in self.convs.values(): conv.reset_parameters()
[docs] def forward( self, x_dict, edge_index_dict, *args_dict, **kwargs_dict, ): r""" Parameters ---------- x_dict: dict[str, tensor] A dictionary holding node feature information for each individual node type. edge_index_dict: dict[tuple[str, str, str], tensor] A dictionary holding graph connectivity information for each individual edge type. *args_dict: optional Additional forward arguments of invididual :class:`gammagl.layers.conv.MessagePassing` layers. **kwargs_dict: optional Additional forward arguments of individual :class:`gammagl.layers.conv.MessagePassing` layers. For example, if a specific GNN layer at edge type :obj:`edge_type` expects edge attributes :obj:`edge_attr` as a forward argument, then you can pass them to :meth:`~gammagl.layers.conv.HeteroConv.forward` via :obj:`edge_attr_dict = { edge_type: edge_attr }`. """ out_dict = defaultdict(list) for edge_type, edge_index in edge_index_dict.items(): src, rel, dst = edge_type str_edge_type = '__'.join(edge_type) if str_edge_type not in self.convs: continue args = [] for value_dict in args_dict: if edge_type in value_dict: args.append(value_dict[edge_type]) elif src == dst and src in value_dict: args.append(value_dict[src]) elif src in value_dict or dst in value_dict: args.append( (value_dict.get(src, None), value_dict.get(dst, None))) kwargs = {} for arg, value_dict in kwargs_dict.items(): arg = arg[:-5] # `{*}_dict` if edge_type in value_dict: kwargs[arg] = value_dict[edge_type] elif src == dst and src in value_dict: kwargs[arg] = value_dict[src] elif src in value_dict or dst in value_dict: kwargs[arg] = (value_dict.get(src, None), value_dict.get(dst, None)) conv = self.convs[str_edge_type] if src == dst: out = conv(x_dict[src], edge_index, *args, **kwargs) else: out = conv((x_dict[src], x_dict[dst]), edge_index, *args, **kwargs) out_dict[dst].append(out) for key, value in out_dict.items(): out_dict[key] = group(value, self.aggr) return out_dict
def __repr__(self) -> str: return f'{self.__class__.__name__}(num_relations={len(self.convs)})'