Skip to content

Commit

Permalink
Fixed bug having dense connection in message passing block disabled b…
Browse files Browse the repository at this point in the history
…y default
  • Loading branch information
elliottower committed Apr 21, 2021
1 parent af0d40c commit f019a70
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions ocpmodels/models/graphtransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def __init__(
bias=False, # Whether to add bias to linear layers
depth=3, # Number of message passing steps
dropout=0.0, # Dropout probability
activation='ReLU', # Choices: ReLU', 'LeakyReLU', 'PReLU', 'tanh', 'SELU', 'ELU' #TODO: import other activation functions/enforce these?
activation='ReLU', # Choices: ReLU', 'LeakyReLU', 'PReLU', 'tanh', 'SELU', 'ELU'
weight_decay=0.0, # Weight decay
num_attn_head=4, # Number of attention heads per MTBlock
num_mt_block=1, # Number of MTBlocks
Expand Down Expand Up @@ -581,7 +581,7 @@ def convert_input(args, data):
a2a = a2a2
b2a = b2a2

# Set reverse bond mapping to all zeros, as our graph is mostly undirected TODO: check if this limits training capcity
# Set reverse bond mapping to all zeros, as our graph is mostly undirected
b2revb = torch.zeros((data.edge_index.shape[1],)).long()

batch = (f_atoms, f_bonds, a2b, b2a, b2revb, a_scope, b_scope, a2a)
Expand Down Expand Up @@ -1360,7 +1360,7 @@ def forward(self,
for _ in range(ndepth - 1):
if self.undirected:
# two directions should be the same
message = (message + message[b2revb]) / 2 # TODO: figure out error with message being out of bounds (using undirected)
message = (message + message[b2revb]) / 2
nei_message = select_neighbor_and_aggregate(message, a2nei)
a_message = nei_message
if self.attached_fea:
Expand All @@ -1376,9 +1376,7 @@ def forward(self,
else:
message = a_message
message = self.W_h(message)
# BUG here, by default MPNEncoder use the dense connection in the message passing step.
# The correct form should if not self.dense
if self.dense:
if not self.dense:
message = self.act_func(message) # num_bonds x hidden_size
else:
message = self.act_func(input + message)
Expand Down

0 comments on commit f019a70

Please sign in to comment.