marigold-dc-metric / marigold_dc.py
toshas's picture
another sparsity issue fixed
1e6104e
import logging
import warnings
import diffusers
import numpy as np
import torch
from diffusers import MarigoldDepthPipeline
warnings.simplefilter(action="ignore", category=FutureWarning)
diffusers.utils.logging.disable_progress_bar()
class MarigoldDepthCompletionPipeline(MarigoldDepthPipeline):
def __call__(
self,
image,
sparse_depth,
num_inference_steps=50,
processing_resolution=0,
seed=2024,
lr_scale_shift=0.005,
lr_latent=0.05,
override_shift=None,
override_scale=None,
dry_run=False,
):
# Resolving variables
device = self._execution_device
generator = torch.Generator(device=device).manual_seed(seed)
if dry_run:
logging.warning("Dry run mode")
for i in range(num_inference_steps):
yield np.array(image)[:, :, 0].astype(float), float(np.log(i + 1))
return
# Check inputs.
if num_inference_steps is None:
raise ValueError("Invalid num_inference_steps")
if sparse_depth is not None and (type(sparse_depth) is not np.ndarray or sparse_depth.ndim != 2):
raise ValueError(
"Sparse depth should be a 2D numpy ndarray with zeros at missing positions"
)
with torch.no_grad():
# Prepare empty text conditioning
if self.empty_text_embedding is None:
prompt = ""
text_inputs = self.tokenizer(
prompt,
padding="do_not_pad",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(device)
self.empty_text_embedding = self.text_encoder(text_input_ids)[0] # [1,2,1024]
# Preprocess input images
image, padding, original_resolution = self.image_processor.preprocess(
image,
processing_resolution=processing_resolution,
device=device,
dtype=self.dtype,
) # [N,3,PPH,PPW]
if sparse_depth is not None and sparse_depth.shape != original_resolution:
raise ValueError(
f"Sparse depth dimensions ({sparse_depth.shape}) must match that of the image ({image.shape[-2:]})"
)
with torch.no_grad():
# Encode input image into latent space
image_latent, pred_latent = self.prepare_latents(
image, None, generator, 1, 1
) # [N*E,4,h,w], [N*E,4,h,w]
del image
# Preprocess sparse depth
if sparse_depth is not None:
sparse_depth = torch.from_numpy(sparse_depth)[None, None].float()
sparse_depth = sparse_depth.to(device)
sparse_mask = sparse_depth > 0
sparse_depth = sparse_depth[sparse_mask]
sparse_depth_min = sparse_depth.min() if sparse_depth.numel() > 0 else 0
sparse_depth_max = sparse_depth.max() if sparse_depth.numel() > 0 else 1
else:
sparse_depth_min = 0
sparse_depth_max = 1
# Set up optimization targets
pred_latent = torch.nn.Parameter(pred_latent, requires_grad=True)
if override_scale:
scale = np.sqrt(override_scale)
sparse_range = 1.0
else:
scale = torch.nn.Parameter(torch.ones(1, device=device), requires_grad=True)
sparse_range = sparse_depth_max - sparse_depth_min
if torch.is_tensor(sparse_range):
sparse_range = sparse_range.item()
if override_shift:
shift = np.sqrt(override_shift)
sparse_lower = 1.0
else:
shift = torch.nn.Parameter(torch.ones(1, device=device), requires_grad=True)
sparse_lower = sparse_depth_min
if torch.is_tensor(sparse_range):
sparse_lower = sparse_lower.item()
def affine_to_metric(depth):
return (scale**2) * sparse_range * depth + (shift**2) * sparse_lower
def latent_to_metric(latent):
affine_invariant_prediction = self.decode_prediction(
latent
) # [E,1,PPH,PPW]
affine_invariant_prediction = affine_invariant_prediction.to(torch.float32)
prediction = affine_to_metric(affine_invariant_prediction)
prediction = self.image_processor.unpad_image(
prediction, padding
) # [E,1,PH,PW]
prediction = self.image_processor.resize_antialias(
prediction, original_resolution, "bilinear", is_aa=False
) # [1,1,H,W]
return prediction
def loss_l1l2(input, target):
out_l1 = torch.nn.functional.l1_loss(input, target)
out_l2 = torch.nn.functional.mse_loss(input, target)
out = out_l1 + out_l2
return out, out_l2.sqrt()
optimizer_params = [{"params": [pred_latent], "lr": lr_latent}]
if override_shift is None:
optimizer_params.append({"params": [shift], "lr": lr_scale_shift})
if override_scale is None:
optimizer_params.append({"params": [scale], "lr": lr_scale_shift})
optimizer = torch.optim.Adam(optimizer_params)
# Process the denoising loop
self.scheduler.set_timesteps(num_inference_steps, device=device)
for iter, t in enumerate(
self.progress_bar(self.scheduler.timesteps, desc=f"Marigold-DC steps ({str(device)})...")
):
optimizer.zero_grad()
batch_latent = torch.cat([image_latent, pred_latent], dim=1) # [1,8,h,w]
noise = self.unet(
batch_latent,
t,
encoder_hidden_states=self.empty_text_embedding,
return_dict=False,
)[0] # [1,4,h,w]
# Compute pred_epsilon to later rescale the depth latent gradient
with torch.no_grad():
alpha_prod_t = self.scheduler.alphas_cumprod[t]
beta_prod_t = 1 - alpha_prod_t
pred_epsilon = (alpha_prod_t**0.5) * noise + (
beta_prod_t**0.5
) * pred_latent
step_output = self.scheduler.step(noise, t, pred_latent, generator=generator)
# Preview the final output depth, compute loss with guidance, backprop
pred_original_sample = step_output.pred_original_sample
current_metric_estimate = latent_to_metric(pred_original_sample)
rmse = 0
if sparse_depth is not None and sparse_depth.numel() > 0:
loss, rmse = loss_l1l2(current_metric_estimate[sparse_mask], sparse_depth)
rmse = rmse.item()
loss.backward()
# Scale gradients up
with torch.no_grad():
pred_epsilon_norm = torch.linalg.norm(pred_epsilon).item()
depth_latent_grad_norm = torch.linalg.norm(pred_latent.grad).item()
scaling_factor = pred_epsilon_norm / max(depth_latent_grad_norm, 1e-8)
pred_latent.grad *= scaling_factor
optimizer.step()
with torch.no_grad():
pred_latent.data = self.scheduler.step(noise, t, pred_latent, generator=generator).prev_sample
yield current_metric_estimate, rmse
del (
pred_original_sample,
current_metric_estimate,
step_output,
pred_epsilon,
noise,
)
torch.cuda.empty_cache()
# Offload all models
self.maybe_free_model_hooks()