Skip to content

Commit

Permalink
Merge pull request chengtan9907#5 from Lupin1998/master
Browse files Browse the repository at this point in the history
update data and metaformers
  • Loading branch information
chengtan9907 committed Feb 15, 2023
2 parents b14d30c + 88627a5 commit 1312316
Show file tree
Hide file tree
Showing 76 changed files with 1,717 additions and 234 deletions.
137 changes: 133 additions & 4 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,134 @@
__pycache__
/results
/figs
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
apex/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/

.vscode
.ipynb_checkpoints
.idea

# custom
*.pkl
*.pkl.json
*.log.json
*.jpg
bash
data
data_set
output/
work_dirs/
workspace/
tools/exp_bash/
pretrains

# Pytorch
*.pth

*.swp
.DS_Store
*.json

results
figs
.vscode
.ipynb_checkpoints
2 changes: 0 additions & 2 deletions API/__init__.py

This file was deleted.

6 changes: 0 additions & 6 deletions API/dataloader.py

This file was deleted.

18 changes: 11 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ This repository contains the implementation code for paper:

**SimVP: Towards Simple yet Powerful Spatiotemporal Predictive learning**
[Cheng Tan](https://westlake-drug-discovery.github.io/cheng_tan.html), [Zhangyang Gao](https://westlake-drug-discovery.github.io/zhangyang_gao.html), [Stan Z. Li](https://scholar.google.com/citations?user=Y-nyLGIAAAAJ&hl).

## Introduction

This is the journal version of our previous conference work (SimVP: Simpler yet Better Video Prediction, In CVPR 2022).
Expand All @@ -24,18 +25,21 @@ Quantitative results of different methods on the Moving MNIST dataset ($10 \righ
</p>

## Dependencies
* torch
* scikit-image=0.16.2
* numpy
* argparse
* numpy
* hickle
* scikit-image=0.16.2
* torch
* timm
* tqdm

## Overview

* `API/` contains dataloaders and metrics.
* `main.py` is the executable python file with possible arguments.
* `model.py` contains the SimVP model.
* `exp.py` is the core file for training, validating, and testing pipelines.
* `api/` contains dataloaders and metrics.
* `methods/` contains training methods for various video prediction methods.
* `models/` contains the main network architectures of various video prediction methods.
* `modules/` contains network modules and layers.
* `main.py` is the executable python file with possible arguments for training, validating, and testing pipelines.

## Install

Expand Down
9 changes: 9 additions & 0 deletions api/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from .dataloader_moving_mnist import MovingMNIST
from .dataloader_kitticaltech import KittiCaltechDataset
from .dataloader import load_data
from .metrics import metric
from .recorder import Recorder

__all__ = [
'MovingMNIST', 'KittiCaltechDataset', 'load_data', 'metric', 'Recorder',
]
13 changes: 13 additions & 0 deletions api/dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@


def load_data(dataname, batch_size, val_batch_size, num_workers, data_root, **kwargs):
pre_seq_length = kwargs.get('', 10)
aft_seq_length = kwargs.get('', 10)
if dataname == 'mmnist':
from .dataloader_moving_mnist import load_data
return load_data(batch_size, val_batch_size, num_workers, data_root, pre_seq_length, aft_seq_length)
elif dataname == 'kitticaltech':
from .dataloader_kitticaltech import load_data
return load_data(batch_size, val_batch_size, num_workers, data_root, pre_seq_length, aft_seq_length)
else:
raise ValueError(f'Dataname {dataname} is unsupported')
137 changes: 137 additions & 0 deletions api/dataloader_kitticaltech.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import logging
logger = logging.getLogger(__name__)

import os
import os.path as osp
import cv2
import numpy as np
from torch.utils.data import Dataset
import torch
try:
import hickle as hkl
except ImportError:
hkl = None
from skimage.transform import resize

# cite the `process_im` code from PredNet, Thanks!
# https://github.com/coxlab/prednet/blob/master/process_kitti.py
def process_im(im, desired_sz):
target_ds = float(desired_sz[0])/im.shape[0]
im = resize(im, (desired_sz[0], int(np.round(target_ds * im.shape[1]))), preserve_range=True)
d = int((im.shape[1] - desired_sz[1]) / 2)
im = im[:, d:d+desired_sz[1]]
return im


class KittiCaltechDataset(Dataset):
def __init__(self, datas, indices, pre_seq_length, aft_seq_length, require_back=False):
super(KittiCaltechDataset, self).__init__()
self.datas = datas.swapaxes(2, 3).swapaxes(1, 2)
self.indices = indices
self.pre_seq_length = pre_seq_length
self.aft_seq_length = aft_seq_length
self.require_back = require_back
self.mean = 0
self.std = 1

def __len__(self):
return len(self.indices)

def __getitem__(self, i):
batch_ind = self.indices[i]
begin = batch_ind
end1 = begin + self.pre_seq_length
end2 = end1 + self.aft_seq_length
data = torch.tensor(self.datas[begin:end1, ::]).float()
labels = torch.tensor(self.datas[end1:end2, ::]).float()
return data, labels


class DataProcess(object):
def __init__(self, input_param):
self.paths = input_param['paths']
self.seq_len = input_param['seq_length']

def load_data(self, mode='train'):
"""Loads the dataset.
Args:
paths: paths of train/test dataset.
mode: Training or testing.
Returns:
A dataset and indices of the sequence.
"""
if mode == 'train' or mode == 'val':
kitti_root = self.paths['kitti']
data = hkl.load(osp.join(kitti_root, 'X_' + mode + '.hkl'))
data = data.astype('float') / 255.0
fileidx = hkl.load(
osp.join(kitti_root, 'sources_' + mode + '.hkl'))

indices = []
index = len(fileidx) - 1
while index >= self.seq_len - 1:
if fileidx[index] == fileidx[index - self.seq_len + 1]:
indices.append(index - self.seq_len + 1)
index -= self.seq_len - 1
index -= 1

elif mode == 'test':
caltech_root = self.paths['caltech']
data = []
fileidx = []
for seq_id in os.listdir(caltech_root):
if osp.isdir(osp.join(caltech_root, seq_id)) is False:
continue
for item in os.listdir(osp.join(caltech_root, seq_id)):
cap = cv2.VideoCapture(
osp.join(caltech_root, seq_id, item))
cnt_frames = 0
while True:
ret, frame = cap.read()
if not ret:
break
cnt_frames += 1
if cnt_frames % 3 == 0:
frame = process_im(frame, (128, 160)) / 255.0
data.append(frame)
fileidx.append(seq_id + item)
data = np.asarray(data)

indices = []
index = len(fileidx) - 1
while index >= self.seq_len - 1:
if fileidx[index] == fileidx[index - self.seq_len + 1]:
indices.append(index - self.seq_len + 1)
index -= self.seq_len - 1
index -= 1

return data, indices


def load_data(batch_size, val_batch_size, num_workers, data_root, pre_seq_length=10, aft_seq_length=10):
input_param = {
'paths': {'kitti': osp.join(data_root, 'kitti_hkl'),
'caltech': osp.join(data_root, 'caltech')},
'seq_length': (pre_seq_length + aft_seq_length),
'input_data_type': 'float32',
}
input_handle = DataProcess(input_param)
train_data, train_idx = input_handle.load_data('train')
test_data, test_idx = input_handle.load_data('val')

# train_data = np.load(osp.join(data_root, 'kitticaltech_saved', 'train_data.npy'))
# train_idx = np.load(osp.join(data_root, 'kitticaltech_saved', 'train_idx.npy'))
# test_data = np.load(osp.join(data_root, 'kitticaltech_saved', 'test_data.npy'))
# test_idx = np.load(osp.join(data_root, 'kitticaltech_saved', 'test_idx.npy'))

train_set = KittiCaltechDataset(
train_data, train_idx, pre_seq_length, aft_seq_length)
test_set = KittiCaltechDataset(
test_data, test_idx, pre_seq_length, aft_seq_length)

dataloader_train = torch.utils.data.DataLoader(
train_set, batch_size=batch_size, shuffle=True, pin_memory=True, drop_last=True, num_workers=num_workers)
dataloader_test = torch.utils.data.DataLoader(
test_set, batch_size=val_batch_size, shuffle=False, pin_memory=True, drop_last=True, num_workers=num_workers)

return dataloader_train, dataloader_test, dataloader_test
Original file line number Diff line number Diff line change
Expand Up @@ -148,12 +148,12 @@ def __len__(self):
return self.length


def load_data(batch_size, val_batch_size, num_workers, data_root):
def load_data(batch_size, val_batch_size, num_workers, data_root, pre_seq_length=10, aft_seq_length=10):

train_set = MovingMNIST(root=data_root, is_train=True,
n_frames_input=10, n_frames_output=10, num_objects=[2])
n_frames_input=pre_seq_length, n_frames_output=aft_seq_length, num_objects=[2])
test_set = MovingMNIST(root=data_root, is_train=False,
n_frames_input=10, n_frames_output=10, num_objects=[2])
n_frames_input=pre_seq_length, n_frames_output=aft_seq_length, num_objects=[2])

