import os import gradio as gr import numpy as np import torch from PIL import Image from loguru import logger from tqdm import tqdm from tools.common_utils import save_video from dkt.pipelines.pipeline import DKTPipeline, ModelConfig import cv2 import copy import trimesh from os.path import join from tools.depth2pcd import depth2pcd # from moge.model.v2 import MoGeModel from tools.eval_utils import transfer_pred_disp2depth, colorize_depth_map import glob import datetime import shutil import tempfile import spaces import time #* better for bg: logs/outs/train/remote/sft-T2SQNet_glassverse_cleargrasp_HISS_DREDS_DREDS_glassverse_interiorverse-4gpus-origin-lora128-1.3B-rgb_depth-w832-h480-Wan2.1-Fun-Control-2025-10-28-23:26:41/epoch-0-20000.safetensors PROMPT = 'depth' NEGATIVE_PROMPT = '' height = 480 width = 832 window_size = 21 DKT_PIPELINE = DKTPipeline() example_inputs = [ ["examples/1.mp4", "1.3B", 5, 3], ["examples/33.mp4", "1.3B", 5, 3], ["examples/7.mp4", "1.3B", 5, 3], ["examples/8.mp4", "1.3B", 5, 3], ["examples/9.mp4", "1.3B", 5, 3], # ["examples/178db6e89ab682bfc612a3290fec58dd.mp4", "1.3B", 5, 3], ["examples/36.mp4", "1.3B", 5, 3], ["examples/39.mp4", "1.3B", 5, 3], # ["examples/b1f1fa44f414d7731cd7d77751093c44.mp4", "1.3B", 5, 3], ["examples/10.mp4", "1.3B", 5, 3], ["examples/30.mp4", "1.3B", 5, 3], ["examples/3.mp4", "1.3B", 5, 3], ["examples/32.mp4", "1.3B", 5, 3], ["examples/35.mp4", "1.3B", 5, 3], ["examples/40.mp4", "1.3B", 5, 3], ["examples/2.mp4", "1.3B", 5, 3], # ["examples/31.mp4", "1.3B", 5, 3], # ["examples/DJI_20250912164311_0007_D.mp4", "1.3B", 5, 3], # ["examples/DJI_20250912163642_0003_D.mp4", "1.3B", 5, 3], # ["examples/5.mp4", "1.3B", 5, 3], # ["examples/1b0daeb776471c7389b36cee53049417.mp4", "1.3B", 5, 3], # ["examples/8a6dfb8cfe80634f4f77ae9aa830d075.mp4", "1.3B", 5, 3], # ["examples/69230f105ad8740e08d743a8ee11c651.mp4", "1.3B", 5, 3], # ["examples/b68045aa2128ab63d9c7518f8d62eafe.mp4", "1.3B", 5, 3], ] def pmap_to_glb(point_map, valid_mask, frame) -> trimesh.Scene: pts_3d = point_map[valid_mask] * np.array([-1, -1, 1]) pts_rgb = frame[valid_mask] # Initialize a 3D scene scene_3d = trimesh.Scene() # Add point cloud data to the scene point_cloud_data = trimesh.PointCloud( vertices=pts_3d, colors=pts_rgb ) scene_3d.add_geometry(point_cloud_data) return scene_3d def create_simple_glb_from_pointcloud(points, colors, glb_filename): try: if len(points) == 0: logger.warning(f"No valid points to create GLB for {glb_filename}") return False if colors is not None: # logger.info(f"Adding colors to GLB: shape={colors.shape}, range=[{colors.min():.3f}, {colors.max():.3f}]") pts_rgb = colors else: logger.info("No colors provided, adding default white colors") pts_rgb = np.ones((len(points), 3)) valid_mask = np.ones(len(points), dtype=bool) scene_3d = pmap_to_glb(points, valid_mask, pts_rgb) scene_3d.export(glb_filename) # logger.info(f"Saved GLB file using trimesh: {glb_filename}") return True except Exception as e: logger.error(f"Error creating GLB from pointcloud using trimesh: {str(e)}") return False def process_video( video_file, model_size, num_inference_steps, overlap ): global height global width global window_size global DKT_PIPELINE timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") cur_save_dir = tempfile.mkdtemp(prefix=f'dkt_{timestamp}_{model_size}_') start_time = time.time() prediction_result = DKT_PIPELINE( video_file, prompt=PROMPT, negative_prompt=NEGATIVE_PROMPT, height=height, width=width, num_inference_steps=num_inference_steps, overlap=overlap, return_rgb=True ) end_time = time.time() spend_time = end_time - start_time logger.info(f"DKT_PIPELINE spend time: {spend_time:.2f} seconds for depth prediction") print(f"DKT_PIPELINE spend time: {spend_time:.2f} seconds for depth prediction") frame_length = len(prediction_result['rgb_frames']) vis_pc_num = 4 indices = np.linspace(0, frame_length-1, vis_pc_num) indices = np.round(indices).astype(np.int32) pc_start_time = time.time() pcds = DKT_PIPELINE.prediction2pc_v2(prediction_result['depth_map'], prediction_result['rgb_frames'], indices, return_pcd=True) pc_end_time = time.time() pc_spend_time = pc_end_time - pc_start_time logger.info(f"prediction2pc_v2 spend time: {pc_spend_time:.2f} seconds for point cloud extraction") print(f"prediction2pc_v2 spend time: {pc_spend_time:.2f} seconds for point cloud extraction") glb_files = [] for idx, pcd in enumerate(pcds): points = np.asarray(pcd.points) colors = np.asarray(pcd.colors) if pcd.has_colors() else None logger.info(f'points:{points.shape}, colors: {colors.shape}') print(f'points:{points.shape}, colors: {colors.shape}') points[:, 2] = -points[:, 2] points[:, 0] = -points[:, 0] glb_filename = os.path.join(cur_save_dir, f'{timestamp}_{idx:02d}.glb') success = create_simple_glb_from_pointcloud(points, colors, glb_filename) if not success: logger.warning(f"Failed to save GLB file: {glb_filename}") print(f"Failed to save GLB file: {glb_filename}") glb_files.append(glb_filename) #* save depth predictions video output_filename = f"output_{timestamp}.mp4" output_path = os.path.join(cur_save_dir, output_filename) cap = cv2.VideoCapture(video_file) input_fps = cap.get(cv2.CAP_PROP_FPS) cap.release() save_video(prediction_result['colored_depth_map'], output_path, fps=input_fps, quality=8) return output_path, glb_files #* gradio creation and initialization css = """ #download { height: 118px; } .slider .inner { width: 5px; background: #FFF; } .viewport { aspect-ratio: 4/3; } .tabs button.selected { font-size: 20px !important; color: crimson !important; } h1 { text-align: center; display: block; } h2 { text-align: center; display: block; } h3 { text-align: center; display: block; } .md_feedback li { margin-bottom: 0px !important; } """ head_html = """ """ with gr.Blocks(css=css, title="DKT", head=head_html) as demo: # gr.Markdown(title, elem_classes=["title"]) gr.Markdown( """ # Diffusion Knows Transparency: Repurposing Video Diffusion for Transparent Object Depth and Normal Estimation

