Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added gui support #31

Merged
merged 1 commit into from
Jul 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
added gui to ngp-pl
  • Loading branch information
bolopenguin committed Jul 28, 2022
commit 08b0148647cc4ab199d61aaf2234d52db56d993c
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
ckpts/
logs/
results/
output/
*.png
*.mp4

Expand Down
2 changes: 1 addition & 1 deletion datasets/ray_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def get_ray_directions(H, W, K, random=False, return_uv=False, flatten=True):

if return_uv:
return directions, grid
return directions
return directions.cuda()


@torch.cuda.amp.autocast(dtype=torch.float32)
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ imageio-ffmpeg
jupyter
scipy
pymcubes
trimesh
trimesh
dearpygui
234 changes: 234 additions & 0 deletions show_gui.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
import torch
from opt import get_opts

import numpy as np

from einops import rearrange
import dearpygui.dearpygui as dpg
from scipy.spatial.transform import Rotation as R

from train import NeRFSystem

from datasets import dataset_dict
from datasets.ray_utils import get_ray_directions, get_rays

import warnings
from argparse import Namespace

warnings.filterwarnings("ignore")


class OrbitCamera:
def __init__(self, W, H, r=5, fovy=50):
self.W = W
self.H = H
self.radius = r # camera distance from center
self.fovy = fovy # in degree
self.center = np.array([0, 0, 0], dtype=np.float32)
self.rot = R.from_quat([0, 1, 0, 0])
self.up = np.array([0, 1, 0], dtype=np.float32)

# pose
@property
def pose(self):
# first move camera to radius
res = np.eye(4, dtype=np.float32)
res[2, 3] -= self.radius
# rotate
rot = np.eye(4, dtype=np.float32)
rot[:3, :3] = self.rot.as_matrix()
res = rot @ res
# translate
res[:3, 3] -= self.center
return res

