Skip to content

Commit

Permalink
[Example]Add P-GNN example (dmlc#3823)
Browse files Browse the repository at this point in the history
* [Model]P-GNN

* updata

* [Example]P-GNN

* Update README.md

Co-authored-by: Mufei Li <mufeili1996@gmail.com>
  • Loading branch information
RecLusIve-F and mufeili committed Mar 10, 2022
1 parent eec219a commit f908f35
Show file tree
Hide file tree
Showing 5 changed files with 526 additions and 0 deletions.
3 changes: 3 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,9 @@ To quickly locate the examples of your interest, search for the tagged keywords
- <a name='geniepath'></a> Liu Z, et al. Geniepath: Graph neural networks with adaptive receptive paths. [Paper link](https://arxiv.org/abs/1802.00910).
- Example code: [PyTorch](../examples/pytorch/geniepath)
- Tags: Fraud detection, Node classification, Graph attention, LSTM, Adaptive receptive fields
- <a name='pgnn'></a> You J, et al. Position-aware graph neural networks. [Paper link](https://arxiv.org/abs/1906.04817).
- Example code: [PyTorch](../examples/pytorch/P-GNN)
- Tags: Positional encoding, Link prediction, Link-pair prediction

## 2018

Expand Down
57 changes: 57 additions & 0 deletions examples/pytorch/P-GNN/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# DGL Implementations of P-GNN

This DGL example implements the GNN model proposed in the paper [Position-aware Graph Neural Networks](http://proceedings.mlr.press/v97/you19b/you19b.pdf). For the original implementation, see [here](https://github.com/JiaxuanYou/P-GNN).

Contributor: [RecLusIve-F](https://github.com/RecLusIve-F)

## Requirements

The codebase is implemented in Python 3.8. For version requirement of packages, see below.

```
dgl 0.7.2
numpy 1.21.2
torch 1.10.1
networkx 2.6.3
scikit-learn 1.0.2
```

## Instructions to download datasets:

1. Download datasets from [here](https://github.com/RecLusIve-F/P-GNN-dgl/blob/master/data.zip?raw=true)
2. Extract zip folder in this directory

## Instructions for experiments

### Link prediction

```bash
# Communities-T
python main.py --task link

# Communities
python main.py --task link --inductive
```

### Link pair prediction

```bash
# Communities
python main.py --task link_pair --inductive
```

## Performance

### Link prediction (Grid-T and Communities-T refer to the transductive learning setting of Grid and Communities)

| Dataset | Communities-T | Communities |
| :------------------------------: | :-----------: | :-----------: |
| ROC AUC ( P-GNN-E-2L in Table 1) | 0.988 ± 0.003 | 0.985 ± 0.008 |
| ROC AUC (DGL: P-GNN-E-2L) | 0.984 ± 0.010 | 0.991 ± 0.004 |

### Link pair prediction

| Dataset | Communities |
| :------------------------------: | :---------: |
| ROC AUC ( P-GNN-E-2L in Table 1) | 1.0 ± 0.001 |
| ROC AUC (DGL: P-GNN-E-2L) | 1.0 ± 0.000 |
147 changes: 147 additions & 0 deletions examples/pytorch/P-GNN/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import os
import dgl
import torch
import numpy as np
import torch.nn as nn
from model import PGNN
from sklearn.metrics import roc_auc_score
from utils import get_dataset, preselect_anchor

import warnings
warnings.filterwarnings('ignore')

def get_loss(p, data, out, loss_func, device, get_auc=True):
edge_mask = np.concatenate((data['positive_edges_{}'.format(p)], data['negative_edges_{}'.format(p)]), axis=-1)

nodes_first = torch.index_select(out, 0, torch.from_numpy(edge_mask[0, :]).long().to(out.device))
nodes_second = torch.index_select(out, 0, torch.from_numpy(edge_mask[1, :]).long().to(out.device))

pred = torch.sum(nodes_first * nodes_second, dim=-1)

label_positive = torch.ones([data['positive_edges_{}'.format(p)].shape[1], ], dtype=pred.dtype)
label_negative = torch.zeros([data['negative_edges_{}'.format(p)].shape[1], ], dtype=pred.dtype)
label = torch.cat((label_positive, label_negative)).to(device)
loss = loss_func(pred, label)

if get_auc:
auc = roc_auc_score(label.flatten().cpu().numpy(), torch.sigmoid(pred).flatten().data.cpu().numpy())
return loss, auc
else:
return loss

def train_model(data, model, loss_func, optimizer, device, g_data):
model.train()
out = model(g_data)

loss = get_loss('train', data, out, loss_func, device, get_auc=False)

optimizer.zero_grad()
loss.backward()
optimizer.step()
optimizer.zero_grad()

return g_data

def eval_model(data, g_data, model, loss_func, device):
model.eval()
out = model(g_data)

# train loss and auc
tmp_loss, auc_train = get_loss('train', data, out, loss_func, device)
loss_train = tmp_loss.cpu().data.numpy()

# val loss and auc
_, auc_val = get_loss('val', data, out, loss_func, device)

# test loss and auc
_, auc_test = get_loss('test', data, out, loss_func, device)

return loss_train, auc_train, auc_val, auc_test

def main(args):
# The mean and standard deviation of the experiment results
# are stored in the 'results' folder
if not os.path.isdir('results'):
os.mkdir('results')

if torch.cuda.is_available():
device = 'cuda:0'
else:
device = 'cpu'

print('Learning Type: {}'.format(['Transductive', 'Inductive'][args.inductive]),
'Task: {}'.format(args.task))

results = []

for repeat in range(args.repeat_num):
data = get_dataset(args)

# pre-sample anchor nodes and compute shortest distance values for all epochs
g_list, anchor_eid_list, dist_max_list, edge_weight_list = preselect_anchor(data, args)

# model
model = PGNN(input_dim=data['feature'].shape[1]).to(device)

# loss
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)
loss_func = nn.BCEWithLogitsLoss()

best_auc_val = -1
best_auc_test = -1

for epoch in range(args.epoch_num):
if epoch == 200:
for param_group in optimizer.param_groups:
param_group['lr'] /= 10

g = dgl.graph(g_list[epoch])
g.ndata['feat'] = torch.FloatTensor(data['feature'])
g.edata['sp_dist'] = torch.FloatTensor(edge_weight_list[epoch])
g_data = {
'graph': g.to(device),
'anchor_eid': anchor_eid_list[epoch],
'dists_max': dist_max_list[epoch]
}

train_model(data, model, loss_func, optimizer, device, g_data)

loss_train, auc_train, auc_val, auc_test = eval_model(
data, g_data, model, loss_func, device)
if auc_val > best_auc_val:
best_auc_val = auc_val
best_auc_test = auc_test

if epoch % args.epoch_log == 0:
print(repeat, epoch, 'Loss {:.4f}'.format(loss_train), 'Train AUC: {:.4f}'.format(auc_train),
'Val AUC: {:.4f}'.format(auc_val), 'Test AUC: {:.4f}'.format(auc_test),
'Best Val AUC: {:.4f}'.format(best_auc_val), 'Best Test AUC: {:.4f}'.format(best_auc_test))

results.append(best_auc_test)

results = np.array(results)
results_mean = np.mean(results).round(6)
results_std = np.std(results).round(6)
print('-----------------Final-------------------')
print(results_mean, results_std)

with open('results/{}_{}_{}.txt'.format(['Transductive', 'Inductive'][args.inductive], args.task,
args.k_hop_dist), 'w') as f:
f.write('{}, {}\n'.format(results_mean, results_std))

if __name__ == '__main__':
from argparse import ArgumentParser

parser = ArgumentParser()
parser.add_argument('--task', type=str, default='link', choices=['link', 'link_pair'])
parser.add_argument('--inductive', action='store_true',
help='Inductive learning or transductive learning')
parser.add_argument('--k_hop_dist', default=-1, type=int,
help='K-hop shortest path distance, -1 means exact shortest path.')

parser.add_argument('--epoch_num', type=int, default=2000)
parser.add_argument('--repeat_num', type=int, default=10)
parser.add_argument('--epoch_log', type=int, default=100)

args = parser.parse_args()
main(args)
55 changes: 55 additions & 0 deletions examples/pytorch/P-GNN/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import torch
import torch.nn as nn
import dgl.function as fn
import torch.nn.functional as F

class PGNN_layer(nn.Module):
def __init__(self, input_dim, output_dim):
super(PGNN_layer, self).__init__()
self.input_dim = input_dim

self.linear_hidden_u = nn.Linear(input_dim, output_dim)
self.linear_hidden_v = nn.Linear(input_dim, output_dim)
self.linear_out_position = nn.Linear(output_dim, 1)
self.act = nn.ReLU()

def forward(self, graph, feature, anchor_eid, dists_max):
with graph.local_scope():
u_feat = self.linear_hidden_u(feature)
v_feat = self.linear_hidden_v(feature)
graph.srcdata.update({'u_feat': u_feat})
graph.dstdata.update({'v_feat': v_feat})

graph.apply_edges(fn.u_mul_e('u_feat', 'sp_dist', 'u_message'))
graph.apply_edges(fn.v_add_e('v_feat', 'u_message', 'message'))

messages = torch.index_select(graph.edata['message'], 0,
torch.LongTensor(anchor_eid).to(feature.device))
messages = messages.reshape(dists_max.shape[0], dists_max.shape[1], messages.shape[-1])

messages = self.act(messages) # n*m*d

out_position = self.linear_out_position(messages).squeeze(-1) # n*m_out
out_structure = torch.mean(messages, dim=1) # n*d

return out_position, out_structure

class PGNN(nn.Module):
def __init__(self, input_dim, feature_dim=32, dropout=0.5):
super(PGNN, self).__init__()
self.dropout = nn.Dropout(dropout)

self.linear_pre = nn.Linear(input_dim, feature_dim)
self.conv_first = PGNN_layer(feature_dim, feature_dim)
self.conv_out = PGNN_layer(feature_dim, feature_dim)

def forward(self, data):
x = data['graph'].ndata['feat']
graph = data['graph']
x = self.linear_pre(x)
x_position, x = self.conv_first(graph, x, data['anchor_eid'], data['dists_max'])

x = self.dropout(x)
x_position, x = self.conv_out(graph, x, data['anchor_eid'], data['dists_max'])
x_position = F.normalize(x_position, p=2, dim=-1)
return x_position
Loading

0 comments on commit f908f35

Please sign in to comment.