diff --git a/README.md b/README.md
index 65eef95..b68b2e1 100644
--- a/README.md
+++ b/README.md
@@ -115,6 +115,7 @@ We support various spatiotemporal prediction methods and will provide benchmarks
Currently supported methods
- [x] [ConvLSTM](https://arxiv.org/abs/1506.04214) (NeurIPS'2015)
+ - [x] [PredNet](https://openreview.net/forum?id=B1ewdt9xe) (ICLR'2017)
- [x] [PredRNN](https://dl.acm.org/doi/abs/10.5555/3294771.3294855) (NeurIPS'2017)
- [x] [PredRNN++](https://arxiv.org/abs/1804.06300) (ICML'2018)
- [x] [E3D-LSTM](https://openreview.net/forum?id=B1lKS2AqtX) (ICLR'2018)
@@ -171,7 +172,7 @@ This project is released under the [Apache 2.0 license](LICENSE). See `LICENSE`
## Acknowledgement
-OpenSTL is an open-source project for STL algorithms created by researchers in **CAIRI AI Lab**. We encourage researchers interested in video and weather prediction to contribute to OpenSTL! We borrow the official implementations of [ConvLSTM](https://arxiv.org/abs/1506.04214), [PredRNN](https://dl.acm.org/doi/abs/10.5555/3294771.3294855) variants, [E3D-LSTM](https://openreview.net/forum?id=B1lKS2AqtX), [MAU](https://arxiv.org/abs/1811.07490), [CrevNet](https://openreview.net/forum?id=B1lKS2AqtX), [PhyDNet](https://arxiv.org/abs/2003.01460), [DMVFN](https://arxiv.org/abs/2303.09875).
+OpenSTL is an open-source project for STL algorithms created by researchers in **CAIRI AI Lab**. We encourage researchers interested in video and weather prediction to contribute to OpenSTL! We borrow the official implementations of [ConvLSTM](https://arxiv.org/abs/1506.04214), [PredNet](https://arxiv.org/abs/1605.08104), [PredRNN](https://dl.acm.org/doi/abs/10.5555/3294771.3294855) variants, [E3D-LSTM](https://openreview.net/forum?id=B1lKS2AqtX), [MAU](https://arxiv.org/abs/1811.07490), [CrevNet](https://openreview.net/forum?id=B1lKS2AqtX), [PhyDNet](https://arxiv.org/abs/2003.01460), and [DMVFN](https://arxiv.org/abs/2303.09875).
## Citation
diff --git a/configs/mmnist/PredNet.py b/configs/mmnist/PredNet.py
new file mode 100644
index 0000000..a4324bd
--- /dev/null
+++ b/configs/mmnist/PredNet.py
@@ -0,0 +1,12 @@
+method = 'PredNet'
+stack_sizes = (1, 32, 64, 128, 256) # 1 refer to num of channel(input)
+R_stack_sizes = stack_sizes
+A_filt_sizes = (3, 3, 3, 3)
+Ahat_filt_sizes = (3, 3, 3, 3, 3)
+R_filt_sizes = (3, 3, 3, 3, 3)
+pixel_max = 1.0
+weight_mode = 'L_0'
+error_activation = 'relu'
+A_activation = 'relu'
+LSTM_activation = 'tanh'
+LSTM_inner_activation = 'hard_sigmoid'
\ No newline at end of file
diff --git a/openstl/api/train.py b/openstl/api/train.py
index 0a4f70e..c258906 100644
--- a/openstl/api/train.py
+++ b/openstl/api/train.py
@@ -263,6 +263,8 @@ def display_method_info(self):
input_dummy = (_tmp_input, _tmp_flag)
elif self.args.method == 'dmvfn':
input_dummy = torch.ones(1, 3, C, H, W, requires_grad=True).to(self.device)
+ elif self.args.method == 'prednet':
+ input_dummy = torch.ones(1, 10, C, H, W, requires_grad=True).to(self.device)
else:
raise ValueError(f'Invalid method name {self.args.method}')
@@ -296,7 +298,7 @@ def train(self):
cur_lr = self.method.current_lr()
cur_lr = sum(cur_lr) / len(cur_lr)
with torch.no_grad():
- vali_loss = self.vali(self.vali_loader)
+ vali_loss = self.vali()
if self._rank == 0:
print_log('Epoch: {0}, Steps: {1} | Lr: {2:.7f} | Train Loss: {3:.7f} | Vali Loss: {4:.7f}\n'.format(
@@ -313,7 +315,7 @@ def train(self):
time.sleep(1) # wait for some hooks like loggers to finish
self.call_hook('after_run')
- def vali(self, vali_loader):
+ def vali(self):
"""A validation loop during training"""
self.call_hook('before_val_epoch')
results, eval_log = self.method.vali_one_epoch(self, self.vali_loader)
diff --git a/openstl/methods/prednet.py b/openstl/methods/prednet.py
index e99d09e..e1698ac 100644
--- a/openstl/methods/prednet.py
+++ b/openstl/methods/prednet.py
@@ -1,13 +1,12 @@
import time
import torch
import torch.nn as nn
+from tqdm import tqdm
import numpy as np
from timm.utils import AverageMeter
-from tqdm import tqdm
-
from openstl.models import PredNet_Model
-from openstl.utils import reduce_tensor
-from .base_method import Base_method
+from openstl.utils import (reduce_tensor, get_initial_states)
+from openstl.methods.base_method import Base_method
class PredNet(Base_method):
@@ -20,28 +19,123 @@ class PredNet(Base_method):
def __init__(self, args, device, steps_per_epoch):
Base_method.__init__(self, args, device, steps_per_epoch)
- self.model = self._build_model(self.config)
+ self.model = self._build_model(self.args)
self.model_optim, self.scheduler, self.by_epoch = self._init_optimizer(steps_per_epoch)
self.criterion = nn.MSELoss()
-
- self.constraints = self._get_constraints()
+ self.train_loss = TrainLossCalculator(num_layer=len(self.args.stack_sizes), timestep=self.args.pre_seq_length +
+ self.args.aft_seq_length, weight_mode=self.args.weight_mode, device=self.device)
def _build_model(self, args):
- return PredNet_Model(args, output_mode='error').to(self.device)
+ return PredNet_Model(args.stack_sizes, args.R_stack_sizes,
+ args.A_filt_sizes, args.Ahat_filt_sizes,
+ args.R_filt_sizes, args.pixel_max, args)
- def train_one_epoch(self, runner, train_loader, **kwargs):
- """Train the model with train_loader.
+ def _predict(self, batch_x, batch_y, **kwargs):
+ input = torch.cat([batch_x, batch_y], dim=1)
+ states = get_initial_states(input.shape, -2, -1, len(self.args.stack_sizes),
+ self.args.R_stack_sizes, self.args.stack_sizes,
+ -3, self.args.device)
+ predict_list, _ = self.model(input, states, extrapolation=True)
+ pred_y = torch.stack(predict_list[batch_x.shape[1]:], dim=1)
+ return pred_y
- Args:
- runner: the trainer of methods.
- train_loader: dataloader of train.
- """
- raise NotImplementedError
+ def train_one_epoch(self, runner, train_loader, epoch, num_updates, eta=None, **kwargs):
+ """Train the model with train_loader."""
+ data_time_m = AverageMeter()
+ losses_m = AverageMeter()
+ self.model.train()
+ if self.by_epoch:
+ self.scheduler.step(epoch)
+ train_pbar = tqdm(train_loader) if self.rank == 0 else train_loader
- def _predict(self, batch_x, batch_y, **kwargs):
- """Forward the model.
+ end = time.time()
+ for batch_x, batch_y in train_pbar:
+ data_time_m.update(time.time() - end)
+ self.model_optim.zero_grad()
+
+ batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device)
+ runner.call_hook('before_train_iter')
+
+ with self.amp_autocast():
+ input = torch.cat([batch_x, batch_y], dim=1)
+ states = get_initial_states(input.shape, -2, -1, len(self.args.stack_sizes),
+ self.args.R_stack_sizes, self.args.stack_sizes,
+ -3, self.args.device)
+
+ _, error_list = self.model(input, states, extrapolation=False)
+ loss = self.train_loss.calculate_loss(error_list)
+
+ if not self.dist:
+ losses_m.update(loss.item(), batch_x.size(0))
+
+ if self.loss_scaler is not None:
+ if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
+ raise ValueError(
+ "Inf or nan loss value. Please use fp32 training!")
+ self.loss_scaler(
+ loss, self.model_optim,
+ clip_grad=self.args.clip_grad, clip_mode=self.args.clip_mode,
+ parameters=self.model.parameters())
+ else:
+ loss.backward()
+ self.clip_grads(self.model.parameters())
+
+ self.model_optim.step()
+ torch.cuda.synchronize()
+ num_updates += 1
+
+ if self.dist:
+ losses_m.update(reduce_tensor(loss), batch_x.size(0))
+
+ if not self.by_epoch:
+ self.scheduler.step()
+ runner.call_hook('after_train_iter')
+ runner._iter += 1
+
+ if self.rank == 0:
+ log_buffer = 'train loss: {:.4f}'.format(loss.item())
+ log_buffer += ' | data time: {:.4f}'.format(data_time_m.avg)
+ train_pbar.set_description(log_buffer)
+
+ end = time.time() # end for
+
+ if hasattr(self.model_optim, 'sync_lookahead'):
+ self.model_optim.sync_lookahead()
+
+ return num_updates, losses_m, eta
+
+
+class TrainLossCalculator:
+ def __init__(self, num_layer, timestep, weight_mode, device):
+ self.num_layers = num_layer
+ self.timestep = timestep
+ self.weight_mode = weight_mode
+ self.device = device
+
+ if self.weight_mode == 'L_0':
+ layer_weights = np.array([0. for _ in range(num_layer)])
+ layer_weights[0] = 1.
+ elif self.weight_mode == 'L_all':
+ layer_weights = np.array([0.1 for _ in range(num_layer)])
+ layer_weights[0] = 1.
+ else:
+ raise (RuntimeError('Unknown loss weighting mode! '
+ 'Please use `L_0` or `L_all`.'))
+ self.layer_weights = torch.from_numpy(layer_weights).to(self.device)
+
+ def calculate_loss(self, input):
+ # Weighted by layer
+ error_list = [batch_numLayer_error * self.layer_weights for
+ batch_numLayer_error in input] # Use the broadcast
+ error_list = [torch.sum(error_at_t) for error_at_t in error_list]
+
+ # Weighted by timestep
+ time_loss_weights = torch.cat([torch.tensor([0.], device=self.device),
+ torch.full((self.timestep - 1,),
+ 1. / (self.timestep - 1), device=self.device)])
- Args:
- batch_x, batch_y: testing samples and groung truth.
- """
- raise NotImplementedError
+ total_error = error_list[0] * time_loss_weights[0]
+ for err, time_weight in zip(error_list[1:], time_loss_weights[1:]):
+ total_error += err * time_weight
+ total_error /= input[0].shape[0] # input[0].shape[0] = B
+ return total_error
diff --git a/openstl/models/prednet_model.py b/openstl/models/prednet_model.py
index 74fe876..bd0b664 100644
--- a/openstl/models/prednet_model.py
+++ b/openstl/models/prednet_model.py
@@ -1,203 +1,189 @@
import torch
-from torch import nn
-from torch.nn import functional as F
-
-from openstl.modules import PredNetConvLSTMCell
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+from openstl.utils import get_initial_states
class PredNet_Model(nn.Module):
- r"""PredNet Model
-
- Implementation of `Deep Predictive Coding Networks for Video Prediction
- and Unsupervised Learning `_.
-
- """
-
- def __init__(self, configs, output_mode='error', **kwargs):
+ def __init__(self, stack_sizes, R_stack_sizes,
+ A_filt_sizes, Ahat_filt_sizes, R_filt_sizes,
+ pixel_max=1., args=None):
super(PredNet_Model, self).__init__()
- self.configs = configs
- self.in_shape = configs.in_shape
- _, _, H, W = configs.in_shape
-
- self.a_channels = getattr(configs, "A_channels", (3, 48, 96, 192))
- self.r_channels = getattr(configs, "R_channels", (3, 48, 96, 192))
- self.n_layers = len(self.r_channels)
- self.r_channels += (0, ) # for convenience
- self.output_mode = output_mode
- self.gating_mode = getattr(configs, "gating_mode", 'mul')
- self.extrap_start_time = getattr(configs, "extrap_start_time", None)
- self.peephole = getattr(configs, "peephole", False)
- self.lstm_tied_bias = getattr(configs, "lstm_tied_bias", False)
- self.p_max = getattr(configs, "p_max", 1.0)
-
- # Input validity checks
- default_output_modes = ['prediction', 'error', 'pred+err']
- layer_output_modes = [unit + str(l) for l in range(self.n_layers) for unit in ['R', 'E', 'A', 'Ahat']]
- default_gating_modes = ['mul', 'sub']
- assert output_mode in default_output_modes + layer_output_modes, 'Invalid output_mode: ' + str(output_mode)
- assert self.gating_mode in default_gating_modes, 'Invalid gating_mode: ' + str(self.gating_mode)
-
- if self.output_mode in layer_output_modes:
- self.output_layer_type = self.output_mode[:-1]
- self.output_layer_num = int(self.output_mode[-1])
- else:
- self.output_layer_type = None
- self.output_layer_num = None
-
- # h, w = self.input_size
-
- for i in range(self.n_layers):
- # A_channels multiplied by 2 because E_l concactenates pred-target and target-pred
- # Hidden states don't have same size due to upsampling
- # How does this handle i = L-1 (final layer) | appends a zero
-
- if self.gating_mode == 'mul':
- cell = PredNetConvLSTMCell((H, W), 2 * self.a_channels[i] + self.r_channels[i+1], self.r_channels[i],
- (3, 3), gating_mode='mul', peephole=self.peephole, tied_bias=self.lstm_tied_bias)
- elif self.gating_mode == 'sub':
- cell = PredNetConvLSTMCell((H, W), 2 * self.a_channels[i] + self.r_channels[i+1], self.r_channels[i],
- (3, 3), gating_mode='sub', peephole=self.peephole, tied_bias=self.lstm_tied_bias)
-
- setattr(self, 'cell{}'.format(i), cell)
- H = H // 2
- W = W // 2
-
- for i in range(self.n_layers):
- # Calculate predictions A_hat
- conv = nn.Sequential(nn.Conv2d(self.r_channels[i], self.a_channels[i], 3, padding=1), nn.ReLU())
- setattr(self, 'conv{}'.format(i), conv)
-
- self.upsample = nn.Upsample(scale_factor=2)
- self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
-
- for l in range(self.n_layers - 1):
- # Propagate error as next layer's target (line 16 of Lotter algo)
- # In channels = 2 * A_channels[l] because of pos/neg error concat
- # NOTE: Operation belongs to curr layer l and produces next layer state l+1
-
- update_A = nn.Sequential(nn.Conv2d(2* self.a_channels[l], self.a_channels[l+1], (3, 3), padding=1), self.maxpool)
- setattr(self, 'update_A{}'.format(l), update_A)
-
- self.criterion = nn.MSELoss()
-
- def set_output_mode(self, output_mode):
- self.output_mode = output_mode
-
- # Input validity checks
- default_output_modes = ['prediction', 'error', 'pred+err']
- layer_output_modes = [unit + str(l) for l in range(self.n_layers) for unit in ['R', 'E', 'A', 'Ahat']]
- assert output_mode in default_output_modes + layer_output_modes, 'Invalid output_mode: ' + str(output_mode)
-
- if self.output_mode in layer_output_modes:
- self.output_layer_type = self.output_mode[:-1]
- self.output_layer_num = int(self.output_mode[-1])
- else:
- self.output_layer_type = None
- self.output_layer_num = None
-
- def step(self, a, states):
- batch_size = a.size(0)
- R_layers = states[:self.n_layers]
- C_layers = states[self.n_layers:2*self.n_layers]
- E_layers = states[2*self.n_layers:3*self.n_layers]
-
- if self.extrap_start_time is not None:
- t = states[-1]
- if t >= self.extrap_start_time: # if past self.extra_start_time use previous prediction as input
- a = states[-2]
-
- # Update representation units
- for l in reversed(range(self.n_layers)):
- cell = getattr(self, 'cell{}'.format(l))
- r_tm1 = R_layers[l]
- c_tm1 = C_layers[l]
- e_tm1 = E_layers[l]
- if l == self.n_layers - 1:
- r, c = cell(e_tm1, (r_tm1, c_tm1))
- else:
- tmp = torch.cat((e_tm1, self.upsample(R_layers[l+1])), 1)
- r, c = cell(tmp, (r_tm1, c_tm1))
- R_layers[l] = r
- C_layers[l] = c
-
- # Perform error forward pass
- for l in range(self.n_layers):
- conv = getattr(self, 'conv{}'.format(l))
- a_hat = conv(R_layers[l])
+ self.args = args
+ self.stack_sizes = stack_sizes
+ self.num_layers = len(stack_sizes)
+ assert len(R_stack_sizes) == self.num_layers
+ self.R_stack_sizes = R_stack_sizes
+ assert len(A_filt_sizes) == self.num_layers - 1
+ self.A_filt_sizes = A_filt_sizes
+ assert len(Ahat_filt_sizes) == self.num_layers
+ self.Ahat_filt_sizes = Ahat_filt_sizes
+ assert len(R_filt_sizes) == self.num_layers
+ self.R_filt_sizes = R_filt_sizes
+
+ self.pixel_max = pixel_max
+ self.error_activation = args.error_activation # 'relu'
+ self.A_activation = args.A_activation # 'relu'
+ self.LSTM_activation = args.LSTM_activation # 'tanh'
+ self.LSTM_inner_activation = args.LSTM_inner_activation # 'hard_sigmoid'
+ self.channel_axis = -3
+ self.row_axis = -2
+ self.col_axis = -1
+
+ self.get_activationFunc = {
+ 'relu': nn.ReLU(),
+ 'tanh': nn.Tanh(),
+ }
+
+ self.build_layers()
+ self.init_weights()
+
+ def init_weights(self):
+ def init_layer_weights(layer):
+ if isinstance(layer, nn.Conv2d):
+ layer.bias.data.zero_()
+ self.apply(init_layer_weights)
+
+ def hard_sigmoid(self, x, slope=0.2, shift=0.5):
+ x = (slope * x) + shift
+ x = torch.clamp(x, 0, 1)
+ return x
+
+ def batch_flatten(self, x):
+ shape = [*x.size()]
+ dim = np.prod(shape[1:])
+ return x.view(-1, int(dim))
+
+ def isNotTopestLayer(self, layerIndex):
+ '''judge if the layerIndex is not the topest layer.'''
+ return True if layerIndex < self.num_layers - 1 else False
+
+ def build_layers(self):
+ # i: input, f: forget, c: cell, o: output
+ self.conv_layers = {item: []
+ for item in ['i', 'f', 'c', 'o', 'A', 'Ahat']}
+ lstm_list = ['i', 'f', 'c', 'o']
+
+ for item in sorted(self.conv_layers.keys()):
+ for l in range(self.num_layers):
+ if item == 'Ahat':
+ in_channels = self.R_stack_sizes[l]
+ self.conv_layers['Ahat'].append(nn.Conv2d(in_channels=in_channels,
+ out_channels=self.stack_sizes[
+ l], kernel_size=self.Ahat_filt_sizes[l],
+ stride=(1, 1), padding=int((self.Ahat_filt_sizes[l] - 1) / 2)))
+ act = 'relu' if l == 0 else self.A_activation
+ self.conv_layers['Ahat'].append(self.get_activationFunc[act])
+ elif item == 'A':
+ if self.isNotTopestLayer(l):
+ in_channels = self.R_stack_sizes[l] * 2
+ self.conv_layers['A'].append(nn.Conv2d(in_channels=in_channels,
+ out_channels=self.stack_sizes[l + 1], kernel_size=self.A_filt_sizes[l], stride=(1, 1), padding=int((self.A_filt_sizes[l] - 1) / 2)))
+ self.conv_layers['A'].append(self.get_activationFunc[self.A_activation])
+ elif item in lstm_list: # build R module
+ in_channels = self.stack_sizes[l] * \
+ 2 + self.R_stack_sizes[l]
+ if self.isNotTopestLayer(l):
+ in_channels += self.R_stack_sizes[l + 1]
+ self.conv_layers[item].append(nn.Conv2d(in_channels=in_channels, out_channels=self.R_stack_sizes[l],
+ kernel_size=self.R_filt_sizes[l], stride=(1, 1), padding=int((self.R_filt_sizes[l] - 1) / 2)))
+
+ for name, layerList in self.conv_layers.items():
+ self.conv_layers[name] = nn.ModuleList(
+ layerList).to(self.args.device)
+ setattr(self, name, self.conv_layers[name])
+
+ self.upSample = nn.Upsample(scale_factor=2, mode='nearest')
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
+
+ def step(self, A, states, extrapolation=False):
+ n = self.num_layers
+ R_current = states[:(n)]
+ C_current = states[(n):(2 * n)]
+ E_current = states[(2 * n):(3 * n)]
+
+ timestep = states[-1]
+ if extrapolation == True and timestep >= self.args.in_shape[0]:
+ A = states[-2]
+
+ R_list, C_list, E_list = [], [], []
+
+ for l in reversed(range(self.num_layers)):
+ inputs = [R_current[l], E_current[l]]
+ if self.isNotTopestLayer(l):
+ inputs.append(R_up)
+ inputs = torch.cat(inputs, dim=self.channel_axis)
+
+ in_gate = self.hard_sigmoid(self.conv_layers['i'][l](inputs))
+ forget_gate = self.hard_sigmoid(self.conv_layers['f'][l](inputs))
+ cell_gate = F.tanh(self.conv_layers['c'][l](inputs))
+ out_gate = self.hard_sigmoid(self.conv_layers['o'][l](inputs))
+ C_next = (forget_gate * C_current[l]) + (in_gate * cell_gate)
+ R_next = out_gate * F.tanh(C_next)
+
+ C_list.insert(0, C_next)
+ R_list.insert(0, R_next)
+
+ if l > 0:
+ R_up = self.upSample(R_next)
+
+ for l in range(self.num_layers):
+ Ahat = self.conv_layers['Ahat'][2 * l](R_list[l]) # ConvLayer
+ Ahat = self.conv_layers['Ahat'][2 *
+ l + 1](Ahat) # activation function
+
if l == 0:
- a_hat= torch.min(a_hat, torch.tensor(self.p_max).to(self.configs.device)) # alternative SatLU (Lotter)
- frame_prediction = a_hat
- pos = F.relu(a_hat - a)
- neg = F.relu(a - a_hat)
- e = torch.cat([pos, neg],1)
- E_layers[l] = e
-
- # Handling layer-specific outputs
- if self.output_layer_num == l:
- if self.output_layer_type == 'A':
- output = a
- elif self.output_layer_type == 'Ahat':
- output = a_hat
- elif self.output_layer_type == 'R':
- output = R_layers[l]
- elif self.output_layer_type == 'E':
- output = E_layers[l]
-
- if l < self.n_layers - 1: # updating A for next layer
- update_A = getattr(self, 'update_A{}'.format(l))
- a = update_A(e)
-
- if self.output_layer_type is None:
- if self.output_mode == 'prediction':
- output = frame_prediction
+ Ahat = torch.clamp(Ahat, max=self.pixel_max)
+ frame_prediction = Ahat
+
+ if self.error_activation.lower() == 'relu':
+ E_up = F.relu(Ahat - A)
+ E_down = F.relu(A - Ahat)
+ elif self.error_activation.lower() == 'tanh':
+ E_up = F.tanh(Ahat - A)
+ E_down = F.tanh(A - Ahat)
else:
- # Batch flatten (return 2D matrix) then mean over units
- # Finally, concatenate layers (batch, n_layers)
- mean_E_layers = torch.cat([torch.mean(e.view(batch_size, -1), axis=1, keepdim=True) for e in E_layers], axis=1)
- if self.output_mode == 'error':
- output = mean_E_layers
- else:
- output = torch.cat([frame_prediction.view(batch_size, -1), mean_E_layers], axis=1)
-
- states = R_layers + C_layers + E_layers
- if self.extrap_start_time is not None:
- states += [frame_prediction, t+1]
- return output, states
-
- def forward(self, input_tensor, **kwargs):
-
- R_layers = [None] * self.n_layers
- C_layers = [None] * self.n_layers
- E_layers = [None] * self.n_layers
-
- _, _, h, w = self.in_shape
- batch_size = input_tensor.size(0)
-
- # Initialize states
- for l in range(self.n_layers):
- R_layers[l] = torch.zeros(batch_size, self.r_channels[l], h, w, requires_grad=True).to(self.configs.device)
- C_layers[l] = torch.zeros(batch_size, self.r_channels[l], h, w, requires_grad=True).to(self.configs.device)
- E_layers[l] = torch.zeros(batch_size, 2*self.a_channels[l], h, w, requires_grad=True).to(self.configs.device)
- # Size of hidden state halves from each layer to the next
- h = h//2
- w = w//2
-
- states = R_layers + C_layers + E_layers
- # Initialize previous_prediction
- if self.extrap_start_time is not None:
- frame_prediction = torch.zeros_like(input_tensor[:,0], dtype=torch.float32).to(self.configs.device)
- states += [frame_prediction, -1] # [a, t]
-
- num_time_steps = input_tensor.size(1)
- total_output = [] # contains output sequence
- for t in range(num_time_steps):
- a = input_tensor[:,t].type(torch.FloatTensor).to(self.configs.device)
- output, states = self.step(a, states)
- total_output.append(output)
-
- ax = len(output.shape)
- # print(output.shape)
- total_output = [out.view(out.shape + (1,)) for out in total_output]
- total_output = torch.cat(total_output, axis=ax) # (batch, ..., nt)
-
- return total_output
+ raise (RuntimeError(
+ 'cannot obtain the activation function named %s' % self.error_activation))
+
+ E_list.append(torch.cat((E_up, E_down), dim=self.channel_axis))
+
+ if self.isNotTopestLayer(l):
+ A = self.conv_layers['A'][2 * l](E_list[l])
+ A = self.conv_layers['A'][2 * l + 1](A)
+ A = self.pool(A) # target for next layer
+
+ for l in range(self.num_layers):
+ layer_error = torch.mean(self.batch_flatten(
+ E_list[l]), dim=-1, keepdim=True)
+ all_error = layer_error if l == 0 else torch.cat(
+ (all_error, layer_error), dim=-1)
+
+ states = R_list + C_list + E_list
+ predict = frame_prediction
+ error = all_error
+ states += [frame_prediction, timestep + 1]
+ return predict, error, states
+
+ def forward(self, A0_withTimeStep, initial_states=None, extrapolation=False):
+ '''
+ A0_withTimeStep is the input from dataloader.
+ Its shape is: (batch_size, timesteps, Channel, Height, Width).
+
+ '''
+ if initial_states is None:
+ initial_states = get_initial_states((1, 10, 1, 64, 64),
+ self.row_axis, self.col_axis, self.num_layers, self.R_stack_sizes, self.stack_sizes, self.channel_axis, self.args.device)
+ A0_withTimeStep = A0_withTimeStep.transpose(0, 1)
+ num_timesteps = A0_withTimeStep.shape[0]
+
+ hidden_states = initial_states
+ predict_list, error_list = [], []
+ for t in range(num_timesteps):
+ A0 = A0_withTimeStep[t, ...]
+ predict, error, hidden_states = self.step(
+ A0, hidden_states, extrapolation)
+ predict_list.append(predict)
+ error_list.append(error)
+ return predict_list, error_list
diff --git a/openstl/utils/__init__.py b/openstl/utils/__init__.py
index b6e8da4..46415c7 100644
--- a/openstl/utils/__init__.py
+++ b/openstl/utils/__init__.py
@@ -10,8 +10,9 @@
from .parser import create_parser
from .predrnn_utils import (reserve_schedule_sampling_exp, schedule_sampling, reshape_patch,
reshape_patch_back)
-from .progressbar import ProgressBar, Timer
from .dmvfn_utils import LapLoss, VGGPerceptualLoss
+from .prednet_utils import get_initial_states
+from .progressbar import ProgressBar, Timer
@@ -23,6 +24,7 @@
'get_dataset', 'count_parameters', 'measure_throughput', 'load_config', 'update_config', 'weights_to_cpu',
'init_dist', 'init_random_seed', 'get_dist_info', 'reduce_tensor',
'reserve_schedule_sampling_exp', 'schedule_sampling', 'reshape_patch', 'reshape_patch_back',
+ 'LapLoss', 'VGGPerceptualLoss',
+ 'get_initial_states',
'ProgressBar', 'Timer',
- 'LapLoss', 'VGGPerceptualLoss'
]
\ No newline at end of file
diff --git a/openstl/utils/prednet_utils.py b/openstl/utils/prednet_utils.py
new file mode 100644
index 0000000..bff7f95
--- /dev/null
+++ b/openstl/utils/prednet_utils.py
@@ -0,0 +1,47 @@
+import numpy as np
+import torch
+
+
+def get_initial_states(input_shape, row_axis, col_axis, num_layers,
+ R_stack_sizes, stack_sizes, channel_axis,
+ device):
+ # input_shape.shape: (batch_size, timeSteps, Channel, Height, Width)
+ init_height = input_shape[row_axis]
+ init_width = input_shape[col_axis]
+
+ base_initial_state = np.zeros(input_shape)
+ non_channel_axis = -1
+ for _ in range(2):
+ base_initial_state = np.sum(base_initial_state, axis=non_channel_axis)
+ base_initial_state = np.sum(base_initial_state, axis=1) # (batch_size, Channel)
+
+ initial_states = []
+ states_to_pass = ['R', 'C', 'E'] # R is `representation`, C is Cell state in ConvLSTM, E is `error`.
+ num_layer_to_pass = {stp: num_layers for stp in states_to_pass}
+ states_to_pass.append('Ahat') # pass prediction in states so can use as actual for t+1 when extrapolating
+ num_layer_to_pass['Ahat'] = 1
+
+ for stp in states_to_pass:
+ for l in range(num_layer_to_pass[stp]):
+ downsample_factor = 2 ** l
+ row = init_height // downsample_factor
+ col = init_width // downsample_factor
+ if stp in ['R', 'C']:
+ stack_size = R_stack_sizes[l]
+ elif stp == 'E':
+ stack_size = stack_sizes[l] * 2
+ elif stp == 'Ahat':
+ stack_size = stack_sizes[l]
+
+ output_size = stack_size * row * col # flattened size
+ reducer = np.zeros((input_shape[channel_axis], output_size)) # (Channel, output_size)
+ initial_state = np.dot(base_initial_state, reducer) # (batch_size, output_size)
+
+ output_shape = (-1, stack_size, row, col)
+ initial_state = torch.from_numpy(np.reshape(initial_state, output_shape)).float().to(
+ device).requires_grad_() # requires_grad=True
+ initial_states += [initial_state]
+
+ initial_states += [
+ torch.zeros(1, dtype=torch.int).to(device)] # the last state will correspond to the current timestep
+ return initial_states
\ No newline at end of file