Source code for gammagl.utils.mask

# !/usr/bin/env python3
# -*- coding:utf-8 -*-

# @Time    : 2022/05/18 16:02
# @Author  : clear
# @FileName: mask.py

import tensorlayerx as tlx
from typing import Optional
import numpy as np

[docs] def index_to_mask(index, size: Optional[int] = None): r"""Converts indices to a mask representation. Parameters ---------- index: tensor The indices. size: int, optional The size of the mask. If set to :obj:`None`, a minimal sized output mask is returned. """ index = tlx.convert_to_numpy(index) index = index.reshape((-1,)) size = int(index.max()) + 1 if size is None else size mask = np.zeros(size, dtype=bool) mask[index] = True return tlx.convert_to_tensor(mask)
[docs] def mask_to_index(mask): r"""Converts a mask to an index representation. Parameters ---------- mask: tensor The mask. """ idx = tlx.convert_to_numpy(mask).nonzero()[0] return tlx.convert_to_tensor(idx, dtype=tlx.int64)