Skip to content

Commit

Permalink
add prednet
Browse files Browse the repository at this point in the history
  • Loading branch information
chengtan9907 committed May 24, 2023
1 parent 5d4743b commit 2162466
Show file tree
Hide file tree
Showing 7 changed files with 366 additions and 222 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ We support various spatiotemporal prediction methods and will provide benchmarks
<summary>Currently supported methods</summary>

- [x] [ConvLSTM](https://arxiv.org/abs/1506.04214) (NeurIPS'2015)
- [x] [PredNet](https://openreview.net/forum?id=B1ewdt9xe) (ICLR'2017)
- [x] [PredRNN](https://dl.acm.org/doi/abs/10.5555/3294771.3294855) (NeurIPS'2017)
- [x] [PredRNN++](https://arxiv.org/abs/1804.06300) (ICML'2018)
- [x] [E3D-LSTM](https://openreview.net/forum?id=B1lKS2AqtX) (ICLR'2018)
Expand Down Expand Up @@ -171,7 +172,7 @@ This project is released under the [Apache 2.0 license](LICENSE). See `LICENSE`

## Acknowledgement

OpenSTL is an open-source project for STL algorithms created by researchers in **CAIRI AI Lab**. We encourage researchers interested in video and weather prediction to contribute to OpenSTL! We borrow the official implementations of [ConvLSTM](https://arxiv.org/abs/1506.04214), [PredRNN](https://dl.acm.org/doi/abs/10.5555/3294771.3294855) variants, [E3D-LSTM](https://openreview.net/forum?id=B1lKS2AqtX), [MAU](https://arxiv.org/abs/1811.07490), [CrevNet](https://openreview.net/forum?id=B1lKS2AqtX), [PhyDNet](https://arxiv.org/abs/2003.01460), [DMVFN](https://arxiv.org/abs/2303.09875).
OpenSTL is an open-source project for STL algorithms created by researchers in **CAIRI AI Lab**. We encourage researchers interested in video and weather prediction to contribute to OpenSTL! We borrow the official implementations of [ConvLSTM](https://arxiv.org/abs/1506.04214), [PredNet](https://arxiv.org/abs/1605.08104), [PredRNN](https://dl.acm.org/doi/abs/10.5555/3294771.3294855) variants, [E3D-LSTM](https://openreview.net/forum?id=B1lKS2AqtX), [MAU](https://arxiv.org/abs/1811.07490), [CrevNet](https://openreview.net/forum?id=B1lKS2AqtX), [PhyDNet](https://arxiv.org/abs/2003.01460), and [DMVFN](https://arxiv.org/abs/2303.09875).

## Citation

Expand Down
12 changes: 12 additions & 0 deletions configs/mmnist/PredNet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
method = 'PredNet'
stack_sizes = (1, 32, 64, 128, 256) # 1 refer to num of channel(input)
R_stack_sizes = stack_sizes
A_filt_sizes = (3, 3, 3, 3)
Ahat_filt_sizes = (3, 3, 3, 3, 3)
R_filt_sizes = (3, 3, 3, 3, 3)
pixel_max = 1.0
weight_mode = 'L_0'
error_activation = 'relu'
A_activation = 'relu'
LSTM_activation = 'tanh'
LSTM_inner_activation = 'hard_sigmoid'
6 changes: 4 additions & 2 deletions openstl/api/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,8 @@ def display_method_info(self):
input_dummy = (_tmp_input, _tmp_flag)
elif self.args.method == 'dmvfn':
input_dummy = torch.ones(1, 3, C, H, W, requires_grad=True).to(self.device)
elif self.args.method == 'prednet':
input_dummy = torch.ones(1, 10, C, H, W, requires_grad=True).to(self.device)
else:
raise ValueError(f'Invalid method name {self.args.method}')

Expand Down Expand Up @@ -296,7 +298,7 @@ def train(self):
cur_lr = self.method.current_lr()
cur_lr = sum(cur_lr) / len(cur_lr)
with torch.no_grad():
vali_loss = self.vali(self.vali_loader)
vali_loss = self.vali()

if self._rank == 0:
print_log('Epoch: {0}, Steps: {1} | Lr: {2:.7f} | Train Loss: {3:.7f} | Vali Loss: {4:.7f}\n'.format(
Expand All @@ -313,7 +315,7 @@ def train(self):
time.sleep(1) # wait for some hooks like loggers to finish
self.call_hook('after_run')

def vali(self, vali_loader):
def vali(self):
"""A validation loop during training"""
self.call_hook('before_val_epoch')
results, eval_log = self.method.vali_one_epoch(self, self.vali_loader)
Expand Down
136 changes: 115 additions & 21 deletions openstl/methods/prednet.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import time
import torch
import torch.nn as nn
from tqdm import tqdm
import numpy as np
from timm.utils import AverageMeter
from tqdm import tqdm

from openstl.models import PredNet_Model
from openstl.utils import reduce_tensor
from .base_method import Base_method
from openstl.utils import (reduce_tensor, get_initial_states)
from openstl.methods.base_method import Base_method


class PredNet(Base_method):
Expand All @@ -20,28 +19,123 @@ class PredNet(Base_method):

def __init__(self, args, device, steps_per_epoch):
Base_method.__init__(self, args, device, steps_per_epoch)
self.model = self._build_model(self.config)
self.model = self._build_model(self.args)
self.model_optim, self.scheduler, self.by_epoch = self._init_optimizer(steps_per_epoch)
self.criterion = nn.MSELoss()

self.constraints = self._get_constraints()
self.train_loss = TrainLossCalculator(num_layer=len(self.args.stack_sizes), timestep=self.args.pre_seq_length +
self.args.aft_seq_length, weight_mode=self.args.weight_mode, device=self.device)

def _build_model(self, args):
return PredNet_Model(args, output_mode='error').to(self.device)
return PredNet_Model(args.stack_sizes, args.R_stack_sizes,
args.A_filt_sizes, args.Ahat_filt_sizes,
args.R_filt_sizes, args.pixel_max, args)

def train_one_epoch(self, runner, train_loader, **kwargs):
"""Train the model with train_loader.
def _predict(self, batch_x, batch_y, **kwargs):
input = torch.cat([batch_x, batch_y], dim=1)
states = get_initial_states(input.shape, -2, -1, len(self.args.stack_sizes),
self.args.R_stack_sizes, self.args.stack_sizes,
-3, self.args.device)
predict_list, _ = self.model(input, states, extrapolation=True)
pred_y = torch.stack(predict_list[batch_x.shape[1]:], dim=1)
return pred_y

Args:
runner: the trainer of methods.
train_loader: dataloader of train.
"""
raise NotImplementedError
def train_one_epoch(self, runner, train_loader, epoch, num_updates, eta=None, **kwargs):
"""Train the model with train_loader."""
data_time_m = AverageMeter()
losses_m = AverageMeter()
self.model.train()
if self.by_epoch:
self.scheduler.step(epoch)
train_pbar = tqdm(train_loader) if self.rank == 0 else train_loader

def _predict(self, batch_x, batch_y, **kwargs):
"""Forward the model.
end = time.time()
for batch_x, batch_y in train_pbar:
data_time_m.update(time.time() - end)
self.model_optim.zero_grad()

batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device)
runner.call_hook('before_train_iter')

with self.amp_autocast():
input = torch.cat([batch_x, batch_y], dim=1)
states = get_initial_states(input.shape, -2, -1, len(self.args.stack_sizes),
self.args.R_stack_sizes, self.args.stack_sizes,
-3, self.args.device)

_, error_list = self.model(input, states, extrapolation=False)
loss = self.train_loss.calculate_loss(error_list)

if not self.dist:
losses_m.update(loss.item(), batch_x.size(0))

if self.loss_scaler is not None:
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
raise ValueError(
"Inf or nan loss value. Please use fp32 training!")
self.loss_scaler(
loss, self.model_optim,
clip_grad=self.args.clip_grad, clip_mode=self.args.clip_mode,
parameters=self.model.parameters())
else:
loss.backward()
self.clip_grads(self.model.parameters())

self.model_optim.step()
torch.cuda.synchronize()
num_updates += 1

if self.dist:
losses_m.update(reduce_tensor(loss), batch_x.size(0))

if not self.by_epoch:
self.scheduler.step()
runner.call_hook('after_train_iter')
runner._iter += 1

if self.rank == 0:
log_buffer = 'train loss: {:.4f}'.format(loss.item())
log_buffer += ' | data time: {:.4f}'.format(data_time_m.avg)
train_pbar.set_description(log_buffer)

end = time.time() # end for

if hasattr(self.model_optim, 'sync_lookahead'):
self.model_optim.sync_lookahead()

return num_updates, losses_m, eta


class TrainLossCalculator:
def __init__(self, num_layer, timestep, weight_mode, device):
self.num_layers = num_layer
self.timestep = timestep
self.weight_mode = weight_mode
self.device = device

if self.weight_mode == 'L_0':
layer_weights = np.array([0. for _ in range(num_layer)])
layer_weights[0] = 1.
elif self.weight_mode == 'L_all':
layer_weights = np.array([0.1 for _ in range(num_layer)])
layer_weights[0] = 1.
else:
raise (RuntimeError('Unknown loss weighting mode! '
'Please use `L_0` or `L_all`.'))
self.layer_weights = torch.from_numpy(layer_weights).to(self.device)

def calculate_loss(self, input):
# Weighted by layer
error_list = [batch_numLayer_error * self.layer_weights for
batch_numLayer_error in input] # Use the broadcast
error_list = [torch.sum(error_at_t) for error_at_t in error_list]

# Weighted by timestep
time_loss_weights = torch.cat([torch.tensor([0.], device=self.device),
torch.full((self.timestep - 1,),
1. / (self.timestep - 1), device=self.device)])

Args:
batch_x, batch_y: testing samples and groung truth.
"""
raise NotImplementedError
total_error = error_list[0] * time_loss_weights[0]
for err, time_weight in zip(error_list[1:], time_loss_weights[1:]):
total_error += err * time_weight
total_error /= input[0].shape[0] # input[0].shape[0] = B
return total_error
Loading

0 comments on commit 2162466

Please sign in to comment.