-
Notifications
You must be signed in to change notification settings - Fork 121
/
run.py
218 lines (182 loc) · 6.79 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
# Script for inference on (in-the-wild) images
# Author: Bingxin Ke
# Last modified: 2023-12-15
import argparse
import os
from glob import glob
import logging
import numpy as np
import torch
from PIL import Image
from tqdm.auto import tqdm
from marigold import MarigoldPipeline
from marigold.util.seed_all import seed_all
EXTENSION_LIST = [".jpg", ".jpeg", ".png"]
if "__main__" == __name__:
logging.basicConfig(level=logging.INFO)
# -------------------- Arguments --------------------
parser = argparse.ArgumentParser(
description="Run single-image depth estimation using Marigold."
)
parser.add_argument(
"--checkpoint",
type=str,
default="Bingxin/Marigold",
help="Checkpoint path or hub name.",
)
parser.add_argument(
"--input_rgb_dir",
type=str,
required=True,
help="Path to the input image folder.",
)
parser.add_argument(
"--output_dir", type=str, required=True, help="Output directory."
)
# inference setting
parser.add_argument(
"--denoise_steps",
type=int,
default=10,
help="Diffusion denoising steps, more stepts results in higher accuracy but slower inference speed.",
)
parser.add_argument(
"--ensemble_size",
type=int,
default=10,
help="Number of predictions to be ensembled, more inference gives better results but runs slower.",
)
# resolution setting
parser.add_argument(
"--processing_res",
type=int,
default=768,
help="Maximum resolution of processing. 0 for using input image resolution. Default: 768.",
)
parser.add_argument(
"--output_processing_res",
action="store_true",
help="When input is resized, out put depth at resized operating resolution. Default: False.",
)
# depth map colormap
parser.add_argument(
"--color_map",
type=str,
default="Spectral",
help="Colormap used to render depth predictions.",
)
# other settings
parser.add_argument("--seed", type=int, default=None, help="Random seed.")
parser.add_argument(
"--batch_size",
type=int,
default=0,
help="Inference batch size. Default: 0 (will be set automatically).",
)
parser.add_argument(
"--apple_silicon",
action="store_true",
help="Flag of running on Apple Silicon.",
)
args = parser.parse_args()
checkpoint_path = args.checkpoint
input_rgb_dir = args.input_rgb_dir
output_dir = args.output_dir
denoise_steps = args.denoise_steps
ensemble_size = args.ensemble_size
processing_res = args.processing_res
match_input_res = not args.output_processing_res
color_map = args.color_map
seed = args.seed
batch_size = args.batch_size
apple_silicon = args.apple_silicon
if apple_silicon and 0 == batch_size:
batch_size = 1 # set default batchsize
# -------------------- Preparation --------------------
# Random seed
if seed is None:
import time
seed = int(time.time())
seed_all(seed)
# Output directories
output_dir_color = os.path.join(output_dir, "depth_colored")
output_dir_tif = os.path.join(output_dir, "depth_bw")
output_dir_npy = os.path.join(output_dir, "depth_npy")
os.makedirs(output_dir, exist_ok=True)
os.makedirs(output_dir_color, exist_ok=True)
os.makedirs(output_dir_tif, exist_ok=True)
os.makedirs(output_dir_npy, exist_ok=True)
logging.info(f"output dir: {output_dir}")
# Device
if apple_silicon:
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
device = torch.device("mps:0")
else:
device = torch.device("cpu")
logging.warning("MPS is not available. Running on CPU will be slow.")
else:
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
logging.warning("CUDA is not available. Running on CPU will be slow.")
logging.info(f"device: {device}")
# -------------------- Data --------------------
rgb_filename_list = glob(os.path.join(input_rgb_dir, "*"))
rgb_filename_list = [
f for f in rgb_filename_list if os.path.splitext(f)[1].lower() in EXTENSION_LIST
]
rgb_filename_list = sorted(rgb_filename_list)
n_images = len(rgb_filename_list)
if n_images > 0:
logging.info(f"Found {n_images} images")
else:
logging.error(f"No image found in '{input_rgb_dir}'")
exit(1)
# -------------------- Model --------------------
pipe = MarigoldPipeline.from_pretrained(checkpoint_path)
try:
import xformers
pipe.enable_xformers_memory_efficient_attention()
except:
pass # run without xformers
pipe = pipe.to(device)
# -------------------- Inference and saving --------------------
with torch.no_grad():
os.makedirs(output_dir, exist_ok=True)
for rgb_path in tqdm(rgb_filename_list, desc=f"Estimating depth", leave=True):
# Read input image
input_image = Image.open(rgb_path)
# Predict depth
pipe_out = pipe(
input_image,
denoising_steps=denoise_steps,
ensemble_size=ensemble_size,
processing_res=processing_res,
match_input_res=match_input_res,
batch_size=batch_size,
color_map=color_map,
show_progress_bar=True,
)
depth_pred: np.ndarray = pipe_out.depth_np
depth_colored: Image.Image = pipe_out.depth_colored
# Save as npy
rgb_name_base = os.path.splitext(os.path.basename(rgb_path))[0]
pred_name_base = rgb_name_base + "_pred"
npy_save_path = os.path.join(output_dir_npy, f"{pred_name_base}.npy")
if os.path.exists(npy_save_path):
logging.warning(f"Existing file: '{npy_save_path}' will be overwritten")
np.save(npy_save_path, depth_pred)
# Save as 16-bit uint png
depth_to_save = (depth_pred * 65535.0).astype(np.uint16)
png_save_path = os.path.join(output_dir_tif, f"{pred_name_base}.png")
if os.path.exists(png_save_path):
logging.warning(f"Existing file: '{png_save_path}' will be overwritten")
Image.fromarray(depth_to_save).save(png_save_path, mode="I;16")
# Colorize
colored_save_path = os.path.join(
output_dir_color, f"{pred_name_base}_colored.png"
)
if os.path.exists(colored_save_path):
logging.warning(f"Existing file: '{colored_save_path}' will be overwritten")
depth_colored.save(colored_save_path)