Skip to content

Commit

Permalink
Added support for dynamic message passing through training and dynami…
Browse files Browse the repository at this point in the history
…c_depth arguments
  • Loading branch information
elliottower committed Apr 16, 2021
1 parent 81263ac commit 3513c6c
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion ocpmodels/models/graphtransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,9 @@ def __init__(
features_only=False, # Use only the additional features in an FFN, no graph network
cuda=True, # Use CUDA acceleration
debug=False, # Print debugging info
debug_name='Null'
debug_name='Null', # Name for saving debugging data object to disk #TODO: remove in the future if not needed
training=False, # Training flag true enables dynamic depth in message passing
dynamic_depth="none" # Method of calculating dynamic depth value. Possible choices: "none", "uniform" and "truncnorm"
):
# OCP parameters
self.num_targets = num_targets
Expand Down Expand Up @@ -152,6 +154,8 @@ def __init__(
args.features_size = 0 # Temporary for troubleshooting
args.debug = debug
args.debug_name = debug_name
args.training = training
args.dynamic_depth = dynamic_depth
self.args = args # Hack to call in forward pass

super(GraphTransformer, self).__init__()
Expand Down Expand Up @@ -1283,6 +1287,8 @@ def __init__(self, args: Namespace,
w_h_input_size = self.hidden_size
# Shared weight matrix across depths (default)
self.W_h = nn.Linear(w_h_input_size, self.hidden_size, bias=self.bias)
self.training = args.training
self.dynamic_depth = args.dynamic_depth

def forward(self,
init_messages,
Expand Down

0 comments on commit 3513c6c

Please sign in to comment.