diff --git a/openstl/models/prednet_model.py b/openstl/models/prednet_model.py index bd0b664..34f7e2d 100644 --- a/openstl/models/prednet_model.py +++ b/openstl/models/prednet_model.py @@ -173,7 +173,8 @@ def forward(self, A0_withTimeStep, initial_states=None, extrapolation=False): ''' if initial_states is None: - initial_states = get_initial_states((1, 10, 1, 64, 64), + T, C, H, W = self.args.in_shape + initial_states = get_initial_states((1, T, C, H, W)), self.row_axis, self.col_axis, self.num_layers, self.R_stack_sizes, self.stack_sizes, self.channel_axis, self.args.device) A0_withTimeStep = A0_withTimeStep.transpose(0, 1) num_timesteps = A0_withTimeStep.shape[0]