Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support sparse version GAT #8

Merged
merged 4 commits into from
Sep 19, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ A small note about initial sparse matrix operations of https://github.com/tkipf/

# Requirements

pyGAT relies on Python 3.5 and PyTorch 0.4 (due to torch.where).
pyGAT relies on Python 3.5 and PyTorch 0.4.1 (due to torch.sparse_coo_tensor).

# Issues/Pull Requests/Feedbacks

Expand Down
68 changes: 66 additions & 2 deletions layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@ def __init__(self, in_features, out_features, dropout, alpha, concat=True):
self.alpha = alpha
self.concat = concat

self.W = nn.Parameter(nn.init.xavier_uniform(torch.Tensor(in_features, out_features).type(torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor), gain=np.sqrt(2.0)), requires_grad=True)
self.a = nn.Parameter(nn.init.xavier_uniform(torch.Tensor(2*out_features, 1).type(torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor), gain=np.sqrt(2.0)), requires_grad=True)
self.W = nn.Parameter(torch.zeros(size=(in_features, out_features)))
nn.init.xavier_uniform_(self.W.data, gain=1.414)
self.a = nn.Parameter(torch.zeros(size=(2*out_features, 1)))
nn.init.xavier_uniform_(self.a.data, gain=1.414)

self.leakyrelu = nn.LeakyReLU(self.alpha)

Expand All @@ -42,3 +44,65 @@ def forward(self, input, adj):

def __repr__(self):
return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'


class SpGraphAttentionLayer(nn.Module):
"""
Sparse version GAT layer, similar to https://arxiv.org/abs/1710.10903
"""

def __init__(self, in_features, out_features, dropout, alpha, concat=True):
super(SpGraphAttentionLayer, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.alpha = alpha
self.concat = concat

self.W = nn.Parameter(torch.zeros(size=(in_features, out_features)))
nn.init.xavier_normal_(self.W.data, gain=1.414)

self.a = nn.Parameter(torch.zeros(size=(1, 2*out_features)))
nn.init.xavier_normal_(self.a.data, gain=1.414)

self.dropout = nn.Dropout(dropout)
self.leakyrelu = nn.LeakyReLU(self.alpha)

def forward(self, input, adj):
N = input.size()[0]
edge = adj.nonzero().t()

h = torch.mm(input, self.W)
# h: N x out
assert not torch.isnan(h).any()

# Self-attention on the nodes - Shared attention mechanism
edge_h = torch.cat((h[edge[0, :], :], h[edge[1, :], :]), dim=1).t()
# edge: 2*D x E
edge_e = torch.exp(-self.leakyrelu(self.a.mm(edge_h).squeeze()))
assert not torch.isnan(edge_e).any()
# edge_e: 1 x E
e = torch.sparse_coo_tensor(edge, edge_e, torch.Size([N, N]))
# e: N x N
e_rowsum = torch.matmul(e, torch.ones(size=(N, 1)).cuda())
# e_rowsum: N x 1

edge_e = self.dropout(edge_e)
# edge_e: 1 x E
e = torch.sparse_coo_tensor(edge, edge_e, torch.Size([N, N]))
# e: N x N
h_prime = torch.matmul(e, h)
assert not torch.isnan(h_prime).any()

h_prime = h_prime.div(e_rowsum)
# h_prime: N x out
assert not torch.isnan(h_prime).any()

if self.concat:
# if this layer is not last layer,
return F.elu(h_prime)
else:
# if this layer is last layer,
return h_prime

def __repr__(self):
return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'
31 changes: 29 additions & 2 deletions models.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import torch.nn as nn
import torch
import torch.nn as nn
import torch.nn.functional as F
from layers import GraphAttentionLayer
from layers import GraphAttentionLayer, SpGraphAttentionLayer


class GAT(nn.Module):
def __init__(self, nfeat, nhid, nclass, dropout, alpha, nheads):
"""Dense version of GAT."""
super(GAT, self).__init__()
self.dropout = dropout

Expand All @@ -23,4 +24,30 @@ def forward(self, x, adj):
return F.log_softmax(x, dim=1)


class SpGAT(nn.Module):
def __init__(self, nfeat, nhid, nclass, dropout, alpha, nheads):
"""Sparse version of GAT."""
super(SpGAT, self).__init__()
self.dropout = dropout

self.attentions = [SpGraphAttentionLayer(nfeat,
nhid,
dropout=dropout,
alpha=alpha,
concat=True) for _ in range(nheads)]
for i, attention in enumerate(self.attentions):
self.add_module('attention_{}'.format(i), attention)

self.out_att = SpGraphAttentionLayer(nhid * nheads,
nclass,
dropout=dropout,
alpha=alpha,
concat=False)

def forward(self, x, adj):
x = F.dropout(x, self.dropout, training=self.training)
x = torch.cat([att(x, adj) for att in self.attentions], dim=1)
x = F.dropout(x, self.dropout, training=self.training)
x = F.elu(self.out_att(x, adj))
return F.log_softmax(x, dim=1)

11 changes: 8 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
torch==0.4.0a0+1fdb392
scipy==0.19.1
numpy==1.14.0
certifi==2018.8.24
cffi==1.11.5
mkl-fft==1.0.4
mkl-random==1.0.1
numpy==1.15.1
pycparser==2.18
scipy==1.1.0
torch==0.4.1.post2
31 changes: 24 additions & 7 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,27 @@
from __future__ import division
from __future__ import print_function

import os
import glob
import time
import random
import argparse
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import os
import glob
from torch.autograd import Variable

from utils import load_data, accuracy
from models import GAT
from models import GAT, SpGAT

# Training settings
parser = argparse.ArgumentParser()
parser.add_argument('--no-cuda', action='store_true', default=False, help='Disables CUDA training.')
parser.add_argument('--fastmode', action='store_true', default=False, help='Validate during training pass.')
parser.add_argument('--seed', type=int, default=42, help='Random seed.')
parser.add_argument('--sparse', action='store_true', default=False, help='GAT with sparse version or not.')
parser.add_argument('--seed', type=int, default=72, help='Random seed.')
parser.add_argument('--epochs', type=int, default=10000, help='Number of epochs to train.')
parser.add_argument('--lr', type=float, default=0.005, help='Initial learning rate.')
parser.add_argument('--weight_decay', type=float, default=5e-4, help='Weight decay (L2 loss on parameters).')
Expand All @@ -42,8 +44,23 @@
adj, features, labels, idx_train, idx_val, idx_test = load_data()

# Model and optimizer
model = GAT(nfeat=features.shape[1], nhid=args.hidden, nclass=int(labels.max()) + 1, dropout=args.dropout, nheads=args.nb_heads, alpha=args.alpha)
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
if args.sparse:
model = SpGAT(nfeat=features.shape[1],
nhid=args.hidden,
nclass=int(labels.max()) + 1,
dropout=args.dropout,
nheads=args.nb_heads,
alpha=args.alpha)
else:
model = GAT(nfeat=features.shape[1],
nhid=args.hidden,
nclass=int(labels.max()) + 1,
dropout=args.dropout,
nheads=args.nb_heads,
alpha=args.alpha)
optimizer = optim.Adam(model.parameters(),
lr=args.lr,
weight_decay=args.weight_decay)

if args.cuda:
model.cuda()
Expand Down