# intrinsics
@property
def intrinsics(self):
focal = self.H / (2 * np.tan(np.radians(self.fovy) / 2))
return np.array([focal, focal, self.W // 2, self.H // 2])

def orbit(self, dx, dy):
# rotate along camera up/side axis!
side = self.rot.as_matrix()[:3, 0]
rotvec_x = self.up * np.radians(-0.1 * dx)
rotvec_y = side * np.radians(-0.1 * dy)
self.rot = R.from_rotvec(rotvec_x) * R.from_rotvec(rotvec_y) * self.rot

def scale(self, delta):
self.radius *= 1.1 ** (-delta)

def pan(self, dx, dy, dz=0):
self.center += 0.001 * self.rot.as_matrix()[:3, :3] @ np.array([dx, dy, dz])


class NeRFGUI:
def __init__(self, renderer, H=1080, W=1440, radius=2.5, fovy=50):
self.renderer = renderer
self.H = H
self.W = W
self.radius = radius
self.fovy = fovy

self.cam = OrbitCamera(self.W, self.H, r=self.radius, fovy=self.fovy)

self.render_buffer = np.ones((self.W, self.H, 3), dtype=np.float32)

dpg.create_context()
self.register_dpg()

def __del__(self):
dpg.destroy_context()

def render_nerf(self):
dpg.set_value("_texture", self.renderer.render_one_pose(self.cam.pose))

def register_dpg(self):

## register texture ##
with dpg.texture_registry(show=False):
dpg.add_raw_texture(
self.W,
self.H,
self.render_buffer,
format=dpg.mvFormat_Float_rgb,
tag="_texture",
)

## register window ##
# the rendered image, as the primary window
with dpg.window(tag="_primary_window", width=self.W, height=self.H):
dpg.add_image("_texture")
dpg.set_primary_window("_primary_window", True)

## control window ##
with dpg.window(label="Control", tag="_control_window", width=500, height=150):
# Pose info
with dpg.collapsing_header(label="Info", default_open=True):
# pose
dpg.add_separator()
dpg.add_text("Camera Pose:")
dpg.add_text(str(self.cam.pose), tag="_log_pose")

## register camera handler ##
def callback_camera_drag_rotate(sender, app_data):
if not dpg.is_item_focused("_primary_window"):
return
dx = app_data[1]
dy = app_data[2]
self.cam.orbit(dx, dy)
self.need_update = True
dpg.set_value("_log_pose", str(self.cam.pose))

def callback_camera_wheel_scale(sender, app_data):
if not dpg.is_item_focused("_primary_window"):
return
delta = app_data
self.cam.scale(delta)
self.need_update = True
dpg.set_value("_log_pose", str(self.cam.pose))

def callback_camera_drag_pan(sender, app_data):
if not dpg.is_item_focused("_primary_window"):
return
dx = app_data[1]
dy = app_data[2]
self.cam.pan(dx, dy)
self.need_update = True
dpg.set_value("_log_pose", str(self.cam.pose))

with dpg.handler_registry():
dpg.add_mouse_drag_handler(
button=dpg.mvMouseButton_Left, callback=callback_camera_drag_rotate
)
dpg.add_mouse_wheel_handler(callback=callback_camera_wheel_scale)
dpg.add_mouse_drag_handler(
button=dpg.mvMouseButton_Middle, callback=callback_camera_drag_pan
)

## Window name ##
dpg.create_viewport(
title="ngp-pl", width=self.W, height=self.H, resizable=False
)

## Avoid scroll bar in the window ##
with dpg.theme() as theme_no_padding:
with dpg.theme_component(dpg.mvAll):
dpg.add_theme_style(
dpg.mvStyleVar_WindowPadding, 0, 0, category=dpg.mvThemeCat_Core
)
dpg.add_theme_style(
dpg.mvStyleVar_FramePadding, 0, 0, category=dpg.mvThemeCat_Core
)
dpg.add_theme_style(
dpg.mvStyleVar_CellPadding, 0, 0, category=dpg.mvThemeCat_Core
)
dpg.bind_item_theme("_primary_window", theme_no_padding)

## Launch the gui ##
dpg.setup_dearpygui()
dpg.show_viewport()

def render(self):

while dpg.is_dearpygui_running():
self.render_nerf()
dpg.render_dearpygui_frame()


class RenderGui:
def __init__(self, ckpt_path, intrinsics, H, W, shift, scale) -> None:
self.ckp = torch.load(ckpt_path)
self.intrinsics = intrinsics
self.H = H
self.W = W
self.shift = shift
self.scale = scale

del self.ckp["state_dict"]["poses"]
del self.ckp["state_dict"]["directions"]
self.ckp["hyper_parameters"]["eval_lpips"] = False

# Load checkpoint
self.system = NeRFSystem(Namespace(**self.ckp["hyper_parameters"])).cuda()
self.system.load_state_dict(self.ckp["state_dict"])

# Rays direction
self.directions = get_ray_directions(self.H, self.W, self.intrinsics)

def render_one_pose(self, pose):
rays_o, rays_d = self.get_rays(pose)
results = self.system(rays_o, rays_d, split="render")

rgb_pred = rearrange(results["rgb"].cpu().numpy(), "(h w) c -> h w c", h=self.H)

return rgb_pred

def get_rays(self, pose):
pose = pose[:3]
pose[:, 3] -= self.shift
pose[:, 3] /= self.scale

rays_o, rays_d = get_rays(self.directions, torch.cuda.FloatTensor(pose))

return rays_o, rays_d


if __name__ == "__main__":
hparams = get_opts()
dataset = dataset_dict[hparams.dataset_name]
kwargs = {
"root_dir": hparams.root_dir,
"downsample": hparams.downsample,
}
dataset = dataset(split="val", **kwargs)

shift = dataset.shift
scale = dataset.scale

intrinsics = dataset.K
w, h = dataset.img_wh

render_gui = RenderGui(hparams.ckpt_path, intrinsics, h, w, shift, scale)
gui = NeRFGUI(render_gui, h, w)
gui.render()
42 changes: 21 additions & 21 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,12 @@ def __init__(self, hparams):
self.train_psnr = PeakSignalNoiseRatio(data_range=1)
self.val_psnr = PeakSignalNoiseRatio(data_range=1)
self.val_ssim = StructuralSimilarityIndexMeasure(data_range=1)
if hparams.eval_lpips:
if self.hparams.eval_lpips:
self.val_lpips = LearnedPerceptualImagePatchSimilarity('vgg')
for p in self.val_lpips.net.parameters():
p.requires_grad = False

self.model = NGP(scale=hparams.scale)
self.model = NGP(scale=self.hparams.scale)
G = self.model.grid_size
self.model.register_buffer('density_grid',
torch.zeros(self.model.cascades, G**3))
Expand All @@ -75,17 +75,17 @@ def __init__(self, hparams):

def forward(self, rays_o, rays_d, split):
kwargs = {'test_time': split!='train'}
if hparams.dataset_name in ['colmap', 'nerfpp']:
if self.hparams.dataset_name in ['colmap', 'nerfpp']:
kwargs['exp_step_factor'] = 1/256

return render(self.model, rays_o, rays_d, **kwargs)

def setup(self, stage):
dataset = dataset_dict[hparams.dataset_name]
kwargs = {'root_dir': hparams.root_dir,
'downsample': hparams.downsample}
self.train_dataset = dataset(split=hparams.split, **kwargs)
self.train_dataset.batch_size = hparams.batch_size
dataset = dataset_dict[self.hparams.dataset_name]
kwargs = {'root_dir': self.hparams.root_dir,
'downsample': self.hparams.downsample}
self.train_dataset = dataset(split=self.hparams.split, **kwargs)
self.train_dataset.batch_size = self.hparams.batch_size

self.test_dataset = dataset(split='test', **kwargs)

Expand All @@ -94,7 +94,7 @@ def configure_optimizers(self):
self.register_buffer('directions', self.train_dataset.directions.to(self.device))
self.register_buffer('poses', self.train_dataset.poses.to(self.device))

if hparams.optimize_ext:
if self.hparams.optimize_ext:
N = len(self.train_dataset.poses)
self.register_parameter('dR',
nn.Parameter(torch.zeros(N, 3, device=self.device)))
Expand All @@ -106,16 +106,16 @@ def configure_optimizers(self):
if n not in ['dR', 'dT']: net_params += [p]

opts = []
self.net_opt = FusedAdam(net_params, hparams.lr, eps=1e-15)
self.net_opt = FusedAdam(net_params, self.hparams.lr, eps=1e-15)
opts += [self.net_opt]
if hparams.optimize_ext:
if self.hparams.optimize_ext:
# learning rate is hard-coded
pose_r_opt = FusedAdam([self.dR], 1e-6)
pose_t_opt = FusedAdam([self.dT], 1e-6)
opts += [pose_r_opt, pose_t_opt]
net_sch = CosineAnnealingLR(self.net_opt,
hparams.num_epochs,
hparams.lr/30)
self.hparams.num_epochs,
self.hparams.lr/30)

return opts, [net_sch]

Expand All @@ -141,10 +141,10 @@ def training_step(self, batch, batch_nb, *args):
if self.global_step%self.S == 0:
self.model.update_density_grid(0.01*MAX_SAMPLES/3**0.5,
warmup=self.global_step<256,
erode=hparams.dataset_name=='colmap')
erode=self.hparams.dataset_name=='colmap')

