Source code for gammagl.layers.conv.mgnni_m_iter

import numpy as np
import tensorlayerx as tlx
import tensorlayerx.nn as nn

from gammagl.utils import degree
from gammagl.layers.conv import MessagePassing


[docs] class MGNNI_m_iter(MessagePassing): r"""The mgnni operator from the `"Multiscale Graph Neural Networks with Implicit Layers" <https://arxiv.org/abs/2210.08353>`_ paper .. math:: Z^{(l+1)} =\gamma g(F)Z^{(l)}S^{(m)}+f(X;G) where :math `\gamma` denotes the contraction factor, :math `m` denotes a hyperparameter for graph scale(i.e., the power of adjacency matrix) and :math `f(X;G)` is a parameterized transformation on input features and graphs, the normalized weight matrix :math:`g(F)` are computed as .. math:: g(F) =\frac{1}{\|F^\top F\|_\text{F}+\epsilon_F}F^\top F Parameters ---------- m: int Size of each input sample to derive the size from the first input(s) to the forward method. k: int The power of adjacency matrix. The greater the k, the further the distance to capture the information threshold: int Threshold for convergence. Convergence is considered when the difference between the two times is less than this threshold max_iter: int Maximum number of iterative solver iterations gamma: float The contraction factor. The smaller the gamma, the faster the contraction, the smaller the capture range; the larger the gamma, the larger the capture range, but it is difficult to converge and inefficient layer_norm: bool, optional whether to use layer norm. (default: :obj:`False`) """ def __init__(self, m, k, threshold, max_iter, gamma, layer_norm=False): super(MGNNI_m_iter, self).__init__() self.F = nn.Parameter(tlx.convert_to_tensor(np.zeros((m, m)), dtype=tlx.float32)) self.layer_norm = layer_norm self.gamma = gamma self.k = k self.max_iter = max_iter self.threshold = threshold self.f_solver = fwd_solver self.b_solver = fwd_solver
[docs] def reset_parameters(self): initor = tlx.initializers.Zeros() self.F = self._get_weights("F", shape=self.F.shape, init=initor)
def _inner_func(self, Z, X, edge_index, edge_weight, num_nodes): P = tlx.ops.transpose(Z) ei = tlx.ops.convert_to_tensor(edge_index) src, dst = ei[0], ei[1] if edge_weight is None: edge_weight = tlx.ones(shape=(ei.shape[1], 1)) edge_weight = tlx.reshape(edge_weight, (-1,)) deg = degree(src, num_nodes=num_nodes, dtype=tlx.float32) norm = tlx.pow(deg, -0.5) weights = tlx.ops.gather(norm, src) * tlx.reshape(edge_weight, (-1,)) deg = degree(dst, num_nodes=num_nodes, dtype=tlx.float32) norm = tlx.pow(deg, -0.5) weights = tlx.reshape(weights, (-1,)) * tlx.ops.gather(norm, dst) for _ in range(self.k): P = self.propagate(P, ei, edge_weight=weights, num_nodes=num_nodes) Z = tlx.ops.transpose(P) Z_new = self.gamma * g(self.F) @ Z + X del Z, P, ei return Z_new
[docs] def forward(self, X, edge_index, edge_weight, num_nodes): Z, abs_diff = self.f_solver(lambda Z: self._inner_func(Z, X, edge_index, edge_weight, num_nodes), z_init=tlx.zeros_like(X), threshold=self.threshold, max_iter=self.max_iter) Z = tlx.convert_to_tensor(Z) new_Z = Z if self.is_train: if tlx.BACKEND != 'paddle': new_Z = self._inner_func(tlx.Variable(Z, 'Z'), X, edge_index, edge_weight, num_nodes) else: Z.stop_gradient = False new_Z = self._inner_func(Z, X, edge_index, edge_weight, num_nodes) return new_Z
def fwd_solver(f, z_init, threshold, max_iter): z_prev, z = z_init, f(z_init) nstep = 0 while nstep < max_iter: z_prev, z = z, f(z) abs_diff = tlx.ops.convert_to_numpy(norm(z_prev - z)).item() if abs_diff < threshold: break nstep += 1 del z_prev if nstep == max_iter: print(f'step {nstep}, not converged, abs_diff: {abs_diff}') return z, abs_diff def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None): if p == "fro": norm_np = np.linalg.norm(tlx.convert_to_numpy(input), ord="fro", axis=dim, keepdims=keepdim) elif p == "nuc": norm_np = np.linalg.norm(tlx.convert_to_numpy(input), ord="nuc", axis=dim, keepdims=keepdim) else: norm_np = np.linalg.norm(tlx.convert_to_numpy(input), ord=p, axis=dim, keepdims=keepdim) op = tlx.convert_to_tensor(norm_np) if (tlx.BACKEND == "paddle"): op.stop_gradient = False else: op = tlx.Variable(op, 'op') return op epsilon_F = 10 ** (-12) def g(F): FF = tlx.ops.transpose(F) @ F FF_norm = norm(FF, p='fro') return (1 / (FF_norm + epsilon_F)) * FF