dataloader_train = torch.utils.data.DataLoader(
train_set, batch_size=batch_size, shuffle=True, pin_memory=True, drop_last=True, num_workers=num_workers)
Expand Down
6 changes: 5 additions & 1 deletion API/metrics.py → api/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,22 @@
def rescale(x):
return (x - x.max()) / (x.max() - x.min()) * 2 - 1


def MAE(pred, true):
return np.mean(np.abs(pred-true), axis=(0, 1)).sum()


def MSE(pred, true):
return np.mean((pred-true)**2, axis=(0, 1)).sum()


# cite the `PSNR` code from E3d-LSTM, Thanks!
# https://github.com/google/e3d_lstm/blob/master/src/trainer.py
def PSNR(pred, true):
mse = np.mean((np.uint8(pred * 255)-np.uint8(true * 255))**2)
return 20 * np.log10(255) - 10 * np.log10(mse)


def metric(pred, true, mean, std, return_ssim_psnr=False, clip_range=[0, 1]):
pred = pred*std + mean
true = true*std + mean
Expand All @@ -39,4 +43,4 @@ def metric(pred, true, mean, std, return_ssim_psnr=False, clip_range=[0, 1]):

return mae, mse, ssim, psnr
else:
return mae, mse
return mae, mse
Loading

0 comments on commit 1312316

Please sign in to comment.