badge-github-stars social """ ) # gr.Markdown(description, elem_classes=["description"]) # gr.Markdown("### Video Processing Demo", elem_classes=["description"]) with gr.Row(): with gr.Column(): input_video = gr.Video(label="Input Video", elem_id='video-display-input') model_size = gr.Radio( # choices=["1.3B", "14B"], choices=["1.3B"], value="1.3B", label="Model Size" ) with gr.Accordion("Advanced Parameters", open=False): num_inference_steps = gr.Slider( minimum=1, maximum=50, value=5, step=1, label="Number of Inference Steps" ) overlap = gr.Slider( minimum=1, maximum=20, value=3, step=1, label="Overlap" ) submit = gr.Button(value="Compute Depth", variant="primary") with gr.Column(): output_video = gr.Video( label="Depth Outputs", elem_id='video-display-output', autoplay=True ) vis_video = gr.Video( label="Visualization Video", visible=False, autoplay=True ) with gr.Row(): gr.Markdown("### 3D Point Cloud Visualization", elem_classes=["title"]) with gr.Row(equal_height=True): with gr.Column(scale=1): output_point_map0 = gr.Model3D( label="Point Cloud Key Frame 1", clear_color=[1.0, 1.0, 1.0, 1.0], interactive=False, ) with gr.Column(scale=1): output_point_map1 = gr.Model3D( label="Point Cloud Key Frame 2", clear_color=[1.0, 1.0, 1.0, 1.0], interactive=False ) with gr.Row(equal_height=True): with gr.Column(scale=1): output_point_map2 = gr.Model3D( label="Point Cloud Key Frame 3", clear_color=[1.0, 1.0, 1.0, 1.0], interactive=False ) with gr.Column(scale=1): output_point_map3 = gr.Model3D( label="Point Cloud Key Frame 4", clear_color=[1.0, 1.0, 1.0, 1.0], interactive=False ) def on_submit(video_file, model_size, num_inference_steps, overlap): logger.info('on_submit is calling') if video_file is None: return None, None, None, None, None, None, "Please upload a video file" try: start_time = time.time() output_path, glb_files = process_video( video_file, model_size, num_inference_steps, overlap ) spend_time = time.time() - start_time logger.info(f"Total spend time in on_submit: {spend_time:.2f} seconds") print(f"Total spend time in on_submit: {spend_time:.2f} seconds") if output_path is None: return None, None, None, None, None, None, glb_files model3d_outputs = [None] * 4 if glb_files: for i, glb_file in enumerate(glb_files[:4]): if os.path.exists(glb_file): model3d_outputs[i] = glb_file return output_path, None, *model3d_outputs except Exception as e: logger.error(e) return None, None, None, None, None, None submit.click( on_submit, inputs=[ input_video, model_size, num_inference_steps, overlap ], outputs=[ output_video, vis_video, output_point_map0, output_point_map1, output_point_map2, output_point_map3 ] ) logger.info(f'there are {len(example_inputs)} demo files') print(f'there are {len(example_inputs)} demo files') examples = gr.Examples( examples=example_inputs, inputs=[input_video, model_size, num_inference_steps, overlap], outputs=[ output_video, vis_video, output_point_map0, output_point_map1, output_point_map2, output_point_map3 ], fn=on_submit, examples_per_page=12, cache_examples=False ) if __name__ == '__main__': #* main code, model and moge model initialization demo.queue().launch(share = True)