Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support sparse version GAT #8

Merged
merged 4 commits into from
Sep 19, 2018
Merged

Support sparse version GAT #8

merged 4 commits into from
Sep 19, 2018

Conversation

sh0416
Copy link
Contributor

@sh0416 sh0416 commented Sep 14, 2018

I develop sparse version GAT using pytorch 0.4.1, which can improve training speed significantly.

Now, 0.07 sec per epoch in Titan Xp environment.

@Diego999
Copy link
Owner

Thank you for the pull request !

Everything looks fine. I'm juste wondering if there is a reason to change the weights of W and a ?

@sh0416
Copy link
Contributor Author

sh0416 commented Sep 15, 2018

Actually, i don't know why you use xvaiver things. I think normal intialization is simple and understandable to other user, but i can change this if you want.

FYI, softmax is not supported in sparse format so that I use exponential which can make weight unstable. For example, std=0.1 doesn't work because nan occur during the training

@Diego999
Copy link
Owner

Maybe normal initialization is simple but depending on the activation function, it's not theoretically the best ;)

What happen with the xavier weights and the exponential ? Is this, finally, the reason of using normal initialization with a small std ?

@sh0416
Copy link
Contributor Author

sh0416 commented Sep 15, 2018

Yes, you're right, xavier is more theoretically sounded.

I tried xavier weight and it doesn't generate nan during exponential things.

The only why I use small std is exponential overflow, no other reason for that.

I will add commit for the previous issue (xavier things).

@Diego999
Copy link
Owner

Great, thank you !

One last question: what performance do you get ?

@sh0416
Copy link
Contributor Author

sh0416 commented Sep 16, 2018

training time per epoch: 0.07~ 0.09 sec

Test accuracy: average 84%

Thanks for advicing me :)

@sh0416
Copy link
Contributor Author

sh0416 commented Sep 16, 2018

I found that dropout is missing.

I will add this tomorrow. plz wait until I implement dropout feature.

@sh0416
Copy link
Contributor Author

sh0416 commented Sep 17, 2018

To sum up,

I tested 5 times and test acc is average 0.8376 with std 0.004827. The best accuracy is 0.844.

Training time per epoch is 0.07 ~ 0.1 sec in Titan Xp environment.

@Diego999
Copy link
Owner

Thank you for all these changes !

So given your information, I think it's a good idea to have a sparse version and keep the main one. So I propose to add an other model SpGAT using instead SpGraphAttentionLayer. In this manner, we conserve both models, the "original" one and a memory improved one with slight differences.

What do you think ?

@sh0416
Copy link
Contributor Author

sh0416 commented Sep 17, 2018

Yeah, I also consider that way.

I will change that ASAP. :)

Also, if you don't mind, I want to change your xavier initialization to remove warning. Is it OK for you?

@sh0416
Copy link
Contributor Author

sh0416 commented Sep 17, 2018

OK, I think it is good. I add --sparse flag to turn sparse version or not, which is fancy :)

@sh0416
Copy link
Contributor Author

sh0416 commented Sep 19, 2018

What are you waiting for?

@sh0416
Copy link
Contributor Author

sh0416 commented Oct 23, 2018

#11 not occur in my environment, Maybe, there is some other dataset and the denominator of softmax equation will lead to nan error.

@sh0416
Copy link
Contributor Author

sh0416 commented Oct 23, 2018

I found that backpropagation is not working well because of copying dense vector to sparse matrix. Now I fix this phenomenon.
Also, I visualize model using graphviz. You can see the overall model in 'output/graph_visualize.pdf'

@liuyijiang1994
Copy link

liuyijiang1994 commented Mar 24, 2019

Hi, I have the same problem of #11 .
osX 12.12 python3.7 pytorch 0.4.1
and my adj is

nfeat, nhid, nclass, dropout, alpha, nheads = 10, 10, 3, 0.5, 0.2, 5
gat = SpGAT(nfeat, nhid, nclass, dropout, alpha, nheads)
x = torch.randn([5, 10])
adj = sp.coo_matrix(([1], ([0], [0])), shape=(2170, 2170), dtype=np.float32)
adj = torch.FloatTensor(np.array(adj.todense()))
print(adj.shape)
gat(x, adj)

and it happend:
Traceback (most recent call last):
File "/home/liu/home/liu/pyGAT/t.py", line 13, in
gat(x, adj)
File "/root/anaconda3/envs/liu-dl-3.6/lib/python3.6/site-packages/torch/nn/modules/module.py", line 477, in call
result = self.forward(*input, **kwargs)
File "/home/liu/home/liu/pyGAT/models.py", line 50, in forward
x = torch.cat([att(x, adj) for att in self.attentions], dim=1)
File "/home/liu/home/liu/pyGAT/models.py", line 50, in
x = torch.cat([att(x, adj) for att in self.attentions], dim=1)
File "/root/anaconda3/envs/liu-dl-3.6/lib/python3.6/site-packages/torch/nn/modules/module.py", line 477, in call
result = self.forward(*input, **kwargs)
File "/home/liu/home/liu/pyGAT/layers.py", line 135, in forward
assert not torch.isnan(h_prime).any()
AssertionError

@CYBruce
Copy link

CYBruce commented Dec 4, 2020

Hi, I have the same problem of #11 .
osX 12.12 python3.7 pytorch 0.4.1
and my adj is

nfeat, nhid, nclass, dropout, alpha, nheads = 10, 10, 3, 0.5, 0.2, 5
gat = SpGAT(nfeat, nhid, nclass, dropout, alpha, nheads)
x = torch.randn([5, 10])
adj = sp.coo_matrix(([1], ([0], [0])), shape=(2170, 2170), dtype=np.float32)
adj = torch.FloatTensor(np.array(adj.todense()))
print(adj.shape)
gat(x, adj)

and it happend:
Traceback (most recent call last):
File "/home/liu/home/liu/pyGAT/t.py", line 13, in
gat(x, adj)
File "/root/anaconda3/envs/liu-dl-3.6/lib/python3.6/site-packages/torch/nn/modules/module.py", line 477, in call
result = self.forward(*input, **kwargs)
File "/home/liu/home/liu/pyGAT/models.py", line 50, in forward
x = torch.cat([att(x, adj) for att in self.attentions], dim=1)
File "/home/liu/home/liu/pyGAT/models.py", line 50, in
x = torch.cat([att(x, adj) for att in self.attentions], dim=1)
File "/root/anaconda3/envs/liu-dl-3.6/lib/python3.6/site-packages/torch/nn/modules/module.py", line 477, in call
result = self.forward(*input, **kwargs)
File "/home/liu/home/liu/pyGAT/layers.py", line 135, in forward
assert not torch.isnan(h_prime).any()
AssertionError

@liuyijiang1994 Have you solved this problem?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants