import numpy as np
from typing import List
from gammagl.data import download_url
from gammagl.utils import read_embeddings
[docs]
class CA_GrQc():
r"""The CA-GrQc datasets used in the `"GraphGAN: Graph Representation Learning with Generative Adversarial Nets"
<https://arxiv.org/pdf/1711.08267.pdf>`_ paper. arXiv-GrQc is from arXiv and covers scientific collaborations
between authors with papers submitted to the General Relativity and Quantum Cosmology categories. This graph has
5,242 vertices and 14,496 edges.
Parameters
----------
dir: str
Root directory where the dataset should be saved.
n_emb: int
Dimension of node embeddings
"""
url = 'https://raw.githubusercontent.com/hwwang55/GraphGAN/master'
def __init__(self, dir: str, n_emb: int):
self.download(dir)
self.n_node, self.graph = self.read_edges(f'{dir}/CA-GrQc_train.txt', f'{dir}/CA-GrQc_test.txt')
self.test_edges = self.read_edges_from_file(f'{dir}/CA-GrQc_test.txt')
self.test_edges_neg = self.read_edges_from_file(f'{dir}/CA-GrQc_test_neg.txt')
filename=f'{dir}/CA-GrQc_pre_train.emb'
with open(filename, "r") as f:
lines = f.readlines()[1:]
embedding_matrix_d = np.random.rand(self.n_node, n_emb)
for line in lines:
emd = line.split()
embedding_matrix_d[int(emd[0]), :] = [float(item) for item in emd[1:]]
embedding_matrix_g = embedding_matrix_d.copy()
self.node_embed_init_d = embedding_matrix_d
self.node_embed_init_g = embedding_matrix_g
@property
def file_names(self) -> List[str]:
names = ['data/link_prediction/CA-GrQc_train.txt', 'data/link_prediction/CA-GrQc_test.txt',
'data/link_prediction/CA-GrQc_test_neg.txt', 'pre_train/link_prediction/CA-GrQc_pre_train.emb']
return [f'{name}' for name in names]
[docs]
def download(self, dir):
for name in self.file_names:
download_url(f'{self.url}/{name}', dir)
[docs]
def read_edges(self, train_filename, test_filename):
"""read data from downloaded files
Parameters
----------
train_filename:
training file name
test_filename:
test file name
Returns
-------
(:obj:`int`, :obj:`dict`): number of nodes in the graph and node_id -> list of neighbors in the graph
"""
graph = {}
nodes = set()
train_edges = self.read_edges_from_file(train_filename)
test_edges = self.read_edges_from_file(test_filename) if test_filename != "" else []
for edge in train_edges:
nodes.add(edge[0])
nodes.add(edge[1])
if graph.get(edge[0]) is None:
graph[edge[0]] = []
if graph.get(edge[1]) is None:
graph[edge[1]] = []
graph[edge[0]].append(edge[1])
graph[edge[1]].append(edge[0])
for edge in test_edges:
nodes.add(edge[0])
nodes.add(edge[1])
if graph.get(edge[0]) is None:
graph[edge[0]] = []
if graph.get(edge[1]) is None:
graph[edge[1]] = []
return len(nodes), graph
[docs]
def str_list_to_int(self, str_list):
return [int(item) for item in str_list]
[docs]
def read_edges_from_file(self, filename):
with open(filename, "r") as f:
lines = f.readlines()
edges = [self.str_list_to_int(line.split()) for line in lines]
return edges