Source code for gammagl.utils.to_dense_batch

import tensorlayerx as tlx

[docs] def to_dense_batch(x, batch=None, fill_value=0, max_num_nodes=None): r""" Given a sparse batch of node features :math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}` (with :math:`N_i` indicating the number of nodes in graph :math:`i`), creates a dense node feature tensor :math:`\mathbf{X} \in \mathbb{R}^{B \times N_{\max} \times F}` (with :math:`N_{\max} = \max_i^B N_i`). In addition, a mask of shape :math:`\mathbf{M} \in \{ 0, 1 \}^{B \times N_{\max}}` is returned, holding information about the existence of fake-nodes in the dense representation. Parameters ---------- x: tensor Node feature matrix :math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}`. batch: tensor :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each node to a specific example. Must be ordered. (default: :obj:`None`) fill_value: float, optional The value for invalid entries in the resulting dense output tensor. (default: :obj:`0`) max_num_nodes: int, optional The size of the output node dimension. (default: :obj:`None`) batch_size: int, optional The batch size. (default: :obj:`None`) Returns ------- (:class:`Tensor`, :class:`BoolTensor`) """ if batch is None and max_num_nodes is None: return tlx.expand_dims(x, axis=0), tlx.ones((1, x.shape[0]), dtype=tlx.bool) if batch is None: batch = tlx.zeros((x.shape[0],), dtype=tlx.int64) batch_size = tlx.reduce_max(batch) + 1 num_nodes = tlx.unsorted_segment_sum(tlx.ones((x.shape[0],), dtype=batch.dtype), batch, num_segments=batch_size) cum_nodes = tlx.concat([tlx.zeros((1,), dtype=batch.dtype), tlx.cumsum(num_nodes)], axis=0) if max_num_nodes is None: max_num_nodes = tlx.reduce_max(num_nodes) idx = tlx.arange(0, batch.shape[0], dtype=batch.dtype) idx = idx - tlx.gather(cum_nodes, batch) + batch * max_num_nodes shape = [batch_size * max_num_nodes] + list(x.shape)[1:] ret = tlx.constant(fill_value, shape=shape, dtype=x.dtype) ret = tlx.tensor_scatter_nd_update(ret, tlx.expand_dims(idx, axis=1), x) ret = tlx.reshape(ret, shape=[batch_size, max_num_nodes] + list(x.shape)[1:]) mask = tlx.zeros((batch_size * max_num_nodes,), dtype=tlx.bool) if tlx.BACKEND == 'paddle': mask[idx] = True else: mask = tlx.scatter_update(mask, idx, tlx.ones_like(idx, dtype=tlx.bool)) mask = tlx.reshape(mask, (batch_size, max_num_nodes)) return ret, mask