Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	fix sparse_depth is None usecase
Browse files- marigold_dc.py +13 -9
    	
        marigold_dc.py
    CHANGED
    
    | @@ -37,7 +37,7 @@ class MarigoldDepthCompletionPipeline(MarigoldDepthPipeline): | |
| 37 | 
             
                    # Check inputs.
         | 
| 38 | 
             
                    if num_inference_steps is None:
         | 
| 39 | 
             
                        raise ValueError("Invalid num_inference_steps")
         | 
| 40 | 
            -
                    if type(sparse_depth) is not np.ndarray or sparse_depth.ndim != 2:
         | 
| 41 | 
             
                        raise ValueError(
         | 
| 42 | 
             
                            "Sparse depth should be a 2D numpy ndarray with zeros at missing positions"
         | 
| 43 | 
             
                        )
         | 
| @@ -64,7 +64,7 @@ class MarigoldDepthCompletionPipeline(MarigoldDepthPipeline): | |
| 64 | 
             
                        dtype=self.dtype,
         | 
| 65 | 
             
                    )  # [N,3,PPH,PPW]
         | 
| 66 |  | 
| 67 | 
            -
                    if sparse_depth.shape != original_resolution:
         | 
| 68 | 
             
                        raise ValueError(
         | 
| 69 | 
             
                            f"Sparse depth dimensions ({sparse_depth.shape}) must match that of the image ({image.shape[-2:]})"
         | 
| 70 | 
             
                        )
         | 
| @@ -76,12 +76,16 @@ class MarigoldDepthCompletionPipeline(MarigoldDepthPipeline): | |
| 76 | 
             
                    del image
         | 
| 77 |  | 
| 78 | 
             
                    # Preprocess sparse depth
         | 
| 79 | 
            -
                    sparse_depth  | 
| 80 | 
            -
             | 
| 81 | 
            -
             | 
| 82 | 
            -
             | 
| 83 | 
            -
             | 
| 84 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
| 85 |  | 
| 86 | 
             
                    # Set up optimization targets
         | 
| 87 | 
             
                    pred_latent = torch.nn.Parameter(pred_latent, requires_grad=True)
         | 
| @@ -159,7 +163,7 @@ class MarigoldDepthCompletionPipeline(MarigoldDepthPipeline): | |
| 159 | 
             
                        pred_original_sample = step_output.pred_original_sample
         | 
| 160 | 
             
                        current_metric_estimate = latent_to_metric(pred_original_sample)
         | 
| 161 |  | 
| 162 | 
            -
                        if sparse_depth.numel() > 0:
         | 
| 163 | 
             
                            loss, rmse = loss_l1l2(current_metric_estimate[sparse_mask], sparse_depth)
         | 
| 164 | 
             
                            loss.backward()
         | 
| 165 |  | 
|  | |
| 37 | 
             
                    # Check inputs.
         | 
| 38 | 
             
                    if num_inference_steps is None:
         | 
| 39 | 
             
                        raise ValueError("Invalid num_inference_steps")
         | 
| 40 | 
            +
                    if sparse_depth is not None and (type(sparse_depth) is not np.ndarray or sparse_depth.ndim != 2):
         | 
| 41 | 
             
                        raise ValueError(
         | 
| 42 | 
             
                            "Sparse depth should be a 2D numpy ndarray with zeros at missing positions"
         | 
| 43 | 
             
                        )
         | 
|  | |
| 64 | 
             
                        dtype=self.dtype,
         | 
| 65 | 
             
                    )  # [N,3,PPH,PPW]
         | 
| 66 |  | 
| 67 | 
            +
                    if sparse_depth is not None and sparse_depth.shape != original_resolution:
         | 
| 68 | 
             
                        raise ValueError(
         | 
| 69 | 
             
                            f"Sparse depth dimensions ({sparse_depth.shape}) must match that of the image ({image.shape[-2:]})"
         | 
| 70 | 
             
                        )
         | 
|  | |
| 76 | 
             
                    del image
         | 
| 77 |  | 
| 78 | 
             
                    # Preprocess sparse depth
         | 
| 79 | 
            +
                    if sparse_depth is not None:
         | 
| 80 | 
            +
                        sparse_depth = torch.from_numpy(sparse_depth)[None, None].float()
         | 
| 81 | 
            +
                        sparse_depth = sparse_depth.to(device)
         | 
| 82 | 
            +
                        sparse_mask = sparse_depth > 0
         | 
| 83 | 
            +
                        sparse_depth = sparse_depth[sparse_mask]
         | 
| 84 | 
            +
                        sparse_depth_min = sparse_depth.min() if sparse_depth.numel() > 0 else 0
         | 
| 85 | 
            +
                        sparse_depth_max = sparse_depth.max() if sparse_depth.numel() > 0 else 1
         | 
| 86 | 
            +
                    else:
         | 
| 87 | 
            +
                        sparse_depth_min = 0
         | 
| 88 | 
            +
                        sparse_depth_max = 1
         | 
| 89 |  | 
| 90 | 
             
                    # Set up optimization targets
         | 
| 91 | 
             
                    pred_latent = torch.nn.Parameter(pred_latent, requires_grad=True)
         | 
|  | |
| 163 | 
             
                        pred_original_sample = step_output.pred_original_sample
         | 
| 164 | 
             
                        current_metric_estimate = latent_to_metric(pred_original_sample)
         | 
| 165 |  | 
| 166 | 
            +
                        if sparse_depth is not None and sparse_depth.numel() > 0:
         | 
| 167 | 
             
                            loss, rmse = loss_l1l2(current_metric_estimate[sparse_mask], sparse_depth)
         | 
| 168 | 
             
                            loss.backward()
         | 
| 169 |  | 
 
			

