# Copyright 2024 Anton Obukhov, ETH Zurich. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # -------------------------------------------------------------------------- # If you find this code useful, we kindly ask you to cite our paper in your work. # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation # More information about the method can be found at https://marigoldmonodepth.github.io # -------------------------------------------------------------------------- # TODO: 16bit depth map download # TODO: change to gradio-dualvision (update it with the Examples thumbs first) import os import PIL import pandas import requests import spaces import gradio as gr import numpy as np import plotly.graph_objects as go import torch as torch from PIL import Image, ImageDraw from scipy.ndimage import maximum_filter from huggingface_hub import login from marigold_dc import MarigoldDepthCompletionPipeline DRY_RUN = os.environ.get("ACCELERATOR", "cpu") not in ("zero", "gpu") DEFAULT_denoise_steps = 10 DEFAULT_lr_latent = 0.05 DEFAULT_lr_scale_shift = 0.005 TILE_CHAR = "██" TAB10_COLORS = [ (31, 119, 180), # blue (255, 127, 14), # orange (44, 160, 44), # green (214, 39, 40), # red (148, 103, 189), # purple (140, 86, 75), # brown (227, 119, 194), # pink (127, 127, 127), # gray (188, 189, 34), # olive (23, 190, 207) # cyan ] def adjust_brightness(color, factor): return tuple( max(0, min(255, int(c * factor))) for c in color ) def get_wrapped_color(index): base_index = index % len(TAB10_COLORS) wrap_count = index // len(TAB10_COLORS) base_color = TAB10_COLORS[base_index] factor = 1.0 + 0.15 * ((wrap_count % 2) * 2 - 1) * (wrap_count // 2 + 1) return adjust_brightness(base_color, factor) def process_click_data(img: Image.Image, state_orig_img: gr.State, table, x: int, y: int, value: str = ""): if isinstance(img, str): img = Image.open(img) if state_orig_img is None: state_orig_img = img.copy() if isinstance(table, pandas.DataFrame): table = table.values.tolist() color = get_wrapped_color(len(table)) color_hex = '#%02x%02x%02x' % color img = img.convert("RGB") draw = ImageDraw.Draw(img) width, _ = img.size r = int(width * 0.015) draw.ellipse((x - r, y - r, x + r, y + r), fill=color, outline=color) draw.ellipse((x - r, y - r, x + r, y + r), fill=None, outline=(255, 255, 255), width=max(1, r//4)) if not isinstance(table, list): table = table.values.tolist() table = table + [[TILE_CHAR, value, x, y, color_hex]] return img, state_orig_img, table def on_click(img: Image.Image, state_orig_img: gr.State, evt: gr.SelectData, table): x, y = evt.index img, state_orig_img, table = process_click_data(img, state_orig_img, table, x, y) return img, state_orig_img, gr.Dataframe(table, visible=True) def dilate_rgb_image(image, kernel_size): r_channel, g_channel, b_channel = image[..., 0], image[..., 1], image[..., 2] r_dilated = maximum_filter(r_channel, size=kernel_size) g_dilated = maximum_filter(g_channel, size=kernel_size) b_dilated = maximum_filter(b_channel, size=kernel_size) dilated_image = np.stack([r_dilated, g_dilated, b_dilated], axis=-1) return dilated_image def generate_rmse_plot(steps, metrics, denoise_steps): fig = go.Figure() fig.add_trace( go.Scatter( x=steps, y=metrics, mode="lines+markers", line=dict(color="#af2928"), name="RMSE", ) ) if denoise_steps < 20: x_dtick = 1 else: x_dtick = 5 fig.update_layout( autosize=True, height=300, margin=dict(l=20, r=20, t=20, b=20), xaxis_title="Steps", xaxis_range=[0, denoise_steps + 1], xaxis=dict( scaleanchor="y", scaleratio=1.5, dtick=x_dtick, ), yaxis_title="RMSE", yaxis=dict( type="log", ), hovermode="x unified", template="plotly_white", ) return fig @spaces.GPU def process( image, state_orig_img, table, path_sparse, denoise_steps=DEFAULT_denoise_steps, lr_latent=DEFAULT_lr_latent, lr_scale_shift=DEFAULT_lr_scale_shift, override_shift=None, override_scale=None, ): if override_shift is None: pass elif np.isnan(override_shift): override_shift = None else: override_shift = float(override_shift) if override_scale is None: pass elif np.isnan(override_scale): override_scale = None else: override_scale = float(override_scale) if isinstance(state_orig_img, str): image = Image.open(state_orig_img) elif isinstance(state_orig_img, PIL.Image.Image): image = state_orig_img elif isinstance(image, str): image = Image.open(image) elif isinstance(image, PIL.Image.Image): pass else: raise TypeError(f"Unknown image type: {type(image)}") if isinstance(table, pandas.DataFrame): table = table.values.tolist() if path_sparse is not None and os.path.exists(path_sparse): # numpy file given (lidar) sparse_depth = np.load(path_sparse) sparse_depth_valid = sparse_depth[sparse_depth > 0] sparse_depth_min = np.min(sparse_depth_valid) sparse_depth_max = np.max(sparse_depth_valid) kernel_size = 5 elif table is not None and len(table) >= 2: # clicks annotations sparse_depth = np.full((image.height, image.width), np.nan, dtype=np.float32) for entry in table: try: sparse_depth[entry[3], entry[2]] = float(entry[1]) except Exception: pass sparse_depth_valid_mask = sparse_depth == sparse_depth sparse_depth_valid = sparse_depth[sparse_depth_valid_mask] sparse_depth_valid_num = np.sum(sparse_depth_valid_mask) if sparse_depth_valid_num >= 2: sparse_depth_min = np.min(sparse_depth_valid) sparse_depth_max = np.max(sparse_depth_valid) sparse_depth[~sparse_depth_valid_mask] = 0 kernel_size = 10 else: sparse_depth = None sparse_depth_min = 0 sparse_depth_max = 1 kernel_size = 5 else: sparse_depth = None sparse_depth_min = 0 sparse_depth_max = 1 kernel_size = 5 width, height = image.size max_dim = max(width, height) processing_resolution = 0 if max_dim > 768: processing_resolution = 768 metrics = [] steps = [] for step, (pred, rmse) in enumerate( pipe( image=image, sparse_depth=sparse_depth, num_inference_steps=denoise_steps + 1, processing_resolution=processing_resolution, lr_latent=lr_latent, lr_scale_shift=lr_scale_shift, override_shift=override_shift, override_scale=override_scale, dry_run=DRY_RUN, ) ): min_both = pred.min().item() max_both = pred.max().item() if sparse_depth is not None: min_both = min(sparse_depth_min, min_both) max_both = min(sparse_depth_max, max_both) metrics.append(rmse) steps.append(step) vis_pred = pipe.image_processor.visualize_depth(pred, val_min=min_both, val_max=max_both)[0] if sparse_depth is not None: vis_sparse = pipe.image_processor.visualize_depth(sparse_depth, val_min=min_both, val_max=max_both)[0] vis_sparse = np.array(vis_sparse) vis_sparse[sparse_depth <= 0] = (0, 0, 0) vis_sparse = dilate_rgb_image(vis_sparse, kernel_size=kernel_size) else: vis_sparse = np.full_like(vis_pred, 0) vis_sparse = Image.fromarray(vis_sparse) plot = generate_rmse_plot(steps, metrics, denoise_steps) plot = gr.Plot(plot, visible=True) slider = gr.ImageSlider([vis_sparse, vis_pred], visible=True) yield slider, plot os.system("pip freeze") print("Environment:\n" + "\n".join(f"{k}: {os.environ[k]}" for k in sorted(os.environ.keys()))) if "HF_TOKEN_LOGIN" in os.environ: login(token=os.environ["HF_TOKEN_LOGIN"]) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") pipe = MarigoldDepthCompletionPipeline.from_pretrained( "prs-eth/marigold-depth-v1-1", torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, ) try: import xformers pipe.enable_xformers_memory_efficient_attention() except: print("Running without xformers") pipe = pipe.to(device) os.environ["GRADIO_ALLOW_FLAGGING"] = "never" with gr.Blocks( theme=gr.themes.Default( primary_hue=gr.themes.colors.red, spacing_size=gr.themes.sizes.spacing_sm, radius_size="none", text_size="md", ).set( button_secondary_background_fill="black", button_secondary_text_color="white", body_background_fill="linear-gradient(to right, #FFE0D0, #E0F0FF)" ), analytics_enabled=False, title="Marigold Depth Completion", css=""" .slider .inner { width: 4px; background: #FFF; } .slider .icon-wrap { fill: #FFF; background-color: #FFF; stroke: #FFF; stroke-width: 3px; } .viewport { aspect-ratio: 4/3; } h1 { text-align: center; display: block; } h2 { text-align: center; display: block; } h3 { text-align: center; display: block; } """, head=""" """ ) as demo: gr.HTML( """

