Source code for gammagl.datasets.entities

import logging
import os
import os.path as osp
from collections import Counter
from typing import Callable, List, Optional

import numpy as np
import tensorlayerx as tlx

from gammagl.data import (
    Graph,
    InMemoryDataset,
    download_url,
    extract_tar,
)


[docs] class Entities(InMemoryDataset): r"""The relational entities networks "AIFB", "MUTAG", "BGS" and "AM" from the `"Modeling Relational Data with Graph Convolutional Networks" <https://arxiv.org/abs/1703.06103>`_ paper. Training and test splits are given by node indices. Parameters ---------- root: str, optional Root directory where the dataset should be saved. name: str, optional The name of the dataset (:obj:`"AIFB"`, :obj:`"MUTAG"`, :obj:`"BGS"`, :obj:`"AM"`). hetero: bool, optional If set to :obj:`True`, will save the dataset as a :class:`~gammagl.data.HeteroGraph` object. (default: :obj:`False`) transform: callable, optional A function/transform that takes in an :obj:`gammagl.data.Graph` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform: callable, optional A function/transform that takes in an :obj:`gammagl.data.Graph` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) """ url = 'https://data.dgl.ai/dataset/{}.tgz' def __init__(self, root: str = None, name: str = 'aifb', hetero: bool = False, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False): self.name = name.lower() self.hetero = hetero assert self.name in ['aifb', 'am', 'mutag', 'bgs'] super().__init__(root, transform, pre_transform, force_reload = force_reload) self.data, self.slices = self.load_data(self.processed_paths[0]) @property def raw_dir(self) -> str: return osp.join(self.root, self.name, 'raw') @property def processed_dir(self) -> str: return osp.join(self.root, self.name, 'processed') @property def num_relations(self) -> int: return int(tlx.reduce_max(self.data.edge_type)) + 1 @property def num_classes(self) -> int: return int(tlx.reduce_max(self.data.train_y)) + 1 @property def raw_file_names(self) -> List[str]: return [ f'{self.name}_stripped.nt.gz', 'completeDataset.tsv', 'trainingSet.tsv', 'testSet.tsv', ] @property def processed_file_names(self) -> str: return tlx.BACKEND + 'hetero_data.pt' if self.hetero else tlx.BACKEND + 'data.pt'
[docs] def download(self): path = download_url(self.url.format(self.name), self.root) extract_tar(path, self.raw_dir) os.unlink(path)
[docs] def process(self): import gzip import pandas as pd import rdflib as rdf graph_file, task_file, train_file, test_file = self.raw_paths with hide_stdout(): g = rdf.Graph() with gzip.open(graph_file, 'rb') as f: g.parse(file=f, format='nt') freq = Counter(g.predicates()) relations = sorted(set(g.predicates()), key=lambda p: -freq.get(p, 0)) subjects = set(g.subjects()) objects = set(g.objects()) nodes = list(subjects.union(objects)) N = len(nodes) R = 2 * len(relations) relations_dict = {rel: i for i, rel in enumerate(relations)} nodes_dict = {node: i for i, node in enumerate(nodes)} edges = [] for s, p, o in g.triples((None, None, None)): src, dst, rel = nodes_dict[s], nodes_dict[o], relations_dict[p] edges.append([src, dst, 2 * rel]) edges.append([dst, src, 2 * rel + 1]) edges = np.array(edges).transpose() perm = (N * R * edges[0] + R * edges[1] + edges[2]).argsort() edges = edges[:, perm] edge_index, edge_type = edges[:2], edges[2] if self.name == 'am': label_header = 'label_cateogory' nodes_header = 'proxy' elif self.name == 'aifb': label_header = 'label_affiliation' nodes_header = 'person' elif self.name == 'mutag': label_header = 'label_mutagenic' nodes_header = 'bond' elif self.name == 'bgs': label_header = 'label_lithogenesis' nodes_header = 'rock' labels_df = pd.read_csv(task_file, sep='\t') labels_set = set(labels_df[label_header].values.tolist()) labels_dict = {lab: i for i, lab in enumerate(list(labels_set))} nodes_dict = {str(key): val for key, val in nodes_dict.items()} train_labels_df = pd.read_csv(train_file, sep='\t') train_indices, train_labels = [], [] for nod, lab in zip(train_labels_df[nodes_header].values, train_labels_df[label_header].values): train_indices.append(nodes_dict[nod]) train_labels.append(labels_dict[lab]) train_idx = np.array(train_indices) train_y = np.array(train_labels) test_labels_df = pd.read_csv(test_file, sep='\t') test_indices, test_labels = [], [] for nod, lab in zip(test_labels_df[nodes_header].values, test_labels_df[label_header].values): test_indices.append(nodes_dict[nod]) test_labels.append(labels_dict[lab]) test_idx = np.array(test_indices) test_y = np.array(test_labels) data = Graph(edge_index=edge_index, edge_type=edge_type, train_idx=train_idx, train_y=train_y, test_idx=test_idx, test_y=test_y, num_nodes=N) if self.hetero: data = data.to_heterogeneous(node_type_names=['v']) self.save_data(self.collate([data]), self.processed_paths[0])
def __repr__(self) -> str: return f'{self.name.upper()}{self.__class__.__name__}()'
class hide_stdout(object): def __enter__(self): self.level = logging.getLogger().level logging.getLogger().setLevel(logging.ERROR) def __exit__(self, *args): logging.getLogger().setLevel(self.level)