Source code for gammagl.datasets.tu_dataset

import os
import os.path as osp
import shutil
import tensorlayerx as tlx
from typing import Callable, List, Optional

try:
    import cPickle as pickle
except ImportError:
    import pickle

from gammagl.data import InMemoryDataset, download_url, extract_zip
from gammagl.io.tu import read_tu_data


[docs] class TUDataset(InMemoryDataset): r"""A variety of graph kernel benchmark datasets, *.e.g.* "IMDB-BINARY", "REDDIT-BINARY" or "PROTEINS", collected from the `TU Dortmund University <https://chrsmrrs.github.io/datasets>`_. In addition, this dataset wrapper provides `cleaned dataset versions <https://github.com/nd7141/graph_datasets>`_ as motivated by the `"Understanding Isomorphism Bias in Graph Data Sets" <https://arxiv.org/abs/1910.12091>`_ paper, containing only non-isomorphic graphs. .. note:: Some datasets may not come with any node labels. You can then either make use of the argument :obj:`use_node_attr` to load additional continuous node attributes (if present) or provide synthetic node features using transforms such as like :class:`gammagl.transforms.Constant` or :class:`gammagl.transforms.OneHotDegree`. Parameters ---------- root: str, optional Root directory where the dataset should be saved. name: str, optional The `name <https://chrsmrrs.github.io/datasets/docs/datasets/>`_ of the dataset. transform: callable, optional A function/transform that takes in an :obj:`gammagl.data.Graph` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform: callable, optional A function/transform that takes in an :obj:`gammagl.data.Graph` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter: callable, optional A function that takes in an :obj:`gammagl.data.Graph` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) use_node_attr: bool, optional If :obj:`True`, the dataset will contain additional continuous node attributes (if present). (default: :obj:`False`) use_edge_attr: bool, optional If :obj:`True`, the dataset will contain additional continuous edge attributes (if present). (default: :obj:`False`) cleaned: bool, optional If :obj:`True`, the dataset will contain only non-isomorphic graphs. (default: :obj:`False`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) Tip --- .. list-table:: :widths: 20 10 10 10 10 10 :header-rows: 1 * - Name - #graphs - #nodes - #edges - #features - #classes * - MUTAG - 188 - ~17.9 - ~39.6 - 7 - 2 * - ENZYMES - 600 - ~32.6 - ~124.3 - 3 - 6 * - PROTEINS - 1,113 - ~39.1 - ~145.6 - 3 - 2 * - COLLAB - 5,000 - ~74.5 - ~4914.4 - 0 - 3 * - IMDB-BINARY - 1,000 - ~19.8 - ~193.1 - 0 - 2 * - REDDIT-BINARY - 2,000 - ~429.6 - ~995.5 - 0 - 2 * - ... - - - - - """ url = 'https://www.chrsmrrs.com/graphkerneldatasets' cleaned_url = ('https://raw.githubusercontent.com/nd7141/' 'graph_datasets/master/datasets') def __init__(self, root: str = None, name: str = 'MUTAG', transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, use_node_attr: bool = False, use_edge_attr: bool = False, cleaned: bool = False, force_reload: bool = False): self.name = name self.cleaned = cleaned super().__init__(root, transform, pre_transform, pre_filter, force_reload = force_reload) self.data, self.slices, self.sizes = self.load_data(self.processed_paths[0]) if self.data.x is not None and not use_node_attr: num_node_attributes = self.num_node_attributes self.data.x = self.data.x[:, num_node_attributes:] if self.data.edge_attr is not None and not use_edge_attr: num_edge_attributes = self.num_edge_attributes self.data.edge_attr = self.data.edge_attr[:, num_edge_attributes:] @property def raw_dir(self) -> str: name = f'raw{"_cleaned" if self.cleaned else ""}' return osp.join(self.root, self.name, name) @property def processed_dir(self) -> str: name = f'processed{"_cleaned" if self.cleaned else ""}' return osp.join(self.root, self.name, name) @property def num_node_labels(self) -> int: return self.sizes['num_node_labels'] @property def num_node_attributes(self) -> int: return self.sizes['num_node_attributes'] @property def num_edge_labels(self) -> int: return self.sizes['num_edge_labels'] @property def num_edge_attributes(self) -> int: return self.sizes['num_edge_attributes'] @property def raw_file_names(self) -> List[str]: names = ['A', 'graph_indicator'] return [f'{self.name}_{name}.txt' for name in names] @property def processed_file_names(self) -> str: return tlx.BACKEND + '_graph.pt'
[docs] def download(self): url = self.cleaned_url if self.cleaned else self.url folder = osp.join(self.root, self.name) path = download_url(f'{url}/{self.name}.zip', folder) extract_zip(path, folder) os.unlink(path) shutil.rmtree(self.raw_dir) os.rename(osp.join(folder, self.name), self.raw_dir)
[docs] def process(self): self.data, self.slices, sizes = read_tu_data(self.raw_dir, self.name) if self.pre_filter is not None: data_list = [self.get(idx) for idx in range(len(self))] data_list = [data for data in data_list if self.pre_filter(data)] self.data, self.slices = self.collate(data_list) if self.pre_transform is not None: data_list = [self.get(idx) for idx in range(len(self))] data_list = [self.pre_transform(data) for data in data_list] self.data, self.slices = self.collate(data_list) self.save_data((self.data, self.slices, sizes), self.processed_paths[0])
def __repr__(self) -> str: return f'{self.name}({len(self)})'