Source code for gammagl.transforms.compose

from typing import Callable, List, Union

from gammagl.data import Graph, HeteroGraph
from gammagl.transforms import BaseTransform


[docs] class Compose(BaseTransform): """Composes several transforms together. Parameters ---------- transforms: List[Callable] List of transforms to compose. """ def __init__(self, transforms: List[Callable]): self.transforms = transforms def __call__(self, graph: Graph): for transform in self.transforms: if isinstance(graph, (list, tuple)): graph = [transform(d) for d in graph] else: graph = transform(graph) return graph def __repr__(self) -> str: args = [f' {transform}' for transform in self.transforms] return '{}([\n{}\n])'.format(self.__class__.__name__, ',\n'.join(args))