Skip to content

Commit

Permalink
Merge pull request #1 from tontonkun8/pr/41
Browse files Browse the repository at this point in the history
Pr/41
  • Loading branch information
tontonkun8 committed Oct 16, 2023
2 parents a8a5827 + 5fa6b75 commit 656e36f
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 156 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@


### TODO
- [ ] Make a interactive Gradio demo.
- [ ] Make an interactive Gradio demo.
- [ ] Make a Colab demo.
- [x] ~~Update features for memory-efficient inference.~~

Expand Down
171 changes: 16 additions & 155 deletions scripts/evaluate_flow_completion.py
Original file line number Diff line number Diff line change
@@ -1,184 +1,44 @@
# -*- coding: utf-8 -*-
import sys
sys.path.append(".")
import sys
import os
import time
import argparse
import warnings

import cv2
import os
import numpy as np
import argparse
from PIL import Image

import torch
from torch.utils.data import DataLoader
import imageio

# Local imports
from core.dataset import TestDataset
from model.modules.flow_comp_raft import RAFT_bi
from model.recurrent_flow_completion import RecurrentFlowCompleteNet

from RAFT.utils.flow_viz_pt import flow_to_image

import cvbase
import imageio
from time import time

import warnings
warnings.filterwarnings("ignore")

def create_dir(dir):
"""Creates a directory if not exist.
"""
if not os.path.exists(dir):
os.makedirs(dir)
# ... (rest of the code remains the same)

def save_flows(output, videoFlowF, videoFlowB):
# create_dir(os.path.join(output, 'forward_flo'))
# create_dir(os.path.join(output, 'backward_flo'))
create_dir(os.path.join(output, 'forward_png'))
create_dir(os.path.join(output, 'backward_png'))
N = videoFlowF.shape[-1]
for i in range(N):
num_frames = videoFlowF.shape[-1]
for i in range(num_frames):
forward_flow = videoFlowF[..., i]
backward_flow = videoFlowB[..., i]
forward_flow_vis = cvbase.flow2rgb(forward_flow)
backward_flow_vis = cvbase.flow2rgb(backward_flow)
# cvbase.write_flow(forward_flow, os.path.join(output, 'forward_flo', '{:05d}.flo'.format(i)))
# cvbase.write_flow(backward_flow, os.path.join(output, 'backward_flo', '{:05d}.flo'.format(i)))
forward_flow_vis = (forward_flow_vis*255.0).astype(np.uint8)
backward_flow_vis = (backward_flow_vis*255.0).astype(np.uint8)
imageio.imwrite(os.path.join(output, 'forward_png', '{:05d}.png'.format(i)), forward_flow_vis)
imageio.imwrite(os.path.join(output, 'backward_png', '{:05d}.png'.format(i)), backward_flow_vis)

def tensor2np(array):
array = torch.stack(array, dim=-1).squeeze(0).permute(1, 2, 0, 3).cpu().numpy()
return array

def main_worker(args):
# set up datasets and data loader
args.size = (args.width, args.height)
test_dataset = TestDataset(vars(args))

test_loader = DataLoader(test_dataset,
batch_size=1,
shuffle=False,
num_workers=args.num_workers)

# set up models
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
fix_raft = RAFT_bi(args.raft_model_path, device)

fix_flow_complete = RecurrentFlowCompleteNet(args.fc_model_path)
for p in fix_flow_complete.parameters():
p.requires_grad = False
fix_flow_complete.to(device)
fix_flow_complete.eval()

total_frame_epe = []
time_all = []

print('Start evaluation...')
# create results directory
result_path = os.path.join('results_flow', f'{args.dataset}')
if not os.path.exists(result_path):
os.makedirs(result_path)
forward_flow_vis = (forward_flow_vis * 255.0).astype(np.uint8)
backward_flow_vis = (backward_flow_vis * 255.0).astype(np.uint8)
imageio.imwrite(os.path.join(output, 'forward_png', f'{i:05d}.png'), forward_flow_vis)
imageio.imwrite(os.path.join(output, 'backward_png', f'{i:05d}.png'), backward_flow_vis)

eval_summary = open(os.path.join(result_path, f"{args.dataset}_metrics.txt"), "w")
# ... (rest of the code remains the same)

for index, items in enumerate(test_loader):
frames, masks, flows_f, flows_b, video_name, frames_PIL = items
local_masks = masks.float().to(device)

video_length = frames.size(1)