poses = self.poses[batch['img_idxs']] # (B, 3, 4)
if hparams.optimize_ext:
if self.hparams.optimize_ext:
dR = axisangle_to_R(self.dR[batch['img_idxs']]) # (B, 3, 3)
poses[..., :3] = dR @ poses[..., :3]
dT = self.dT[batch['img_idxs']] # (B, 3)
Expand All @@ -167,13 +167,13 @@ def training_step(self, batch, batch_nb, *args):

def on_validation_start(self):
torch.cuda.empty_cache()
if not hparams.no_save_test:
self.val_dir = f'results/{hparams.dataset_name}/{hparams.exp_name}'
if not self.hparams.no_save_test:
self.val_dir = f'results/{self.hparams.dataset_name}/{self.hparams.exp_name}'
os.makedirs(self.val_dir, exist_ok=True)

def validation_step(self, batch, batch_nb):
rgb_gt = batch['rgb']
if hparams.optimize_ext:
if self.hparams.optimize_ext:
dR = axisangle_to_R(self.dR[batch['img_idxs']]) # (B, 3, 3)
batch['pose'][..., :3] = dR @ batch['pose'][..., :3]
dT = self.dT[batch['img_idxs']] # (3)
Expand All @@ -193,13 +193,13 @@ def validation_step(self, batch, batch_nb):
self.val_ssim(rgb_pred, rgb_gt)
logs['ssim'] = self.val_ssim.compute()
self.val_ssim.reset()
if hparams.eval_lpips:
if self.hparams.eval_lpips:
self.val_lpips(torch.clip(rgb_pred*2-1, -1, 1),
torch.clip(rgb_gt*2-1, -1, 1))
logs['lpips'] = self.val_lpips.compute()
self.val_lpips.reset()

if not hparams.no_save_test: # save test image to disk
if not self.hparams.no_save_test: # save test image to disk
idx = batch['img_idxs']
rgb_pred = rearrange(results['rgb'].cpu().numpy(), '(h w) c -> h w c', h=h)
rgb_pred = (rgb_pred*255).astype(np.uint8)
Expand All @@ -218,7 +218,7 @@ def validation_epoch_end(self, outputs):
mean_ssim = all_gather_ddp_if_available(ssims).mean()
self.log('test/ssim', mean_ssim)

if hparams.eval_lpips:
if self.hparams.eval_lpips:
lpipss = torch.stack([x['lpips'] for x in outputs])
mean_lpips = all_gather_ddp_if_available(lpipss).mean()
self.log('test/lpips_vgg', mean_lpips)
Expand Down