-
Notifications
You must be signed in to change notification settings - Fork 26
/
cgat.py
137 lines (106 loc) · 5.08 KB
/
cgat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import torch
from torch.nn import Parameter
import torch.nn.functional as F
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops, softmax
class CGAT(MessagePassing):
r"""The graph attentional operator from the `"Graph Attention Networks"
<https://arxiv.org/abs/1710.10903>`_ paper
.. math::
\mathbf{x}^{\prime}_i = \alpha_{i,i}\mathbf{\Theta}\mathbf{x}_{i} +
\sum_{j \in \mathcal{N}(i)} \alpha_{i,j}\mathbf{\Theta}\mathbf{x}_{j},
where the attention coefficients :math:`\alpha_{i,j}` are computed as
.. math::
\alpha_{i,j} =
\frac{
\exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top}
[\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_j]
\right)\right)}
{\sum_{k \in \mathcal{N}(i) \cup \{ i \}}
\exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top}
[\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_k]
\right)\right)}.
Args:
in_channels (int): Size of each input sample.
out_channels (int): Size of each output sample.
heads (int, optional): Number of multi-head-attentions.
(default: :obj:`1`)
concat (bool, optional): If set to :obj:`False`, the multi-head
attentions are averaged instead of concatenated.
(default: :obj:`True`)
negative_slope (float, optional): LeakyReLU angle of the negative
slope. (default: :obj:`0.2`)
dropout (float, optional): Dropout probability of the normalized
attention coefficients which exposes each node to a stochastically
sampled neighborhood during training. (default: :obj:`0`)
bias (bool, optional): If set to :obj:`False`, the layer will not learn
an additive bias. (default: :obj:`True`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.
"""
def __init__(self, in_channels, out_channels, heads=1, concat=True,
negative_slope=0.2, dropout=0, bias=True, **kwargs):
super(CGAT, self).__init__(aggr='add', **kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.heads = heads
self.concat = concat
self.negative_slope = negative_slope
self.dropout = dropout
self.weight = Parameter(
torch.Tensor(in_channels, heads *2* out_channels))
self.weight_ij = Parameter(torch.Tensor(in_channels, heads * out_channels))
self.att = Parameter(torch.Tensor(1, heads, 3 * out_channels))
if bias and concat:
self.bias = Parameter(torch.Tensor(heads * out_channels))
elif bias and not concat:
self.bias = Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
#glorot(self.weight)
#glorot(self.att)
#zeros(self.bias)
torch.nn.init.xavier_uniform_(self.weight)
torch.nn.init.xavier_uniform_(self.att)
torch.nn.init.zeros_(self.bias)
def forward(self, x, edge_index, edge_attr, size=None):
"""
n,d; 2,k
"""
if size is None and torch.is_tensor(x):
edge_index, _ = remove_self_loops(edge_index)
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
if torch.is_tensor(x):
x = torch.matmul(x, self.weight)
else:
x = (None if x[0] is None else torch.matmul(x[0], self.weight),
None if x[1] is None else torch.matmul(x[1], self.weight))
return self.propagate(edge_index, x=x, size=size)
def message(self, edge_index_i, x_i, x_j, size_i):
# Compute attention coefficients.
x_j = x_j.view(-1, self.heads, 2*self.out_channels)
if x_i is None:
alpha = (x_j * self.att[:, :, self.out_channels:]).sum(dim=-1)
else:
x_i = x_i.view(-1, self.heads, 2*self.out_channels)
x_ij = x_i[:,:,self.out_channels:]-x_j[:,:,self.out_channels:]
alpha = (torch.cat([x_i[:,:,:self.out_channels], x_j[:,:,:self.out_channels],x_ij], dim=-1) * self.att).sum(dim=-1)
alpha = F.leaky_relu(alpha, self.negative_slope)
alpha = softmax(alpha, edge_index_i, size_i)
# Sample attention coefficients stochastically.
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
return x_j[:,:,:self.out_channels] * alpha.view(-1, self.heads, 1)
def update(self, aggr_out):
if self.concat is True:
aggr_out = aggr_out.view(-1, self.heads * self.out_channels)
else:
aggr_out = aggr_out.mean(dim=1)
if self.bias is not None:
aggr_out = aggr_out + self.bias
return aggr_out
def __repr__(self):
return '{}({}, {}, heads={})'.format(self.__class__.__name__,
self.in_channels,
self.out_channels, self.heads)