toshas commited on
Commit
3151687
·
1 Parent(s): 555f93d

allow inference without clicks

Browse files

add a nice example with click measurements

Files changed (2) hide show
  1. app.py +72 -31
  2. marigold_dc.py +15 -11
app.py CHANGED
@@ -16,11 +16,8 @@
16
  # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
17
  # More information about the method can be found at https://marigoldmonodepth.github.io
18
  # --------------------------------------------------------------------------
19
- # TODO: min 2 measurements exception does not work now
20
  # TODO: 16bit depth map download
21
  # TODO: change to gradio-dualvision (update it with the Examples thumbs first)
22
- # TODO: good examples where measurements help
23
- # TODO: examples for measurements with points saved
24
 
25
  import os
26
  import PIL
@@ -42,7 +39,7 @@ DEFAULT_denoise_steps = 10
42
  DEFAULT_lr_latent = 0.05
43
  DEFAULT_lr_scale_shift = 0.005
44
 
45
-
46
  TAB10_COLORS = [
47
  (31, 119, 180), # blue
48
  (255, 127, 14), # orange
@@ -72,15 +69,15 @@ def get_wrapped_color(index):
72
  return adjust_brightness(base_color, factor)
73
 
74
 
75
- def on_click(img: Image.Image, state_orig_img: gr.State, evt: gr.SelectData, table):
76
  if isinstance(img, str):
77
  img = Image.open(img)
78
  if state_orig_img is None:
79
  state_orig_img = img.copy()
80
- x, y = evt.index
 
81
  color = get_wrapped_color(len(table))
82
  color_hex = '#%02x%02x%02x' % color
83
- tile_char = "██"
84
 
85
  img = img.convert("RGB")
86
  draw = ImageDraw.Draw(img)
@@ -89,7 +86,15 @@ def on_click(img: Image.Image, state_orig_img: gr.State, evt: gr.SelectData, tab
89
  draw.ellipse((x - r, y - r, x + r, y + r), fill=color, outline=color)
90
  draw.ellipse((x - r, y - r, x + r, y + r), fill=None, outline=(255, 255, 255), width=max(1, r//4))
91
 
92
- table = table.values.tolist() + [[tile_char, "", x, y, color_hex]]
 
 
 
 
 
 
 
 
93
  return img, state_orig_img, gr.Dataframe(table, visible=True)
94
 
95
 
@@ -203,9 +208,15 @@ def process(
203
  sparse_depth[~sparse_depth_valid_mask] = 0
204
  kernel_size = 10
205
  else:
206
- raise ValueError("At least two valid measurements are required")
 
 
 
207
  else:
208
- raise ValueError("At least two valid measurements are required")
 
 
 
209
 
210
  width, height = image.size
211
  max_dim = max(width, height)
@@ -230,17 +241,23 @@ def process(
230
  dry_run=DRY_RUN,
231
  )
232
  ):
233
- min_both = min(sparse_depth_min, pred.min().item())
234
- max_both = min(sparse_depth_max, pred.max().item())
 
 
 
235
  metrics.append(rmse)
236
  steps.append(step)
237
 
238
  vis_pred = pipe.image_processor.visualize_depth(pred, val_min=min_both, val_max=max_both)[0]
239
 
240
- vis_sparse = pipe.image_processor.visualize_depth(sparse_depth, val_min=min_both, val_max=max_both)[0]
241
- vis_sparse = np.array(vis_sparse)
242
- vis_sparse[sparse_depth <= 0] = (0, 0, 0)
243
- vis_sparse = dilate_rgb_image(vis_sparse, kernel_size=kernel_size)
 
 
 
244
  vis_sparse = Image.fromarray(vis_sparse)
245
 
246
  plot = generate_rmse_plot(steps, metrics, denoise_steps)
@@ -366,7 +383,7 @@ with gr.Blocks(
366
  visible=False,
367
  )
368
  input_image = gr.Image(
369
- label="Input image",
370
  type="filepath",
371
  interactive=True,
372
  )
@@ -501,30 +518,52 @@ with gr.Blocks(
501
  )
502
 
503
  def examples_depth_lidar_fn(path_thumb):
504
- real_url = lambda fname: f"https://huggingface.co/spaces/prs-eth/marigold-dc/resolve/main/files/{fname}"
505
  l_thumb = os.path.basename(path_thumb)
506
  d_thumb = os.path.dirname(path_thumb)
507
- l_image, l_sparse = {
508
- "thumb_kitti_1.jpg": ["kitti_1.png", "kitti_1.npy"],
509
- "thumb_kitti_2.jpg": ["kitti_2.png", "kitti_2.npy"],
510
- "thumb_teaser_10.jpg": ["teaser.png", "teaser_10.npy"],
511
- "thumb_teaser_100.jpg": ["teaser.png", "teaser_100.npy"],
512
- "thumb_teaser_1000.jpg": ["teaser.png", "teaser_1000.npy"],
 
 
 
 
513
  }[l_thumb]
514
- u_image, u_sparse = real_url(l_image), real_url(l_sparse)
 
515
  l_down_image = os.path.join(d_thumb, l_image)
516
- l_down_sparse = os.path.join(d_thumb, l_sparse)
517
- for url, down_path in ((u_image, l_down_image), (u_sparse, l_down_sparse)):
518
- response = requests.get(url)
 
 
 
 
 
 
 
 
519
  response.raise_for_status()
520
- with open(down_path, "wb") as f:
521
  f.write(response.content)
522
- for outputs in process(l_down_image, None, [], l_down_sparse):
523
- yield l_down_image, l_down_sparse, *outputs
 
 
 
 
 
 
 
 
524
 
525
  examples = gr.Examples(
526
  fn=examples_depth_lidar_fn,
527
  examples=[
 
528
  "files/thumb_kitti_1.jpg",
529
  "files/thumb_kitti_2.jpg",
530
  "files/thumb_teaser_10.jpg",
@@ -537,6 +576,8 @@ with gr.Blocks(
537
  outputs=[
538
  input_image,
539
  input_sparse,
 
 
540
  output_slider,
541
  plot,
542
  ],
 
16
  # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
17
  # More information about the method can be found at https://marigoldmonodepth.github.io
18
  # --------------------------------------------------------------------------
 
19
  # TODO: 16bit depth map download
20
  # TODO: change to gradio-dualvision (update it with the Examples thumbs first)
 
 
21
 
22
  import os
23
  import PIL
 
39
  DEFAULT_lr_latent = 0.05
40
  DEFAULT_lr_scale_shift = 0.005
41
 
42
+ TILE_CHAR = "██"
43
  TAB10_COLORS = [
44
  (31, 119, 180), # blue
45
  (255, 127, 14), # orange
 
69
  return adjust_brightness(base_color, factor)
70
 
71
 
72
+ def process_click_data(img: Image.Image, state_orig_img: gr.State, table, x: int, y: int, value: str = ""):
73
  if isinstance(img, str):
74
  img = Image.open(img)
75
  if state_orig_img is None:
76
  state_orig_img = img.copy()
77
+ if isinstance(table, pandas.DataFrame):
78
+ table = table.values.tolist()
79
  color = get_wrapped_color(len(table))
80
  color_hex = '#%02x%02x%02x' % color
 
81
 
82
  img = img.convert("RGB")
83
  draw = ImageDraw.Draw(img)
 
86
  draw.ellipse((x - r, y - r, x + r, y + r), fill=color, outline=color)
87
  draw.ellipse((x - r, y - r, x + r, y + r), fill=None, outline=(255, 255, 255), width=max(1, r//4))
88
 
89
+ if not isinstance(table, list):
90
+ table = table.values.tolist()
91
+ table = table + [[TILE_CHAR, value, x, y, color_hex]]
92
+ return img, state_orig_img, table
93
+
94
+
95
+ def on_click(img: Image.Image, state_orig_img: gr.State, evt: gr.SelectData, table):
96
+ x, y = evt.index
97
+ img, state_orig_img, table = process_click_data(img, state_orig_img, table, x, y)
98
  return img, state_orig_img, gr.Dataframe(table, visible=True)
99
 
100
 
 
208
  sparse_depth[~sparse_depth_valid_mask] = 0
209
  kernel_size = 10
210
  else:
211
+ sparse_depth = None
212
+ sparse_depth_min = 0
213
+ sparse_depth_max = 1
214
+ kernel_size = 5
215
  else:
216
+ sparse_depth = None
217
+ sparse_depth_min = 0
218
+ sparse_depth_max = 1
219
+ kernel_size = 5
220
 
221
  width, height = image.size
222
  max_dim = max(width, height)
 
241
  dry_run=DRY_RUN,
242
  )
243
  ):
244
+ min_both = pred.min().item()
245
+ max_both = pred.max().item()
246
+ if sparse_depth is not None:
247
+ min_both = min(sparse_depth_min, min_both)
248
+ max_both = min(sparse_depth_max, max_both)
249
  metrics.append(rmse)
250
  steps.append(step)
251
 
252
  vis_pred = pipe.image_processor.visualize_depth(pred, val_min=min_both, val_max=max_both)[0]
253
 
254
+ if sparse_depth is not None:
255
+ vis_sparse = pipe.image_processor.visualize_depth(sparse_depth, val_min=min_both, val_max=max_both)[0]
256
+ vis_sparse = np.array(vis_sparse)
257
+ vis_sparse[sparse_depth <= 0] = (0, 0, 0)
258
+ vis_sparse = dilate_rgb_image(vis_sparse, kernel_size=kernel_size)
259
+ else:
260
+ vis_sparse = np.full_like(vis_pred, 0)
261
  vis_sparse = Image.fromarray(vis_sparse)
262
 
263
  plot = generate_rmse_plot(steps, metrics, denoise_steps)
 
383
  visible=False,
384
  )
385
  input_image = gr.Image(
386
+ label="Input image (click to enter depth)",
387
  type="filepath",
388
  interactive=True,
389
  )
 
518
  )
519
 
520
  def examples_depth_lidar_fn(path_thumb):
521
+ real_url = lambda fname: f"https://huggingface.co/spaces/obukhovai/marigold-dc-metric/resolve/main/files/{fname}"
522
  l_thumb = os.path.basename(path_thumb)
523
  d_thumb = os.path.dirname(path_thumb)
524
+ l_image, l_sparse, clicks = {
525
+ "thumb_matterhorn_clicks.jpg": ["matterhorn.png", None, [
526
+ [TILE_CHAR, "2", 495, 1573, '#%02x%02x%02x' % get_wrapped_color(0)],
527
+ [TILE_CHAR, "3", 1062, 1550, '#%02x%02x%02x' % get_wrapped_color(1)],
528
+ ]],
529
+ "thumb_kitti_1.jpg": ["kitti_1.png", "kitti_1.npy", []],
530
+ "thumb_kitti_2.jpg": ["kitti_2.png", "kitti_2.npy", []],
531
+ "thumb_teaser_10.jpg": ["teaser.png", "teaser_10.npy", []],
532
+ "thumb_teaser_100.jpg": ["teaser.png", "teaser_100.npy", []],
533
+ "thumb_teaser_1000.jpg": ["teaser.png", "teaser_1000.npy", []],
534
  }[l_thumb]
535
+
536
+ u_image = real_url(l_image)
537
  l_down_image = os.path.join(d_thumb, l_image)
538
+ response = requests.get(u_image)
539
+ response.raise_for_status()
540
+ with open(l_down_image, "wb") as f:
541
+ f.write(response.content)
542
+
543
+ table_visible = len(clicks) > 0
544
+ l_down_sparse = None
545
+ if l_sparse is not None:
546
+ u_sparse = real_url(l_sparse)
547
+ l_down_sparse = os.path.join(d_thumb, l_sparse)
548
+ response = requests.get(u_sparse)
549
  response.raise_for_status()
550
+ with open(l_down_sparse, "wb") as f:
551
  f.write(response.content)
552
+
553
+ state_orig_img = None
554
+ table = []
555
+ if len(clicks) > 0:
556
+ for click in clicks:
557
+ _, value, x, y, _ = click
558
+ l_down_image, state_orig_img, table = process_click_data(l_down_image, state_orig_img, table, x, y, value)
559
+
560
+ for outputs in process(l_down_image, state_orig_img, clicks, l_down_sparse):
561
+ yield l_down_image, l_down_sparse, state_orig_img, gr.Dataframe(table, visible=table_visible), *outputs
562
 
563
  examples = gr.Examples(
564
  fn=examples_depth_lidar_fn,
565
  examples=[
566
+ "files/thumb_matterhorn_clicks.jpg",
567
  "files/thumb_kitti_1.jpg",
568
  "files/thumb_kitti_2.jpg",
569
  "files/thumb_teaser_10.jpg",
 
576
  outputs=[
577
  input_image,
578
  input_sparse,
579
+ state_orig_img,
580
+ table,
581
  output_slider,
582
  plot,
583
  ],
marigold_dc.py CHANGED
@@ -80,6 +80,8 @@ class MarigoldDepthCompletionPipeline(MarigoldDepthPipeline):
80
  sparse_depth = sparse_depth.to(device)
81
  sparse_mask = sparse_depth > 0
82
  sparse_depth = sparse_depth[sparse_mask]
 
 
83
 
84
  # Set up optimization targets
85
  pred_latent = torch.nn.Parameter(pred_latent, requires_grad=True)
@@ -89,14 +91,14 @@ class MarigoldDepthCompletionPipeline(MarigoldDepthPipeline):
89
  sparse_range = 1.0
90
  else:
91
  scale = torch.nn.Parameter(torch.ones(1, device=device), requires_grad=True)
92
- sparse_range = (sparse_depth.max() - sparse_depth.min()).item()
93
 
94
  if override_shift:
95
  shift = np.sqrt(override_shift)
96
  sparse_lower = 1.0
97
  else:
98
  shift = torch.nn.Parameter(torch.ones(1, device=device), requires_grad=True)
99
- sparse_lower = (sparse_depth.min()).item()
100
 
101
  def affine_to_metric(depth):
102
  return (scale**2) * sparse_range * depth + (shift**2) * sparse_lower
@@ -156,17 +158,19 @@ class MarigoldDepthCompletionPipeline(MarigoldDepthPipeline):
156
  # Preview the final output depth, compute loss with guidance, backprop
157
  pred_original_sample = step_output.pred_original_sample
158
  current_metric_estimate = latent_to_metric(pred_original_sample)
159
- loss, rmse = loss_l1l2(current_metric_estimate[sparse_mask], sparse_depth)
160
- loss.backward()
161
 
162
- # Scale gradients up
163
- with torch.no_grad():
164
- pred_epsilon_norm = torch.linalg.norm(pred_epsilon).item()
165
- depth_latent_grad_norm = torch.linalg.norm(pred_latent.grad).item()
166
- scaling_factor = pred_epsilon_norm / max(depth_latent_grad_norm, 1e-8)
167
- pred_latent.grad *= scaling_factor
 
 
 
 
168
 
169
- optimizer.step()
170
 
171
  with torch.no_grad():
172
  pred_latent.data = self.scheduler.step(noise, t, pred_latent, generator=generator).prev_sample
 
80
  sparse_depth = sparse_depth.to(device)
81
  sparse_mask = sparse_depth > 0
82
  sparse_depth = sparse_depth[sparse_mask]
83
+ sparse_depth_min = sparse_depth.min() if sparse_depth.numel() > 0 else 0
84
+ sparse_depth_max = sparse_depth.max() if sparse_depth.numel() > 0 else 1
85
 
86
  # Set up optimization targets
87
  pred_latent = torch.nn.Parameter(pred_latent, requires_grad=True)
 
91
  sparse_range = 1.0
92
  else:
93
  scale = torch.nn.Parameter(torch.ones(1, device=device), requires_grad=True)
94
+ sparse_range = (sparse_depth_max - sparse_depth_min).item()
95
 
96
  if override_shift:
97
  shift = np.sqrt(override_shift)
98
  sparse_lower = 1.0
99
  else:
100
  shift = torch.nn.Parameter(torch.ones(1, device=device), requires_grad=True)
101
+ sparse_lower = (sparse_depth_min).item()
102
 
103
  def affine_to_metric(depth):
104
  return (scale**2) * sparse_range * depth + (shift**2) * sparse_lower
 
158
  # Preview the final output depth, compute loss with guidance, backprop
159
  pred_original_sample = step_output.pred_original_sample
160
  current_metric_estimate = latent_to_metric(pred_original_sample)
 
 
161
 
162
+ if sparse_depth.numel() > 0:
163
+ loss, rmse = loss_l1l2(current_metric_estimate[sparse_mask], sparse_depth)
164
+ loss.backward()
165
+
166
+ # Scale gradients up
167
+ with torch.no_grad():
168
+ pred_epsilon_norm = torch.linalg.norm(pred_epsilon).item()
169
+ depth_latent_grad_norm = torch.linalg.norm(pred_latent.grad).item()
170
+ scaling_factor = pred_epsilon_norm / max(depth_latent_grad_norm, 1e-8)
171
+ pred_latent.grad *= scaling_factor
172
 
173
+ optimizer.step()
174
 
175
  with torch.no_grad():
176
  pred_latent.data = self.scheduler.step(noise, t, pred_latent, generator=generator).prev_sample