Skip to content

Commit

Permalink
fix a minor bug in prednet
Browse files Browse the repository at this point in the history
  • Loading branch information
chengtan9907 committed Jun 3, 2023
1 parent 1c48ff0 commit 389a7d3
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion openstl/models/prednet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,8 @@ def forward(self, A0_withTimeStep, initial_states=None, extrapolation=False):
'''
if initial_states is None:
initial_states = get_initial_states((1, 10, 1, 64, 64),
T, C, H, W = self.args.in_shape
initial_states = get_initial_states((1, T, C, H, W)),
self.row_axis, self.col_axis, self.num_layers, self.R_stack_sizes, self.stack_sizes, self.channel_axis, self.args.device)
A0_withTimeStep = A0_withTimeStep.transpose(0, 1)
num_timesteps = A0_withTimeStep.shape[0]
Expand Down

0 comments on commit 389a7d3

Please sign in to comment.