⇆ Marigold-DC: Zero-Shot Monocular Depth Completion with Guided Diffusion

Website Badge arXiv Badge badge-github-stars social
Upload any image, annotate with a few clicks, and compute dense metric depth!
Alternatively, explore advanced LiDAR functionality and examples at the bottom.

""" ) state_orig_img = gr.State() with gr.Row(): with gr.Column(): thumb = gr.Image( label="Thumb Image", type="filepath", visible=False, ) input_image = gr.Image( label="Input image (click to enter depth)", type="filepath", interactive=True, ) table = gr.Dataframe( headers=["Color", "Enter depth estimates (any unit)", "x", "y", "_color"], datatype=["str", "number", "number", "number", "str"], column_widths=["30px", "120px", "0px", "0px", "0px"], static_columns=[0, 2, 3, 4], show_fullscreen_button=False, show_copy_button=False, show_row_numbers=False, show_search="none", row_count=0, interactive=True, visible=False, ) with gr.Accordion("Advanced options", open=False): with gr.Row(): with gr.Column(): denoise_steps = gr.Slider( label="Number of denoising steps", minimum=4, maximum=50, step=1, value=15, ) lr_latent = gr.Number( DEFAULT_lr_latent, interactive=True, label="Latent LR", step=0.001, ) with gr.Row(): lr_scale_shift = gr.Number( DEFAULT_lr_scale_shift, interactive=True, label="Scale-and-shift LR", step=0.001, min_width=90, ) override_shift = gr.Number( value=float("NaN"), label="Shift override", min_width=90, ) override_scale = gr.Number( value=float("NaN"), label="Scale override", min_width=90, ) with gr.Column(): input_sparse = gr.File( label="Input sparse depth (numpy file)", ) with gr.Row(): submit_btn = gr.Button(value="Compute Depth", variant="primary") clear_btn = gr.Button(value="Clear") with gr.Column(): output_slider = gr.ImageSlider( label="Completed depth (red-near, blue-far)", type="filepath", show_download_button=True, interactive=False, elem_classes="slider", slider_position=25, ) plot = gr.Plot( label="RMSE between sparse measurements and densified depth", elem_id="viewport", ) input_image.select( on_click, inputs=[ input_image, state_orig_img, table, ], outputs=[ input_image, state_orig_img, table, ], ) input_image.upload( lambda : gr.update(label="Click and provide depth estimates in the table below"), outputs=input_image, ) def submit_depth_fn( image, state_orig_img, table, path_sparse, denoise_steps, lr_latent, lr_scale_shift, override_shift, override_scale, ): for outputs in process( image, state_orig_img, table, path_sparse, denoise_steps, lr_latent, lr_scale_shift, override_shift, override_scale, ): yield outputs submit_btn.click( fn=submit_depth_fn, inputs=[ input_image, state_orig_img, table, input_sparse, denoise_steps, lr_latent, lr_scale_shift, override_shift, override_scale, ], outputs=[ output_slider, plot, ], ) def examples_depth_lidar_fn(path_thumb): real_url = lambda fname: f"https://huggingface.co/spaces/obukhovai/marigold-dc-metric/resolve/main/files/{fname}" l_thumb = os.path.basename(path_thumb) d_thumb = os.path.dirname(path_thumb) l_image, l_sparse, clicks, nsteps = { "thumb_matterhorn_clicks.jpg": ["matterhorn.jpg", None, [ [TILE_CHAR, "3", 106, 276, '#%02x%02x%02x' % get_wrapped_color(0)], [TILE_CHAR, "2", 527, 600, '#%02x%02x%02x' % get_wrapped_color(1)], ], 15], "thumb_kitti_1.jpg": ["kitti_1.png", "kitti_1.npy", [], 25], "thumb_kitti_2.jpg": ["kitti_2.png", "kitti_2.npy", [], 25], "thumb_teaser_10.jpg": ["teaser.png", "teaser_10.npy", [], 25], "thumb_teaser_100.jpg": ["teaser.png", "teaser_100.npy", [], 25], "thumb_teaser_1000.jpg": ["teaser.png", "teaser_1000.npy", [], 25], }[l_thumb] u_image = real_url(l_image) l_down_image = os.path.join(d_thumb, l_image) response = requests.get(u_image) response.raise_for_status() with open(l_down_image, "wb") as f: f.write(response.content) table_visible = len(clicks) > 0 l_down_sparse = None if l_sparse is not None: u_sparse = real_url(l_sparse) l_down_sparse = os.path.join(d_thumb, l_sparse) response = requests.get(u_sparse) response.raise_for_status() with open(l_down_sparse, "wb") as f: f.write(response.content) state_orig_img = None table = [] if len(clicks) > 0: for click in clicks: _, value, x, y, _ = click l_down_image, state_orig_img, table = process_click_data(l_down_image, state_orig_img, table, x, y, value) for outputs in process(l_down_image, state_orig_img, clicks, l_down_sparse, denoise_steps=nsteps): yield l_down_image, l_down_sparse, state_orig_img, gr.Dataframe(table, visible=table_visible), *outputs examples = gr.Examples( fn=examples_depth_lidar_fn, examples=[ "files/thumb_matterhorn_clicks.jpg", "files/thumb_kitti_1.jpg", "files/thumb_kitti_2.jpg", "files/thumb_teaser_10.jpg", "files/thumb_teaser_100.jpg", "files/thumb_teaser_1000.jpg", ], inputs=[ thumb, ], outputs=[ input_image, input_sparse, state_orig_img, table, output_slider, plot, ], cache_mode="lazy", cache_examples=False, run_on_click=True, ) def clear_fn(): return [ gr.update(value=None, interactive=True, label="Input image"), gr.File(None, interactive=True), None, None, gr.Dataframe([[]], visible=False), None, gr.update(interactive=True), ] clear_btn.click( fn=clear_fn, inputs=[], outputs=[ input_image, input_sparse, output_slider, plot, table, state_orig_img, submit_btn, ], ) demo.queue( api_open=False, ).launch( server_name="0.0.0.0", server_port=7860, ssr_mode=False, )