diff --git a/README.md b/README.md index 65eef95..b68b2e1 100644 --- a/README.md +++ b/README.md @@ -115,6 +115,7 @@ We support various spatiotemporal prediction methods and will provide benchmarks Currently supported methods - [x] [ConvLSTM](https://arxiv.org/abs/1506.04214) (NeurIPS'2015) + - [x] [PredNet](https://openreview.net/forum?id=B1ewdt9xe) (ICLR'2017) - [x] [PredRNN](https://dl.acm.org/doi/abs/10.5555/3294771.3294855) (NeurIPS'2017) - [x] [PredRNN++](https://arxiv.org/abs/1804.06300) (ICML'2018) - [x] [E3D-LSTM](https://openreview.net/forum?id=B1lKS2AqtX) (ICLR'2018) @@ -171,7 +172,7 @@ This project is released under the [Apache 2.0 license](LICENSE). See `LICENSE` ## Acknowledgement -OpenSTL is an open-source project for STL algorithms created by researchers in **CAIRI AI Lab**. We encourage researchers interested in video and weather prediction to contribute to OpenSTL! We borrow the official implementations of [ConvLSTM](https://arxiv.org/abs/1506.04214), [PredRNN](https://dl.acm.org/doi/abs/10.5555/3294771.3294855) variants, [E3D-LSTM](https://openreview.net/forum?id=B1lKS2AqtX), [MAU](https://arxiv.org/abs/1811.07490), [CrevNet](https://openreview.net/forum?id=B1lKS2AqtX), [PhyDNet](https://arxiv.org/abs/2003.01460), [DMVFN](https://arxiv.org/abs/2303.09875). +OpenSTL is an open-source project for STL algorithms created by researchers in **CAIRI AI Lab**. We encourage researchers interested in video and weather prediction to contribute to OpenSTL! We borrow the official implementations of [ConvLSTM](https://arxiv.org/abs/1506.04214), [PredNet](https://arxiv.org/abs/1605.08104), [PredRNN](https://dl.acm.org/doi/abs/10.5555/3294771.3294855) variants, [E3D-LSTM](https://openreview.net/forum?id=B1lKS2AqtX), [MAU](https://arxiv.org/abs/1811.07490), [CrevNet](https://openreview.net/forum?id=B1lKS2AqtX), [PhyDNet](https://arxiv.org/abs/2003.01460), and [DMVFN](https://arxiv.org/abs/2303.09875). ## Citation diff --git a/configs/mmnist/PredNet.py b/configs/mmnist/PredNet.py new file mode 100644 index 0000000..a4324bd --- /dev/null +++ b/configs/mmnist/PredNet.py @@ -0,0 +1,12 @@ +method = 'PredNet' +stack_sizes = (1, 32, 64, 128, 256) # 1 refer to num of channel(input) +R_stack_sizes = stack_sizes +A_filt_sizes = (3, 3, 3, 3) +Ahat_filt_sizes = (3, 3, 3, 3, 3) +R_filt_sizes = (3, 3, 3, 3, 3) +pixel_max = 1.0 +weight_mode = 'L_0' +error_activation = 'relu' +A_activation = 'relu' +LSTM_activation = 'tanh' +LSTM_inner_activation = 'hard_sigmoid' \ No newline at end of file diff --git a/openstl/api/train.py b/openstl/api/train.py index 0a4f70e..c258906 100644 --- a/openstl/api/train.py +++ b/openstl/api/train.py @@ -263,6 +263,8 @@ def display_method_info(self): input_dummy = (_tmp_input, _tmp_flag) elif self.args.method == 'dmvfn': input_dummy = torch.ones(1, 3, C, H, W, requires_grad=True).to(self.device) + elif self.args.method == 'prednet': + input_dummy = torch.ones(1, 10, C, H, W, requires_grad=True).to(self.device) else: raise ValueError(f'Invalid method name {self.args.method}') @@ -296,7 +298,7 @@ def train(self): cur_lr = self.method.current_lr() cur_lr = sum(cur_lr) / len(cur_lr) with torch.no_grad(): - vali_loss = self.vali(self.vali_loader) + vali_loss = self.vali() if self._rank == 0: print_log('Epoch: {0}, Steps: {1} | Lr: {2:.7f} | Train Loss: {3:.7f} | Vali Loss: {4:.7f}\n'.format( @@ -313,7 +315,7 @@ def train(self): time.sleep(1) # wait for some hooks like loggers to finish self.call_hook('after_run') - def vali(self, vali_loader): + def vali(self): """A validation loop during training""" self.call_hook('before_val_epoch') results, eval_log = self.method.vali_one_epoch(self, self.vali_loader) diff --git a/openstl/methods/prednet.py b/openstl/methods/prednet.py index e99d09e..e1698ac 100644 --- a/openstl/methods/prednet.py +++ b/openstl/methods/prednet.py @@ -1,13 +1,12 @@ import time import torch import torch.nn as nn +from tqdm import tqdm import numpy as np from timm.utils import AverageMeter -from tqdm import tqdm - from openstl.models import PredNet_Model -from openstl.utils import reduce_tensor -from .base_method import Base_method +from openstl.utils import (reduce_tensor, get_initial_states) +from openstl.methods.base_method import Base_method class PredNet(Base_method): @@ -20,28 +19,123 @@ class PredNet(Base_method): def __init__(self, args, device, steps_per_epoch): Base_method.__init__(self, args, device, steps_per_epoch) - self.model = self._build_model(self.config) + self.model = self._build_model(self.args) self.model_optim, self.scheduler, self.by_epoch = self._init_optimizer(steps_per_epoch) self.criterion = nn.MSELoss() - - self.constraints = self._get_constraints() + self.train_loss = TrainLossCalculator(num_layer=len(self.args.stack_sizes), timestep=self.args.pre_seq_length + + self.args.aft_seq_length, weight_mode=self.args.weight_mode, device=self.device) def _build_model(self, args): - return PredNet_Model(args, output_mode='error').to(self.device) + return PredNet_Model(args.stack_sizes, args.R_stack_sizes, + args.A_filt_sizes, args.Ahat_filt_sizes, + args.R_filt_sizes, args.pixel_max, args) - def train_one_epoch(self, runner, train_loader, **kwargs): - """Train the model with train_loader. + def _predict(self, batch_x, batch_y, **kwargs): + input = torch.cat([batch_x, batch_y], dim=1) + states = get_initial_states(input.shape, -2, -1, len(self.args.stack_sizes), + self.args.R_stack_sizes, self.args.stack_sizes, + -3, self.args.device) + predict_list, _ = self.model(input, states, extrapolation=True) + pred_y = torch.stack(predict_list[batch_x.shape[1]:], dim=1) + return pred_y - Args: - runner: the trainer of methods. - train_loader: dataloader of train. - """ - raise NotImplementedError + def train_one_epoch(self, runner, train_loader, epoch, num_updates, eta=None, **kwargs): + """Train the model with train_loader.""" + data_time_m = AverageMeter() + losses_m = AverageMeter() + self.model.train() + if self.by_epoch: + self.scheduler.step(epoch) + train_pbar = tqdm(train_loader) if self.rank == 0 else train_loader - def _predict(self, batch_x, batch_y, **kwargs): - """Forward the model. + end = time.time() + for batch_x, batch_y in train_pbar: + data_time_m.update(time.time() - end) + self.model_optim.zero_grad() + + batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device) + runner.call_hook('before_train_iter') + + with self.amp_autocast(): + input = torch.cat([batch_x, batch_y], dim=1) + states = get_initial_states(input.shape, -2, -1, len(self.args.stack_sizes), + self.args.R_stack_sizes, self.args.stack_sizes, + -3, self.args.device) + + _, error_list = self.model(input, states, extrapolation=False) + loss = self.train_loss.calculate_loss(error_list) + + if not self.dist: + losses_m.update(loss.item(), batch_x.size(0)) + + if self.loss_scaler is not None: + if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)): + raise ValueError( + "Inf or nan loss value. Please use fp32 training!") + self.loss_scaler( + loss, self.model_optim, + clip_grad=self.args.clip_grad, clip_mode=self.args.clip_mode, + parameters=self.model.parameters()) + else: + loss.backward() + self.clip_grads(self.model.parameters()) + + self.model_optim.step() + torch.cuda.synchronize() + num_updates += 1 + + if self.dist: + losses_m.update(reduce_tensor(loss), batch_x.size(0)) + + if not self.by_epoch: + self.scheduler.step() + runner.call_hook('after_train_iter') + runner._iter += 1 + + if self.rank == 0: + log_buffer = 'train loss: {:.4f}'.format(loss.item()) + log_buffer += ' | data time: {:.4f}'.format(data_time_m.avg) + train_pbar.set_description(log_buffer) + + end = time.time() # end for + + if hasattr(self.model_optim, 'sync_lookahead'): + self.model_optim.sync_lookahead() + + return num_updates, losses_m, eta + + +class TrainLossCalculator: + def __init__(self, num_layer, timestep, weight_mode, device): + self.num_layers = num_layer + self.timestep = timestep + self.weight_mode = weight_mode + self.device = device + + if self.weight_mode == 'L_0': + layer_weights = np.array([0. for _ in range(num_layer)]) + layer_weights[0] = 1. + elif self.weight_mode == 'L_all': + layer_weights = np.array([0.1 for _ in range(num_layer)]) + layer_weights[0] = 1. + else: + raise (RuntimeError('Unknown loss weighting mode! ' + 'Please use `L_0` or `L_all`.')) + self.layer_weights = torch.from_numpy(layer_weights).to(self.device) + + def calculate_loss(self, input): + # Weighted by layer + error_list = [batch_numLayer_error * self.layer_weights for + batch_numLayer_error in input] # Use the broadcast + error_list = [torch.sum(error_at_t) for error_at_t in error_list] + + # Weighted by timestep + time_loss_weights = torch.cat([torch.tensor([0.], device=self.device), + torch.full((self.timestep - 1,), + 1. / (self.timestep - 1), device=self.device)]) - Args: - batch_x, batch_y: testing samples and groung truth. - """ - raise NotImplementedError + total_error = error_list[0] * time_loss_weights[0] + for err, time_weight in zip(error_list[1:], time_loss_weights[1:]): + total_error += err * time_weight + total_error /= input[0].shape[0] # input[0].shape[0] = B + return total_error diff --git a/openstl/models/prednet_model.py b/openstl/models/prednet_model.py index 74fe876..bd0b664 100644 --- a/openstl/models/prednet_model.py +++ b/openstl/models/prednet_model.py @@ -1,203 +1,189 @@ import torch -from torch import nn -from torch.nn import functional as F - -from openstl.modules import PredNetConvLSTMCell +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from openstl.utils import get_initial_states class PredNet_Model(nn.Module): - r"""PredNet Model - - Implementation of `Deep Predictive Coding Networks for Video Prediction - and Unsupervised Learning `_. - - """ - - def __init__(self, configs, output_mode='error', **kwargs): + def __init__(self, stack_sizes, R_stack_sizes, + A_filt_sizes, Ahat_filt_sizes, R_filt_sizes, + pixel_max=1., args=None): super(PredNet_Model, self).__init__() - self.configs = configs - self.in_shape = configs.in_shape - _, _, H, W = configs.in_shape - - self.a_channels = getattr(configs, "A_channels", (3, 48, 96, 192)) - self.r_channels = getattr(configs, "R_channels", (3, 48, 96, 192)) - self.n_layers = len(self.r_channels) - self.r_channels += (0, ) # for convenience - self.output_mode = output_mode - self.gating_mode = getattr(configs, "gating_mode", 'mul') - self.extrap_start_time = getattr(configs, "extrap_start_time", None) - self.peephole = getattr(configs, "peephole", False) - self.lstm_tied_bias = getattr(configs, "lstm_tied_bias", False) - self.p_max = getattr(configs, "p_max", 1.0) - - # Input validity checks - default_output_modes = ['prediction', 'error', 'pred+err'] - layer_output_modes = [unit + str(l) for l in range(self.n_layers) for unit in ['R', 'E', 'A', 'Ahat']] - default_gating_modes = ['mul', 'sub'] - assert output_mode in default_output_modes + layer_output_modes, 'Invalid output_mode: ' + str(output_mode) - assert self.gating_mode in default_gating_modes, 'Invalid gating_mode: ' + str(self.gating_mode) - - if self.output_mode in layer_output_modes: - self.output_layer_type = self.output_mode[:-1] - self.output_layer_num = int(self.output_mode[-1]) - else: - self.output_layer_type = None - self.output_layer_num = None - - # h, w = self.input_size - - for i in range(self.n_layers): - # A_channels multiplied by 2 because E_l concactenates pred-target and target-pred - # Hidden states don't have same size due to upsampling - # How does this handle i = L-1 (final layer) | appends a zero - - if self.gating_mode == 'mul': - cell = PredNetConvLSTMCell((H, W), 2 * self.a_channels[i] + self.r_channels[i+1], self.r_channels[i], - (3, 3), gating_mode='mul', peephole=self.peephole, tied_bias=self.lstm_tied_bias) - elif self.gating_mode == 'sub': - cell = PredNetConvLSTMCell((H, W), 2 * self.a_channels[i] + self.r_channels[i+1], self.r_channels[i], - (3, 3), gating_mode='sub', peephole=self.peephole, tied_bias=self.lstm_tied_bias) - - setattr(self, 'cell{}'.format(i), cell) - H = H // 2 - W = W // 2 - - for i in range(self.n_layers): - # Calculate predictions A_hat - conv = nn.Sequential(nn.Conv2d(self.r_channels[i], self.a_channels[i], 3, padding=1), nn.ReLU()) - setattr(self, 'conv{}'.format(i), conv) - - self.upsample = nn.Upsample(scale_factor=2) - self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2) - - for l in range(self.n_layers - 1): - # Propagate error as next layer's target (line 16 of Lotter algo) - # In channels = 2 * A_channels[l] because of pos/neg error concat - # NOTE: Operation belongs to curr layer l and produces next layer state l+1 - - update_A = nn.Sequential(nn.Conv2d(2* self.a_channels[l], self.a_channels[l+1], (3, 3), padding=1), self.maxpool) - setattr(self, 'update_A{}'.format(l), update_A) - - self.criterion = nn.MSELoss() - - def set_output_mode(self, output_mode): - self.output_mode = output_mode - - # Input validity checks - default_output_modes = ['prediction', 'error', 'pred+err'] - layer_output_modes = [unit + str(l) for l in range(self.n_layers) for unit in ['R', 'E', 'A', 'Ahat']] - assert output_mode in default_output_modes + layer_output_modes, 'Invalid output_mode: ' + str(output_mode) - - if self.output_mode in layer_output_modes: - self.output_layer_type = self.output_mode[:-1] - self.output_layer_num = int(self.output_mode[-1]) - else: - self.output_layer_type = None - self.output_layer_num = None - - def step(self, a, states): - batch_size = a.size(0) - R_layers = states[:self.n_layers] - C_layers = states[self.n_layers:2*self.n_layers] - E_layers = states[2*self.n_layers:3*self.n_layers] - - if self.extrap_start_time is not None: - t = states[-1] - if t >= self.extrap_start_time: # if past self.extra_start_time use previous prediction as input - a = states[-2] - - # Update representation units - for l in reversed(range(self.n_layers)): - cell = getattr(self, 'cell{}'.format(l)) - r_tm1 = R_layers[l] - c_tm1 = C_layers[l] - e_tm1 = E_layers[l] - if l == self.n_layers - 1: - r, c = cell(e_tm1, (r_tm1, c_tm1)) - else: - tmp = torch.cat((e_tm1, self.upsample(R_layers[l+1])), 1) - r, c = cell(tmp, (r_tm1, c_tm1)) - R_layers[l] = r - C_layers[l] = c - - # Perform error forward pass - for l in range(self.n_layers): - conv = getattr(self, 'conv{}'.format(l)) - a_hat = conv(R_layers[l]) + self.args = args + self.stack_sizes = stack_sizes + self.num_layers = len(stack_sizes) + assert len(R_stack_sizes) == self.num_layers + self.R_stack_sizes = R_stack_sizes + assert len(A_filt_sizes) == self.num_layers - 1 + self.A_filt_sizes = A_filt_sizes + assert len(Ahat_filt_sizes) == self.num_layers + self.Ahat_filt_sizes = Ahat_filt_sizes + assert len(R_filt_sizes) == self.num_layers + self.R_filt_sizes = R_filt_sizes + + self.pixel_max = pixel_max + self.error_activation = args.error_activation # 'relu' + self.A_activation = args.A_activation # 'relu' + self.LSTM_activation = args.LSTM_activation # 'tanh' + self.LSTM_inner_activation = args.LSTM_inner_activation # 'hard_sigmoid' + self.channel_axis = -3 + self.row_axis = -2 + self.col_axis = -1 + + self.get_activationFunc = { + 'relu': nn.ReLU(), + 'tanh': nn.Tanh(), + } + + self.build_layers() + self.init_weights() + + def init_weights(self): + def init_layer_weights(layer): + if isinstance(layer, nn.Conv2d): + layer.bias.data.zero_() + self.apply(init_layer_weights) + + def hard_sigmoid(self, x, slope=0.2, shift=0.5): + x = (slope * x) + shift + x = torch.clamp(x, 0, 1) + return x + + def batch_flatten(self, x): + shape = [*x.size()] + dim = np.prod(shape[1:]) + return x.view(-1, int(dim)) + + def isNotTopestLayer(self, layerIndex): + '''judge if the layerIndex is not the topest layer.''' + return True if layerIndex < self.num_layers - 1 else False + + def build_layers(self): + # i: input, f: forget, c: cell, o: output + self.conv_layers = {item: [] + for item in ['i', 'f', 'c', 'o', 'A', 'Ahat']} + lstm_list = ['i', 'f', 'c', 'o'] + + for item in sorted(self.conv_layers.keys()): + for l in range(self.num_layers): + if item == 'Ahat': + in_channels = self.R_stack_sizes[l] + self.conv_layers['Ahat'].append(nn.Conv2d(in_channels=in_channels, + out_channels=self.stack_sizes[ + l], kernel_size=self.Ahat_filt_sizes[l], + stride=(1, 1), padding=int((self.Ahat_filt_sizes[l] - 1) / 2))) + act = 'relu' if l == 0 else self.A_activation + self.conv_layers['Ahat'].append(self.get_activationFunc[act]) + elif item == 'A': + if self.isNotTopestLayer(l): + in_channels = self.R_stack_sizes[l] * 2 + self.conv_layers['A'].append(nn.Conv2d(in_channels=in_channels, + out_channels=self.stack_sizes[l + 1], kernel_size=self.A_filt_sizes[l], stride=(1, 1), padding=int((self.A_filt_sizes[l] - 1) / 2))) + self.conv_layers['A'].append(self.get_activationFunc[self.A_activation]) + elif item in lstm_list: # build R module + in_channels = self.stack_sizes[l] * \ + 2 + self.R_stack_sizes[l] + if self.isNotTopestLayer(l): + in_channels += self.R_stack_sizes[l + 1] + self.conv_layers[item].append(nn.Conv2d(in_channels=in_channels, out_channels=self.R_stack_sizes[l], + kernel_size=self.R_filt_sizes[l], stride=(1, 1), padding=int((self.R_filt_sizes[l] - 1) / 2))) + + for name, layerList in self.conv_layers.items(): + self.conv_layers[name] = nn.ModuleList( + layerList).to(self.args.device) + setattr(self, name, self.conv_layers[name]) + + self.upSample = nn.Upsample(scale_factor=2, mode='nearest') + self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) + + def step(self, A, states, extrapolation=False): + n = self.num_layers + R_current = states[:(n)] + C_current = states[(n):(2 * n)] + E_current = states[(2 * n):(3 * n)] + + timestep = states[-1] + if extrapolation == True and timestep >= self.args.in_shape[0]: + A = states[-2] + + R_list, C_list, E_list = [], [], [] + + for l in reversed(range(self.num_layers)): + inputs = [R_current[l], E_current[l]] + if self.isNotTopestLayer(l): + inputs.append(R_up) + inputs = torch.cat(inputs, dim=self.channel_axis) + + in_gate = self.hard_sigmoid(self.conv_layers['i'][l](inputs)) + forget_gate = self.hard_sigmoid(self.conv_layers['f'][l](inputs)) + cell_gate = F.tanh(self.conv_layers['c'][l](inputs)) + out_gate = self.hard_sigmoid(self.conv_layers['o'][l](inputs)) + C_next = (forget_gate * C_current[l]) + (in_gate * cell_gate) + R_next = out_gate * F.tanh(C_next) + + C_list.insert(0, C_next) + R_list.insert(0, R_next) + + if l > 0: + R_up = self.upSample(R_next) + + for l in range(self.num_layers): + Ahat = self.conv_layers['Ahat'][2 * l](R_list[l]) # ConvLayer + Ahat = self.conv_layers['Ahat'][2 * + l + 1](Ahat) # activation function + if l == 0: - a_hat= torch.min(a_hat, torch.tensor(self.p_max).to(self.configs.device)) # alternative SatLU (Lotter) - frame_prediction = a_hat - pos = F.relu(a_hat - a) - neg = F.relu(a - a_hat) - e = torch.cat([pos, neg],1) - E_layers[l] = e - - # Handling layer-specific outputs - if self.output_layer_num == l: - if self.output_layer_type == 'A': - output = a - elif self.output_layer_type == 'Ahat': - output = a_hat - elif self.output_layer_type == 'R': - output = R_layers[l] - elif self.output_layer_type == 'E': - output = E_layers[l] - - if l < self.n_layers - 1: # updating A for next layer - update_A = getattr(self, 'update_A{}'.format(l)) - a = update_A(e) - - if self.output_layer_type is None: - if self.output_mode == 'prediction': - output = frame_prediction + Ahat = torch.clamp(Ahat, max=self.pixel_max) + frame_prediction = Ahat + + if self.error_activation.lower() == 'relu': + E_up = F.relu(Ahat - A) + E_down = F.relu(A - Ahat) + elif self.error_activation.lower() == 'tanh': + E_up = F.tanh(Ahat - A) + E_down = F.tanh(A - Ahat) else: - # Batch flatten (return 2D matrix) then mean over units - # Finally, concatenate layers (batch, n_layers) - mean_E_layers = torch.cat([torch.mean(e.view(batch_size, -1), axis=1, keepdim=True) for e in E_layers], axis=1) - if self.output_mode == 'error': - output = mean_E_layers - else: - output = torch.cat([frame_prediction.view(batch_size, -1), mean_E_layers], axis=1) - - states = R_layers + C_layers + E_layers - if self.extrap_start_time is not None: - states += [frame_prediction, t+1] - return output, states - - def forward(self, input_tensor, **kwargs): - - R_layers = [None] * self.n_layers - C_layers = [None] * self.n_layers - E_layers = [None] * self.n_layers - - _, _, h, w = self.in_shape - batch_size = input_tensor.size(0) - - # Initialize states - for l in range(self.n_layers): - R_layers[l] = torch.zeros(batch_size, self.r_channels[l], h, w, requires_grad=True).to(self.configs.device) - C_layers[l] = torch.zeros(batch_size, self.r_channels[l], h, w, requires_grad=True).to(self.configs.device) - E_layers[l] = torch.zeros(batch_size, 2*self.a_channels[l], h, w, requires_grad=True).to(self.configs.device) - # Size of hidden state halves from each layer to the next - h = h//2 - w = w//2 - - states = R_layers + C_layers + E_layers - # Initialize previous_prediction - if self.extrap_start_time is not None: - frame_prediction = torch.zeros_like(input_tensor[:,0], dtype=torch.float32).to(self.configs.device) - states += [frame_prediction, -1] # [a, t] - - num_time_steps = input_tensor.size(1) - total_output = [] # contains output sequence - for t in range(num_time_steps): - a = input_tensor[:,t].type(torch.FloatTensor).to(self.configs.device) - output, states = self.step(a, states) - total_output.append(output) - - ax = len(output.shape) - # print(output.shape) - total_output = [out.view(out.shape + (1,)) for out in total_output] - total_output = torch.cat(total_output, axis=ax) # (batch, ..., nt) - - return total_output + raise (RuntimeError( + 'cannot obtain the activation function named %s' % self.error_activation)) + + E_list.append(torch.cat((E_up, E_down), dim=self.channel_axis)) + + if self.isNotTopestLayer(l): + A = self.conv_layers['A'][2 * l](E_list[l]) + A = self.conv_layers['A'][2 * l + 1](A) + A = self.pool(A) # target for next layer + + for l in range(self.num_layers): + layer_error = torch.mean(self.batch_flatten( + E_list[l]), dim=-1, keepdim=True) + all_error = layer_error if l == 0 else torch.cat( + (all_error, layer_error), dim=-1) + + states = R_list + C_list + E_list + predict = frame_prediction + error = all_error + states += [frame_prediction, timestep + 1] + return predict, error, states + + def forward(self, A0_withTimeStep, initial_states=None, extrapolation=False): + ''' + A0_withTimeStep is the input from dataloader. + Its shape is: (batch_size, timesteps, Channel, Height, Width). + + ''' + if initial_states is None: + initial_states = get_initial_states((1, 10, 1, 64, 64), + self.row_axis, self.col_axis, self.num_layers, self.R_stack_sizes, self.stack_sizes, self.channel_axis, self.args.device) + A0_withTimeStep = A0_withTimeStep.transpose(0, 1) + num_timesteps = A0_withTimeStep.shape[0] + + hidden_states = initial_states + predict_list, error_list = [], [] + for t in range(num_timesteps): + A0 = A0_withTimeStep[t, ...] + predict, error, hidden_states = self.step( + A0, hidden_states, extrapolation) + predict_list.append(predict) + error_list.append(error) + return predict_list, error_list diff --git a/openstl/utils/__init__.py b/openstl/utils/__init__.py index b6e8da4..46415c7 100644 --- a/openstl/utils/__init__.py +++ b/openstl/utils/__init__.py @@ -10,8 +10,9 @@ from .parser import create_parser from .predrnn_utils import (reserve_schedule_sampling_exp, schedule_sampling, reshape_patch, reshape_patch_back) -from .progressbar import ProgressBar, Timer from .dmvfn_utils import LapLoss, VGGPerceptualLoss +from .prednet_utils import get_initial_states +from .progressbar import ProgressBar, Timer @@ -23,6 +24,7 @@ 'get_dataset', 'count_parameters', 'measure_throughput', 'load_config', 'update_config', 'weights_to_cpu', 'init_dist', 'init_random_seed', 'get_dist_info', 'reduce_tensor', 'reserve_schedule_sampling_exp', 'schedule_sampling', 'reshape_patch', 'reshape_patch_back', + 'LapLoss', 'VGGPerceptualLoss', + 'get_initial_states', 'ProgressBar', 'Timer', - 'LapLoss', 'VGGPerceptualLoss' ] \ No newline at end of file diff --git a/openstl/utils/prednet_utils.py b/openstl/utils/prednet_utils.py new file mode 100644 index 0000000..bff7f95 --- /dev/null +++ b/openstl/utils/prednet_utils.py @@ -0,0 +1,47 @@ +import numpy as np +import torch + + +def get_initial_states(input_shape, row_axis, col_axis, num_layers, + R_stack_sizes, stack_sizes, channel_axis, + device): + # input_shape.shape: (batch_size, timeSteps, Channel, Height, Width) + init_height = input_shape[row_axis] + init_width = input_shape[col_axis] + + base_initial_state = np.zeros(input_shape) + non_channel_axis = -1 + for _ in range(2): + base_initial_state = np.sum(base_initial_state, axis=non_channel_axis) + base_initial_state = np.sum(base_initial_state, axis=1) # (batch_size, Channel) + + initial_states = [] + states_to_pass = ['R', 'C', 'E'] # R is `representation`, C is Cell state in ConvLSTM, E is `error`. + num_layer_to_pass = {stp: num_layers for stp in states_to_pass} + states_to_pass.append('Ahat') # pass prediction in states so can use as actual for t+1 when extrapolating + num_layer_to_pass['Ahat'] = 1 + + for stp in states_to_pass: + for l in range(num_layer_to_pass[stp]): + downsample_factor = 2 ** l + row = init_height // downsample_factor + col = init_width // downsample_factor + if stp in ['R', 'C']: + stack_size = R_stack_sizes[l] + elif stp == 'E': + stack_size = stack_sizes[l] * 2 + elif stp == 'Ahat': + stack_size = stack_sizes[l] + + output_size = stack_size * row * col # flattened size + reducer = np.zeros((input_shape[channel_axis], output_size)) # (Channel, output_size) + initial_state = np.dot(base_initial_state, reducer) # (batch_size, output_size) + + output_shape = (-1, stack_size, row, col) + initial_state = torch.from_numpy(np.reshape(initial_state, output_shape)).float().to( + device).requires_grad_() # requires_grad=True + initial_states += [initial_state] + + initial_states += [ + torch.zeros(1, dtype=torch.int).to(device)] # the last state will correspond to the current timestep + return initial_states \ No newline at end of file