Source code for gammagl.layers.conv.pna_conv

import tensorlayerx as tlx
import tensorlayerx.nn
from tensorlayerx import reduce_sum
from tensorlayerx.nn import ReLU, Linear, ModuleList, Sequential
from gammagl.layers.pool.glob import global_sum_pool, global_mean_pool, global_max_pool, global_min_pool
from gammagl.layers.conv.message_passing import MessagePassing
from gammagl.utils.degree import degree


[docs] class PNAConv(MessagePassing): r"""The Principal Neighbourhood Aggregation graph convolution operator from the `"Principal Neighbourhood Aggregation for Graph Nets" <https://arxiv.org/abs/2004.05718>`_ paper .. math:: \mathbf{x}_i^{\prime} = \gamma_{\mathbf{\Theta}} \left( \mathbf{x}_i, \underset{j \in \mathcal{N}(i)}{\bigoplus} h_{\mathbf{\Theta}} \left( \mathbf{x}_i, \mathbf{x}_j \right) \right) with .. math:: \bigoplus = \underbrace{\begin{bmatrix} 1 \\ S(\mathbf{D}, \alpha=1) \\ S(\mathbf{D}, \alpha=-1) \end{bmatrix} }_{\text{scalers}} \otimes \underbrace{\begin{bmatrix} \mu \\ \sigma \\ \max \\ \min \end{bmatrix}}_{\text{aggregators}}, where :math:`\gamma_{\mathbf{\Theta}}` and :math:`h_{\mathbf{\Theta}}` denote MLPs. Parameters ---------- in_channels: int Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. out_channels: int Size of each output sample. aggregators: list[str] Set of aggregation function identifiers, namely :obj:`"sum"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`, :obj:`"var"` and :obj:`"std"`. scalers: list[str] Set of scaling function identifiers, namely :obj:`"identity"`, :obj:`"amplification"`, :obj:`"attenuation"`, :obj:`"linear"` and :obj:`"inverse_linear"`. deg: tensor Histogram of in-degrees of nodes in the training set, used by scalers to normalize. edge_dim: int, optional Edge feature dimensionality (in case there are any). (default :obj:`None`) towers: int, optional Number of towers (default: :obj:`1`). pre_layers: int, optional Number of transformation layers before aggregation (default: :obj:`1`). post_layers: int, optional Number of transformation layers after aggregation (default: :obj:`1`). divide_input: bool, optional Whether the input features should be split between towers or not (default: :obj:`False`). Shapes: - **input:** node features :math:`(|\mathcal{V}|, F_{in})`, edge indices :math:`(2, |\mathcal{E}|)`, edge features :math:`(|\mathcal{E}|, D)` *(optional)* - **output:** node features :math:`(|\mathcal{V}|, F_{out})` """ def __init__(self, in_channels, out_channels, aggregators, scalers, deg, edge_dim, towers=1, pre_layers=1, post_layers=1, divide_input=False): super(PNAConv, self).__init__() if divide_input: assert in_channels % towers == 0 assert out_channels % towers == 0 self.in_channels = in_channels self.out_channels = out_channels self.aggregators = aggregators self.scalers = scalers self.edge_dim = edge_dim self.towers = towers self.divide_input = divide_input self.F_in = in_channels // towers if divide_input else in_channels self.F_out = self.out_channels // towers # Mul operation in TensorFlow require two tensor have the same type if tlx.BACKEND == 'tensorflow': deg = tlx.constant(tlx.convert_to_numpy(deg), dtype=tlx.float32) num_nodes = int(reduce_sum(deg)) bin_degrees = tlx.arange(start=0, limit=len(deg), dtype=tlx.float32) self.avg_deg = { 'lin': float(reduce_sum(bin_degrees * deg)) / num_nodes, 'log': float(reduce_sum(tlx.log(bin_degrees + 1) * deg)) / num_nodes, 'exp': float(reduce_sum(tlx.exp(bin_degrees) * deg)) / num_nodes, } if self.edge_dim is not None: self.edge_encoder = Linear(in_features=edge_dim, out_features=self.F_in) self.pre_nns = ModuleList() self.post_nns = ModuleList() for _ in range(towers): modules = Sequential([Linear(in_features=(3 if edge_dim else 2) * self.F_in, out_features=self.F_in)]) for _ in range(pre_layers - 1): modules.append(ReLU()) modules.append(Linear(in_features=self.F_in, out_features=self.F_in)) self.pre_nns.append(modules) in_channels = (len(aggregators) * len(scalers) + 1) * self.F_in modules = Sequential([Linear(in_features=in_channels, out_features=self.F_out)]) for _ in range(post_layers - 1): modules.append(ReLU()) modules.append(Linear(in_features=self.F_out, out_features=self.F_out)) self.post_nns.append(modules) self.lin = Linear(in_features=out_channels, out_features=out_channels)
[docs] def forward(self, x, edge_index, edge_attr=None): if self.divide_input: x = tlx.reshape(x, (-1, self.towers, self.F_in)) else: x = tlx.stack([x for i in range(0, self.towers)], axis=1) out = self.propagate(x=x, edge_index=edge_index, edge_attr=edge_attr) out = tlx.concat([x, out], axis=-1) outs = [nn(out[:, i]) for i, nn in enumerate(self.post_nns)] out = tlx.concat(outs, axis=1) return self.lin(out)
[docs] def message(self, x, edge_index, edge_attr=None): x_j = tlx.gather(x, edge_index[0, :]) x_i = tlx.gather(x, edge_index[1, :]) if edge_attr is not None: edge_attr = self.edge_encoder(edge_attr) edge_attr = tlx.stack([edge_attr for i in range(0, self.towers)], axis=1) h = tlx.concat([x_i, x_j, edge_attr], axis=-1) else: h = tlx.concat([x_i, x_j], axis=-1) hs = [nn(h[:, i]) for i, nn in enumerate(self.pre_nns)] return tlx.stack(hs, axis=1)
[docs] def aggregate(self, inputs, index, num_nodes=None, aggr=None): outs = [] dim_size = None dst_index = index[1, :] for aggregator in self.aggregators: if aggregator == 'sum': out = global_sum_pool(inputs, dst_index) elif aggregator == 'mean': out = global_mean_pool(inputs, dst_index) elif aggregator == 'min': out = global_min_pool(inputs, dst_index) elif aggregator == 'max': out = global_max_pool(inputs, dst_index) elif aggregator == 'var' or aggregator == 'std': mean = global_mean_pool(inputs, dst_index) mean_squares = global_mean_pool(inputs * inputs, dst_index) out = mean_squares - mean * mean if aggregator == 'std': out = tlx.sqrt(tlx.relu(out) + 1e-5) else: raise ValueError(f'Unknown aggregator "{aggregator}".') outs.append(out) out = tlx.concat(outs, axis=-1) deg = degree(dst_index, dim_size, dtype=inputs.dtype) deg = tlx.reshape(tlx.where(deg > 1, deg, tlx.ones_like(deg)), (-1, 1, 1)) outs = [] for scaler in self.scalers: if scaler == 'identity': pass elif scaler == 'amplification': out = out * (tlx.log(deg + 1) / self.avg_deg['log']) elif scaler == 'attenuation': out = out * (self.avg_deg['log'] / tlx.log(deg + 1)) elif scaler == 'linear': out = out * (deg / self.avg_deg['lin']) elif scaler == 'inverse_linear': out = out * (self.avg_deg['lin'] / deg) else: raise ValueError(f'Unknown scaler "{scaler}".') outs.append(out) return tlx.concat(outs, axis=-1)