Spaces:
Running
on
Zero
Running
on
Zero
| import logging | |
| import warnings | |
| import diffusers | |
| import numpy as np | |
| import torch | |
| from diffusers import MarigoldDepthPipeline | |
| warnings.simplefilter(action="ignore", category=FutureWarning) | |
| diffusers.utils.logging.disable_progress_bar() | |
| class MarigoldDepthCompletionPipeline(MarigoldDepthPipeline): | |
| def __call__( | |
| self, | |
| image, | |
| sparse_depth, | |
| num_inference_steps=50, | |
| processing_resolution=0, | |
| seed=2024, | |
| lr_scale_shift=0.005, | |
| lr_latent=0.05, | |
| override_shift=None, | |
| override_scale=None, | |
| dry_run=False, | |
| ): | |
| # Resolving variables | |
| device = self._execution_device | |
| generator = torch.Generator(device=device).manual_seed(seed) | |
| if dry_run: | |
| logging.warning("Dry run mode") | |
| for i in range(num_inference_steps): | |
| yield np.array(image)[:, :, 0].astype(float), float(np.log(i + 1)) | |
| return | |
| # Check inputs. | |
| if num_inference_steps is None: | |
| raise ValueError("Invalid num_inference_steps") | |
| if sparse_depth is not None and (type(sparse_depth) is not np.ndarray or sparse_depth.ndim != 2): | |
| raise ValueError( | |
| "Sparse depth should be a 2D numpy ndarray with zeros at missing positions" | |
| ) | |
| with torch.no_grad(): | |
| # Prepare empty text conditioning | |
| if self.empty_text_embedding is None: | |
| prompt = "" | |
| text_inputs = self.tokenizer( | |
| prompt, | |
| padding="do_not_pad", | |
| max_length=self.tokenizer.model_max_length, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| text_input_ids = text_inputs.input_ids.to(device) | |
| self.empty_text_embedding = self.text_encoder(text_input_ids)[0] # [1,2,1024] | |
| # Preprocess input images | |
| image, padding, original_resolution = self.image_processor.preprocess( | |
| image, | |
| processing_resolution=processing_resolution, | |
| device=device, | |
| dtype=self.dtype, | |
| ) # [N,3,PPH,PPW] | |
| if sparse_depth is not None and sparse_depth.shape != original_resolution: | |
| raise ValueError( | |
| f"Sparse depth dimensions ({sparse_depth.shape}) must match that of the image ({image.shape[-2:]})" | |
| ) | |
| with torch.no_grad(): | |
| # Encode input image into latent space | |
| image_latent, pred_latent = self.prepare_latents( | |
| image, None, generator, 1, 1 | |
| ) # [N*E,4,h,w], [N*E,4,h,w] | |
| del image | |
| # Preprocess sparse depth | |
| if sparse_depth is not None: | |
| sparse_depth = torch.from_numpy(sparse_depth)[None, None].float() | |
| sparse_depth = sparse_depth.to(device) | |
| sparse_mask = sparse_depth > 0 | |
| sparse_depth = sparse_depth[sparse_mask] | |
| sparse_depth_min = sparse_depth.min() if sparse_depth.numel() > 0 else 0 | |
| sparse_depth_max = sparse_depth.max() if sparse_depth.numel() > 0 else 1 | |
| else: | |
| sparse_depth_min = 0 | |
| sparse_depth_max = 1 | |
| # Set up optimization targets | |
| pred_latent = torch.nn.Parameter(pred_latent, requires_grad=True) | |
| if override_scale: | |
| scale = np.sqrt(override_scale) | |
| sparse_range = 1.0 | |
| else: | |
| scale = torch.nn.Parameter(torch.ones(1, device=device), requires_grad=True) | |
| sparse_range = sparse_depth_max - sparse_depth_min | |
| if torch.is_tensor(sparse_range): | |
| sparse_range = sparse_range.item() | |
| if override_shift: | |
| shift = np.sqrt(override_shift) | |
| sparse_lower = 1.0 | |
| else: | |
| shift = torch.nn.Parameter(torch.ones(1, device=device), requires_grad=True) | |
| sparse_lower = sparse_depth_min | |
| if torch.is_tensor(sparse_range): | |
| sparse_lower = sparse_lower.item() | |
| def affine_to_metric(depth): | |
| return (scale**2) * sparse_range * depth + (shift**2) * sparse_lower | |
| def latent_to_metric(latent): | |
| affine_invariant_prediction = self.decode_prediction( | |
| latent | |
| ) # [E,1,PPH,PPW] | |
| affine_invariant_prediction = affine_invariant_prediction.to(torch.float32) | |
| prediction = affine_to_metric(affine_invariant_prediction) | |
| prediction = self.image_processor.unpad_image( | |
| prediction, padding | |
| ) # [E,1,PH,PW] | |
| prediction = self.image_processor.resize_antialias( | |
| prediction, original_resolution, "bilinear", is_aa=False | |
| ) # [1,1,H,W] | |
| return prediction | |
| def loss_l1l2(input, target): | |
| out_l1 = torch.nn.functional.l1_loss(input, target) | |
| out_l2 = torch.nn.functional.mse_loss(input, target) | |
| out = out_l1 + out_l2 | |
| return out, out_l2.sqrt() | |
| optimizer_params = [{"params": [pred_latent], "lr": lr_latent}] | |
| if override_shift is None: | |
| optimizer_params.append({"params": [shift], "lr": lr_scale_shift}) | |
| if override_scale is None: | |
| optimizer_params.append({"params": [scale], "lr": lr_scale_shift}) | |
| optimizer = torch.optim.Adam(optimizer_params) | |
| # Process the denoising loop | |
| self.scheduler.set_timesteps(num_inference_steps, device=device) | |
| for iter, t in enumerate( | |
| self.progress_bar(self.scheduler.timesteps, desc=f"Marigold-DC steps ({str(device)})...") | |
| ): | |
| optimizer.zero_grad() | |
| batch_latent = torch.cat([image_latent, pred_latent], dim=1) # [1,8,h,w] | |
| noise = self.unet( | |
| batch_latent, | |
| t, | |
| encoder_hidden_states=self.empty_text_embedding, | |
| return_dict=False, | |
| )[0] # [1,4,h,w] | |
| # Compute pred_epsilon to later rescale the depth latent gradient | |
| with torch.no_grad(): | |
| alpha_prod_t = self.scheduler.alphas_cumprod[t] | |
| beta_prod_t = 1 - alpha_prod_t | |
| pred_epsilon = (alpha_prod_t**0.5) * noise + ( | |
| beta_prod_t**0.5 | |
| ) * pred_latent | |
| step_output = self.scheduler.step(noise, t, pred_latent, generator=generator) | |
| # Preview the final output depth, compute loss with guidance, backprop | |
| pred_original_sample = step_output.pred_original_sample | |
| current_metric_estimate = latent_to_metric(pred_original_sample) | |
| rmse = 0 | |
| if sparse_depth is not None and sparse_depth.numel() > 0: | |
| loss, rmse = loss_l1l2(current_metric_estimate[sparse_mask], sparse_depth) | |
| rmse = rmse.item() | |
| loss.backward() | |
| # Scale gradients up | |
| with torch.no_grad(): | |
| pred_epsilon_norm = torch.linalg.norm(pred_epsilon).item() | |
| depth_latent_grad_norm = torch.linalg.norm(pred_latent.grad).item() | |
| scaling_factor = pred_epsilon_norm / max(depth_latent_grad_norm, 1e-8) | |
| pred_latent.grad *= scaling_factor | |
| optimizer.step() | |
| with torch.no_grad(): | |
| pred_latent.data = self.scheduler.step(noise, t, pred_latent, generator=generator).prev_sample | |
| yield current_metric_estimate, rmse | |
| del ( | |
| pred_original_sample, | |
| current_metric_estimate, | |
| step_output, | |
| pred_epsilon, | |
| noise, | |
| ) | |
| torch.cuda.empty_cache() | |
| # Offload all models | |
| self.maybe_free_model_hooks() | |