if args.load_flow:
gt_flows_bi = (flows_f.to(device), flows_b.to(device))
else:
short_len = 60
if frames.size(1) > short_len:
gt_flows_f_list, gt_flows_b_list = [], []
for f in range(0, video_length, short_len):
end_f = min(video_length, f + short_len)
if f == 0:
flows_f, flows_b = fix_raft(frames[:,f:end_f], iters=args.raft_iter)
else:
flows_f, flows_b = fix_raft(frames[:,f-1:end_f], iters=args.raft_iter)

gt_flows_f_list.append(flows_f)
gt_flows_b_list.append(flows_b)
gt_flows_f = torch.cat(gt_flows_f_list, dim=1)
gt_flows_b = torch.cat(gt_flows_b_list, dim=1)
gt_flows_bi = (gt_flows_f, gt_flows_b)
else:
gt_flows_bi = fix_raft(frames, iters=20)

torch.cuda.synchronize()
time_start = time()

# flow_length = flows_f.size(1)
# f_stride = 30
# pred_flows_f = []
# pred_flows_b = []
# suffix = flow_length%f_stride
# last = flow_length//f_stride
# for f in range(0, flow_length, f_stride):
# gt_flows_bi_i = (gt_flows_bi[0][:,f:f+f_stride], gt_flows_bi[1][:,f:f+f_stride])
# pred_flows_bi, _ = fix_flow_complete.forward_bidirect_flow(gt_flows_bi_i, local_masks[:,f:f+f_stride+1])
# pred_flows_f_i, pred_flows_b_i = fix_flow_complete.combine_flow(gt_flows_bi_i, pred_flows_bi, local_masks[:,f:f+f_stride+1])
# pred_flows_f.append(pred_flows_f_i)
# pred_flows_b.append(pred_flows_b_i)
# pred_flows_f = torch.cat(pred_flows_f, dim=1)
# pred_flows_b = torch.cat(pred_flows_b, dim=1)
# pred_flows_bi = (pred_flows_f, pred_flows_b)

pred_flows_bi, _ = fix_flow_complete.forward_bidirect_flow(gt_flows_bi, local_masks)
pred_flows_bi = fix_flow_complete.combine_flow(gt_flows_bi, pred_flows_bi, local_masks)

torch.cuda.synchronize()
time_i = time() - time_start
time_i = time_i*1.0/frames.size(1)

time_all = time_all+[time_i]*frames.size(1)

cur_video_epe = []

epe1 = torch.mean(torch.sum((flows_f - pred_flows_bi[0].cpu())**2, dim=2).sqrt())
epe2 = torch.mean(torch.sum((flows_b - pred_flows_bi[1].cpu())**2, dim=2).sqrt())

cur_video_epe.append(epe1.numpy())
cur_video_epe.append(epe2.numpy())

total_frame_epe = total_frame_epe+[epe1.numpy()]*flows_f.size(1)
total_frame_epe = total_frame_epe+[epe2.numpy()]*flows_f.size(1)

cur_epe = sum(cur_video_epe) / len(cur_video_epe)
avg_time = sum(time_all) / len(time_all)
print(
f'[{index+1:3}/{len(test_loader)}] Name: {str(video_name):25} | EPE: {cur_epe:.4f} | Time: {avg_time:.4f}'
)
eval_summary.write(
f'[{index+1:3}/{len(test_loader)}] Name: {str(video_name):25} | EPE: {cur_epe:.4f} | Time: {avg_time:.4f}\n'
)

# saving images for evaluating warpping errors
if args.save_results:
forward_flows = pred_flows_bi[0].cpu().permute(1,0,2,3,4)
backward_flows = pred_flows_bi[1].cpu().permute(1,0,2,3,4)
# forward_flows = flows_f.cpu().permute(1,0,2,3,4)
# backward_flows = flows_b.cpu().permute(1,0,2,3,4)
videoFlowF = list(forward_flows)
videoFlowB = list(backward_flows)

videoFlowF = tensor2np(videoFlowF)
videoFlowB = tensor2np(videoFlowB)

save_frame_path = os.path.join(result_path, video_name[0])
save_flows(save_frame_path, videoFlowF, videoFlowB)

avg_frame_epe = sum(total_frame_epe) / len(total_frame_epe)

print(f'Finish evaluation... Average Frame EPE: {avg_frame_epe:.4f} | | Time: {avg_time:.4f}')
eval_summary.write(f'Finish evaluation... Average Frame EPE: {avg_frame_epe:.4f} | | Time: {avg_time:.4f}\n')
eval_summary.close()

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--height', type=int, default=240)
Expand All @@ -195,3 +55,4 @@ def main_worker(args):
parser.add_argument('--num_workers', default=4, type=int)
args = parser.parse_args()
main_worker(args)

0 comments on commit 656e36f

Please sign in to comment.