toshas commited on
Commit
0d00040
·
1 Parent(s): 88964f9

fix sparse_depth is None usecase

Browse files
Files changed (1) hide show
  1. marigold_dc.py +13 -9
marigold_dc.py CHANGED
@@ -37,7 +37,7 @@ class MarigoldDepthCompletionPipeline(MarigoldDepthPipeline):
37
  # Check inputs.
38
  if num_inference_steps is None:
39
  raise ValueError("Invalid num_inference_steps")
40
- if type(sparse_depth) is not np.ndarray or sparse_depth.ndim != 2:
41
  raise ValueError(
42
  "Sparse depth should be a 2D numpy ndarray with zeros at missing positions"
43
  )
@@ -64,7 +64,7 @@ class MarigoldDepthCompletionPipeline(MarigoldDepthPipeline):
64
  dtype=self.dtype,
65
  ) # [N,3,PPH,PPW]
66
 
67
- if sparse_depth.shape != original_resolution:
68
  raise ValueError(
69
  f"Sparse depth dimensions ({sparse_depth.shape}) must match that of the image ({image.shape[-2:]})"
70
  )
@@ -76,12 +76,16 @@ class MarigoldDepthCompletionPipeline(MarigoldDepthPipeline):
76
  del image
77
 
78
  # Preprocess sparse depth
79
- sparse_depth = torch.from_numpy(sparse_depth)[None, None].float()
80
- sparse_depth = sparse_depth.to(device)
81
- sparse_mask = sparse_depth > 0
82
- sparse_depth = sparse_depth[sparse_mask]
83
- sparse_depth_min = sparse_depth.min() if sparse_depth.numel() > 0 else 0
84
- sparse_depth_max = sparse_depth.max() if sparse_depth.numel() > 0 else 1
 
 
 
 
85
 
86
  # Set up optimization targets
87
  pred_latent = torch.nn.Parameter(pred_latent, requires_grad=True)
@@ -159,7 +163,7 @@ class MarigoldDepthCompletionPipeline(MarigoldDepthPipeline):
159
  pred_original_sample = step_output.pred_original_sample
160
  current_metric_estimate = latent_to_metric(pred_original_sample)
161
 
162
- if sparse_depth.numel() > 0:
163
  loss, rmse = loss_l1l2(current_metric_estimate[sparse_mask], sparse_depth)
164
  loss.backward()
165
 
 
37
  # Check inputs.
38
  if num_inference_steps is None:
39
  raise ValueError("Invalid num_inference_steps")
40
+ if sparse_depth is not None and (type(sparse_depth) is not np.ndarray or sparse_depth.ndim != 2):
41
  raise ValueError(
42
  "Sparse depth should be a 2D numpy ndarray with zeros at missing positions"
43
  )
 
64
  dtype=self.dtype,
65
  ) # [N,3,PPH,PPW]
66
 
67
+ if sparse_depth is not None and sparse_depth.shape != original_resolution:
68
  raise ValueError(
69
  f"Sparse depth dimensions ({sparse_depth.shape}) must match that of the image ({image.shape[-2:]})"
70
  )
 
76
  del image
77
 
78
  # Preprocess sparse depth
79
+ if sparse_depth is not None:
80
+ sparse_depth = torch.from_numpy(sparse_depth)[None, None].float()
81
+ sparse_depth = sparse_depth.to(device)
82
+ sparse_mask = sparse_depth > 0
83
+ sparse_depth = sparse_depth[sparse_mask]
84
+ sparse_depth_min = sparse_depth.min() if sparse_depth.numel() > 0 else 0
85
+ sparse_depth_max = sparse_depth.max() if sparse_depth.numel() > 0 else 1
86
+ else:
87
+ sparse_depth_min = 0
88
+ sparse_depth_max = 1
89
 
90
  # Set up optimization targets
91
  pred_latent = torch.nn.Parameter(pred_latent, requires_grad=True)
 
163
  pred_original_sample = step_output.pred_original_sample
164
  current_metric_estimate = latent_to_metric(pred_original_sample)
165
 
166
+ if sparse_depth is not None and sparse_depth.numel() > 0:
167
  loss, rmse = loss_l1l2(current_metric_estimate[sparse_mask], sparse_depth)
168
  loss.backward()
169