From f908f35c38a1ae533fd61550c166f7dcf70d3935 Mon Sep 17 00:00:00 2001 From: RecLusIve-F <53204271+RecLusIve-F@users.noreply.github.com> Date: Thu, 10 Mar 2022 17:39:32 +0800 Subject: [PATCH] [Example]Add P-GNN example (#3823) * [Model]P-GNN * updata * [Example]P-GNN * Update README.md Co-authored-by: Mufei Li --- examples/README.md | 3 + examples/pytorch/P-GNN/README.md | 57 +++++++ examples/pytorch/P-GNN/main.py | 147 +++++++++++++++++ examples/pytorch/P-GNN/model.py | 55 +++++++ examples/pytorch/P-GNN/utils.py | 264 +++++++++++++++++++++++++++++++ 5 files changed, 526 insertions(+) create mode 100644 examples/pytorch/P-GNN/README.md create mode 100644 examples/pytorch/P-GNN/main.py create mode 100644 examples/pytorch/P-GNN/model.py create mode 100644 examples/pytorch/P-GNN/utils.py diff --git a/examples/README.md b/examples/README.md index 9d2a82b2447e..7ea98c256a67 100644 --- a/examples/README.md +++ b/examples/README.md @@ -181,6 +181,9 @@ To quickly locate the examples of your interest, search for the tagged keywords - Liu Z, et al. Geniepath: Graph neural networks with adaptive receptive paths. [Paper link](https://arxiv.org/abs/1802.00910). - Example code: [PyTorch](../examples/pytorch/geniepath) - Tags: Fraud detection, Node classification, Graph attention, LSTM, Adaptive receptive fields +- You J, et al. Position-aware graph neural networks. [Paper link](https://arxiv.org/abs/1906.04817). + - Example code: [PyTorch](../examples/pytorch/P-GNN) + - Tags: Positional encoding, Link prediction, Link-pair prediction ## 2018 diff --git a/examples/pytorch/P-GNN/README.md b/examples/pytorch/P-GNN/README.md new file mode 100644 index 000000000000..4549e9973a8a --- /dev/null +++ b/examples/pytorch/P-GNN/README.md @@ -0,0 +1,57 @@ +# DGL Implementations of P-GNN + +This DGL example implements the GNN model proposed in the paper [Position-aware Graph Neural Networks](http://proceedings.mlr.press/v97/you19b/you19b.pdf). For the original implementation, see [here](https://github.com/JiaxuanYou/P-GNN). + +Contributor: [RecLusIve-F](https://github.com/RecLusIve-F) + +## Requirements + +The codebase is implemented in Python 3.8. For version requirement of packages, see below. + +``` +dgl 0.7.2 +numpy 1.21.2 +torch 1.10.1 +networkx 2.6.3 +scikit-learn 1.0.2 +``` + +## Instructions to download datasets: + +1. Download datasets from [here](https://github.com/RecLusIve-F/P-GNN-dgl/blob/master/data.zip?raw=true) +2. Extract zip folder in this directory + +## Instructions for experiments + +### Link prediction + +```bash +# Communities-T +python main.py --task link + +# Communities +python main.py --task link --inductive +``` + +### Link pair prediction + +```bash +# Communities +python main.py --task link_pair --inductive +``` + +## Performance + +### Link prediction (Grid-T and Communities-T refer to the transductive learning setting of Grid and Communities) + +| Dataset | Communities-T | Communities | +| :------------------------------: | :-----------: | :-----------: | +| ROC AUC ( P-GNN-E-2L in Table 1) | 0.988 ± 0.003 | 0.985 ± 0.008 | +| ROC AUC (DGL: P-GNN-E-2L) | 0.984 ± 0.010 | 0.991 ± 0.004 | + +### Link pair prediction + +| Dataset | Communities | +| :------------------------------: | :---------: | +| ROC AUC ( P-GNN-E-2L in Table 1) | 1.0 ± 0.001 | +| ROC AUC (DGL: P-GNN-E-2L) | 1.0 ± 0.000 | diff --git a/examples/pytorch/P-GNN/main.py b/examples/pytorch/P-GNN/main.py new file mode 100644 index 000000000000..0ec45e2fe6b4 --- /dev/null +++ b/examples/pytorch/P-GNN/main.py @@ -0,0 +1,147 @@ +import os +import dgl +import torch +import numpy as np +import torch.nn as nn +from model import PGNN +from sklearn.metrics import roc_auc_score +from utils import get_dataset, preselect_anchor + +import warnings +warnings.filterwarnings('ignore') + +def get_loss(p, data, out, loss_func, device, get_auc=True): + edge_mask = np.concatenate((data['positive_edges_{}'.format(p)], data['negative_edges_{}'.format(p)]), axis=-1) + + nodes_first = torch.index_select(out, 0, torch.from_numpy(edge_mask[0, :]).long().to(out.device)) + nodes_second = torch.index_select(out, 0, torch.from_numpy(edge_mask[1, :]).long().to(out.device)) + + pred = torch.sum(nodes_first * nodes_second, dim=-1) + + label_positive = torch.ones([data['positive_edges_{}'.format(p)].shape[1], ], dtype=pred.dtype) + label_negative = torch.zeros([data['negative_edges_{}'.format(p)].shape[1], ], dtype=pred.dtype) + label = torch.cat((label_positive, label_negative)).to(device) + loss = loss_func(pred, label) + + if get_auc: + auc = roc_auc_score(label.flatten().cpu().numpy(), torch.sigmoid(pred).flatten().data.cpu().numpy()) + return loss, auc + else: + return loss + +def train_model(data, model, loss_func, optimizer, device, g_data): + model.train() + out = model(g_data) + + loss = get_loss('train', data, out, loss_func, device, get_auc=False) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + optimizer.zero_grad() + + return g_data + +def eval_model(data, g_data, model, loss_func, device): + model.eval() + out = model(g_data) + + # train loss and auc + tmp_loss, auc_train = get_loss('train', data, out, loss_func, device) + loss_train = tmp_loss.cpu().data.numpy() + + # val loss and auc + _, auc_val = get_loss('val', data, out, loss_func, device) + + # test loss and auc + _, auc_test = get_loss('test', data, out, loss_func, device) + + return loss_train, auc_train, auc_val, auc_test + +def main(args): + # The mean and standard deviation of the experiment results + # are stored in the 'results' folder + if not os.path.isdir('results'): + os.mkdir('results') + + if torch.cuda.is_available(): + device = 'cuda:0' + else: + device = 'cpu' + + print('Learning Type: {}'.format(['Transductive', 'Inductive'][args.inductive]), + 'Task: {}'.format(args.task)) + + results = [] + + for repeat in range(args.repeat_num): + data = get_dataset(args) + + # pre-sample anchor nodes and compute shortest distance values for all epochs + g_list, anchor_eid_list, dist_max_list, edge_weight_list = preselect_anchor(data, args) + + # model + model = PGNN(input_dim=data['feature'].shape[1]).to(device) + + # loss + optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4) + loss_func = nn.BCEWithLogitsLoss() + + best_auc_val = -1 + best_auc_test = -1 + + for epoch in range(args.epoch_num): + if epoch == 200: + for param_group in optimizer.param_groups: + param_group['lr'] /= 10 + + g = dgl.graph(g_list[epoch]) + g.ndata['feat'] = torch.FloatTensor(data['feature']) + g.edata['sp_dist'] = torch.FloatTensor(edge_weight_list[epoch]) + g_data = { + 'graph': g.to(device), + 'anchor_eid': anchor_eid_list[epoch], + 'dists_max': dist_max_list[epoch] + } + + train_model(data, model, loss_func, optimizer, device, g_data) + + loss_train, auc_train, auc_val, auc_test = eval_model( + data, g_data, model, loss_func, device) + if auc_val > best_auc_val: + best_auc_val = auc_val + best_auc_test = auc_test + + if epoch % args.epoch_log == 0: + print(repeat, epoch, 'Loss {:.4f}'.format(loss_train), 'Train AUC: {:.4f}'.format(auc_train), + 'Val AUC: {:.4f}'.format(auc_val), 'Test AUC: {:.4f}'.format(auc_test), + 'Best Val AUC: {:.4f}'.format(best_auc_val), 'Best Test AUC: {:.4f}'.format(best_auc_test)) + + results.append(best_auc_test) + + results = np.array(results) + results_mean = np.mean(results).round(6) + results_std = np.std(results).round(6) + print('-----------------Final-------------------') + print(results_mean, results_std) + + with open('results/{}_{}_{}.txt'.format(['Transductive', 'Inductive'][args.inductive], args.task, + args.k_hop_dist), 'w') as f: + f.write('{}, {}\n'.format(results_mean, results_std)) + +if __name__ == '__main__': + from argparse import ArgumentParser + + parser = ArgumentParser() + parser.add_argument('--task', type=str, default='link', choices=['link', 'link_pair']) + parser.add_argument('--inductive', action='store_true', + help='Inductive learning or transductive learning') + parser.add_argument('--k_hop_dist', default=-1, type=int, + help='K-hop shortest path distance, -1 means exact shortest path.') + + parser.add_argument('--epoch_num', type=int, default=2000) + parser.add_argument('--repeat_num', type=int, default=10) + parser.add_argument('--epoch_log', type=int, default=100) + + args = parser.parse_args() + main(args) diff --git a/examples/pytorch/P-GNN/model.py b/examples/pytorch/P-GNN/model.py new file mode 100644 index 000000000000..428a0b355d51 --- /dev/null +++ b/examples/pytorch/P-GNN/model.py @@ -0,0 +1,55 @@ +import torch +import torch.nn as nn +import dgl.function as fn +import torch.nn.functional as F + +class PGNN_layer(nn.Module): + def __init__(self, input_dim, output_dim): + super(PGNN_layer, self).__init__() + self.input_dim = input_dim + + self.linear_hidden_u = nn.Linear(input_dim, output_dim) + self.linear_hidden_v = nn.Linear(input_dim, output_dim) + self.linear_out_position = nn.Linear(output_dim, 1) + self.act = nn.ReLU() + + def forward(self, graph, feature, anchor_eid, dists_max): + with graph.local_scope(): + u_feat = self.linear_hidden_u(feature) + v_feat = self.linear_hidden_v(feature) + graph.srcdata.update({'u_feat': u_feat}) + graph.dstdata.update({'v_feat': v_feat}) + + graph.apply_edges(fn.u_mul_e('u_feat', 'sp_dist', 'u_message')) + graph.apply_edges(fn.v_add_e('v_feat', 'u_message', 'message')) + + messages = torch.index_select(graph.edata['message'], 0, + torch.LongTensor(anchor_eid).to(feature.device)) + messages = messages.reshape(dists_max.shape[0], dists_max.shape[1], messages.shape[-1]) + + messages = self.act(messages) # n*m*d + + out_position = self.linear_out_position(messages).squeeze(-1) # n*m_out + out_structure = torch.mean(messages, dim=1) # n*d + + return out_position, out_structure + +class PGNN(nn.Module): + def __init__(self, input_dim, feature_dim=32, dropout=0.5): + super(PGNN, self).__init__() + self.dropout = nn.Dropout(dropout) + + self.linear_pre = nn.Linear(input_dim, feature_dim) + self.conv_first = PGNN_layer(feature_dim, feature_dim) + self.conv_out = PGNN_layer(feature_dim, feature_dim) + + def forward(self, data): + x = data['graph'].ndata['feat'] + graph = data['graph'] + x = self.linear_pre(x) + x_position, x = self.conv_first(graph, x, data['anchor_eid'], data['dists_max']) + + x = self.dropout(x) + x_position, x = self.conv_out(graph, x, data['anchor_eid'], data['dists_max']) + x_position = F.normalize(x_position, p=2, dim=-1) + return x_position diff --git a/examples/pytorch/P-GNN/utils.py b/examples/pytorch/P-GNN/utils.py new file mode 100644 index 000000000000..2074fff607c0 --- /dev/null +++ b/examples/pytorch/P-GNN/utils.py @@ -0,0 +1,264 @@ +import torch +import random +import numpy as np +import networkx as nx +from tqdm.auto import tqdm +import multiprocessing as mp +from multiprocessing import get_context + +def get_communities(remove_feature): + community_size = 20 + + # Create 20 cliques (communities) of size 20, + # then rewire a single edge in each clique to a node in an adjacent clique + graph = nx.connected_caveman_graph(20, community_size) + + # Randomly rewire 1% edges + node_list = list(graph.nodes) + for (u, v) in graph.edges(): + if random.random() < 0.01: + x = random.choice(node_list) + if graph.has_edge(u, x): + continue + graph.remove_edge(u, v) + graph.add_edge(u, x) + + # remove self-loops + graph.remove_edges_from(nx.selfloop_edges(graph)) + edge_index = np.array(list(graph.edges)) + # Add (i, j) for an edge (j, i) + edge_index = np.concatenate((edge_index, edge_index[:, ::-1]), axis=0) + edge_index = torch.from_numpy(edge_index).long().permute(1, 0) + + n = graph.number_of_nodes() + label = np.zeros((n, n), dtype=int) + for u in node_list: + # the node IDs are simply consecutive integers from 0 + for v in range(u): + if u // community_size == v // community_size: + label[u, v] = 1 + + if remove_feature: + feature = torch.ones((n, 1)) + else: + rand_order = np.random.permutation(n) + feature = np.identity(n)[:, rand_order] + + data = { + 'edge_index': edge_index, + 'feature': feature, + 'positive_edges': np.stack(np.nonzero(label)), + 'num_nodes': feature.shape[0] + } + + return data + +def to_single_directed(edges): + edges_new = np.zeros((2, edges.shape[1] // 2), dtype=int) + j = 0 + for i in range(edges.shape[1]): + if edges[0, i] < edges[1, i]: + edges_new[:, j] = edges[:, i] + j += 1 + + return edges_new + +# each node at least remain in the new graph +def split_edges(p, edges, data, non_train_ratio=0.2): + e = edges.shape[1] + edges = edges[:, np.random.permutation(e)] + split1 = int((1 - non_train_ratio) * e) + split2 = int((1 - non_train_ratio / 2) * e) + + data.update({ + '{}_edges_train'.format(p): edges[:, :split1], # 80% + '{}_edges_val'.format(p): edges[:, split1:split2], # 10% + '{}_edges_test'.format(p): edges[:, split2:] # 10% + }) + +def to_bidirected(edges): + return np.concatenate((edges, edges[::-1, :]), axis=-1) + +def get_negative_edges(positive_edges, num_nodes, num_negative_edges): + positive_edge_set = [] + positive_edges = to_bidirected(positive_edges) + for i in range(positive_edges.shape[1]): + positive_edge_set.append(tuple(positive_edges[:, i])) + positive_edge_set = set(positive_edge_set) + + negative_edges = np.zeros((2, num_negative_edges), dtype=positive_edges.dtype) + for i in range(num_negative_edges): + while True: + mask_temp = tuple(np.random.choice(num_nodes, size=(2,), replace=False)) + if mask_temp not in positive_edge_set: + negative_edges[:, i] = mask_temp + break + + return negative_edges + +def get_pos_neg_edges(data, infer_link_positive=True): + if infer_link_positive: + data['positive_edges'] = to_single_directed(data['edge_index'].numpy()) + split_edges('positive', data['positive_edges'], data) + + # resample edge mask link negative + negative_edges = get_negative_edges(data['positive_edges'], data['num_nodes'], + num_negative_edges=data['positive_edges'].shape[1]) + split_edges('negative', negative_edges, data) + + return data + +def shortest_path(graph, node_range, cutoff): + dists_dict = {} + for node in tqdm(node_range, leave=False): + dists_dict[node] = nx.single_source_shortest_path_length(graph, node, cutoff) + return dists_dict + +def merge_dicts(dicts): + result = {} + for dictionary in dicts: + result.update(dictionary) + return result + +def all_pairs_shortest_path(graph, cutoff=None, num_workers=4): + nodes = list(graph.nodes) + random.shuffle(nodes) + pool = mp.Pool(processes=num_workers) + interval_size = len(nodes) / num_workers + results = [pool.apply_async(shortest_path, args=( + graph, nodes[int(interval_size * i): int(interval_size * (i + 1))], cutoff)) + for i in range(num_workers)] + output = [p.get() for p in results] + dists_dict = merge_dicts(output) + pool.close() + pool.join() + return dists_dict + +def precompute_dist_data(edge_index, num_nodes, approximate=0): + """ + Here dist is 1/real_dist, higher actually means closer, 0 means disconnected + :return: + """ + graph = nx.Graph() + edge_list = edge_index.transpose(1, 0).tolist() + graph.add_edges_from(edge_list) + + n = num_nodes + dists_array = np.zeros((n, n)) + dists_dict = all_pairs_shortest_path(graph, cutoff=approximate if approximate > 0 else None) + node_list = graph.nodes() + for node_i in node_list: + shortest_dist = dists_dict[node_i] + for node_j in node_list: + dist = shortest_dist.get(node_j, -1) + if dist != -1: + dists_array[node_i, node_j] = 1 / (dist + 1) + return dists_array + +def get_dataset(args): + # Generate graph data + data_info = get_communities(args.inductive) + # Get positive and negative edges + data = get_pos_neg_edges(data_info, infer_link_positive=True if args.task == 'link' else False) + # Pre-compute shortest path length + if args.task == 'link': + dists_removed = precompute_dist_data(data['positive_edges_train'], data['num_nodes'], + approximate=args.k_hop_dist) + data['dists'] = torch.from_numpy(dists_removed).float() + data['edge_index'] = torch.from_numpy(to_bidirected(data['positive_edges_train'])).long() + else: + dists = precompute_dist_data(data['edge_index'].numpy(), data['num_nodes'], + approximate=args.k_hop_dist) + data['dists'] = torch.from_numpy(dists).float() + + return data + +def get_anchors(n): + """Get a list of NumPy arrays, each of them is an anchor node set""" + m = int(np.log2(n)) + anchor_set_id = [] + for i in range(m): + anchor_size = int(n / np.exp2(i + 1)) + for _ in range(m): + anchor_set_id.append(np.random.choice(n, size=anchor_size, replace=False)) + return anchor_set_id + +def get_dist_max(anchor_set_id, dist): + # N x K, N is number of nodes, K is the number of anchor sets + dist_max = torch.zeros((dist.shape[0], len(anchor_set_id))) + dist_argmax = torch.zeros((dist.shape[0], len(anchor_set_id))).long() + for i in range(len(anchor_set_id)): + temp_id = torch.as_tensor(anchor_set_id[i], dtype=torch.long) + # Get reciprocal of shortest distance to each node in the i-th anchor set + dist_temp = torch.index_select(dist, 1, temp_id) + # For each node in the graph, find its closest anchor node in the set + # and the reciprocal of shortest distance + dist_max_temp, dist_argmax_temp = torch.max(dist_temp, dim=-1) + dist_max[:, i] = dist_max_temp + dist_argmax[:, i] = torch.index_select(temp_id, 0, dist_argmax_temp) + return dist_max, dist_argmax + +def get_a_graph(dists_max, dists_argmax): + src = [] + dst = [] + real_src = [] + real_dst = [] + edge_weight = [] + dists_max = dists_max.numpy() + for i in range(dists_max.shape[0]): + # Get unique closest anchor nodes for node i across all anchor sets + tmp_dists_argmax, tmp_dists_argmax_idx = np.unique(dists_argmax[i, :], True) + src.extend([i] * tmp_dists_argmax.shape[0]) + real_src.extend([i] * dists_argmax[i, :].shape[0]) + real_dst.extend(list(dists_argmax[i, :].numpy())) + dst.extend(list(tmp_dists_argmax)) + edge_weight.extend(dists_max[i, tmp_dists_argmax_idx].tolist()) + eid_dict = {(u, v): i for i, (u, v) in enumerate(list(zip(dst, src)))} + anchor_eid = [eid_dict.get((u, v)) for u, v in zip(real_dst, real_src)] + g = (dst, src) + return g, anchor_eid, edge_weight + +def get_graphs(data, anchor_sets): + graphs = [] + anchor_eids = [] + dists_max_list = [] + edge_weights = [] + for anchor_set in tqdm(anchor_sets, leave=False): + dists_max, dists_argmax = get_dist_max(anchor_set, data['dists']) + g, anchor_eid, edge_weight = get_a_graph(dists_max, dists_argmax) + graphs.append(g) + anchor_eids.append(anchor_eid) + dists_max_list.append(dists_max) + edge_weights.append(edge_weight) + + return graphs, anchor_eids, dists_max_list, edge_weights + +def merge_result(outputs): + graphs = [] + anchor_eids = [] + dists_max_list = [] + edge_weights = [] + + for g, anchor_eid, dists_max, edge_weight in outputs: + graphs.extend(g) + anchor_eids.extend(anchor_eid) + dists_max_list.extend(dists_max) + edge_weights.extend(edge_weight) + + return graphs, anchor_eids, dists_max_list, edge_weights + +def preselect_anchor(data, args, num_workers=4): + pool = get_context("spawn").Pool(processes=num_workers) + # Pre-compute anchor sets, a collection of anchor sets per epoch + anchor_set_ids = [get_anchors(data['num_nodes']) for _ in range(args.epoch_num)] + interval_size = len(anchor_set_ids) / num_workers + results = [pool.apply_async(get_graphs, args=( + data, anchor_set_ids[int(interval_size * i):int(interval_size * (i + 1))],)) + for i in range(num_workers)] + + output = [p.get() for p in results] + graphs, anchor_eids, dists_max_list, edge_weights = merge_result(output) + pool.close() + pool.join() + + return graphs, anchor_eids, dists_max_list, edge_weights