Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from loguru import logger | |
| from tqdm import tqdm | |
| from tools.common_utils import save_video | |
| from dkt.pipelines.pipeline import DKTPipeline, ModelConfig | |
| import cv2 | |
| import copy | |
| import trimesh | |
| from os.path import join | |
| from tools.depth2pcd import depth2pcd | |
| # from moge.model.v2 import MoGeModel | |
| from tools.eval_utils import transfer_pred_disp2depth, colorize_depth_map | |
| import glob | |
| import datetime | |
| import shutil | |
| import tempfile | |
| import spaces | |
| #* better for bg: logs/outs/train/remote/sft-T2SQNet_glassverse_cleargrasp_HISS_DREDS_DREDS_glassverse_interiorverse-4gpus-origin-lora128-1.3B-rgb_depth-w832-h480-Wan2.1-Fun-Control-2025-10-28-23:26:41/epoch-0-20000.safetensors | |
| PROMPT = 'depth' | |
| NEGATIVE_PROMPT = '' | |
| height = 480 | |
| width = 832 | |
| window_size = 21 | |
| DKT_PIPELINE = DKTPipeline() | |
| example_inputs = [ | |
| ["examples/1.mp4", "1.3B", 5, 3], | |
| ["examples/33.mp4", "1.3B", 5, 3], | |
| ["examples/7.mp4", "1.3B", 5, 3], | |
| ["examples/8.mp4", "1.3B", 5, 3], | |
| ["examples/9.mp4", "1.3B", 5, 3], | |
| # ["examples/178db6e89ab682bfc612a3290fec58dd.mp4", "1.3B", 5, 3], | |
| ["examples/36.mp4", "1.3B", 5, 3], | |
| ["examples/39.mp4", "1.3B", 5, 3], | |
| # ["examples/b1f1fa44f414d7731cd7d77751093c44.mp4", "1.3B", 5, 3], | |
| ["examples/10.mp4", "1.3B", 5, 3], | |
| ["examples/30.mp4", "1.3B", 5, 3], | |
| ["examples/3.mp4", "1.3B", 5, 3], | |
| ["examples/32.mp4", "1.3B", 5, 3], | |
| ["examples/35.mp4", "1.3B", 5, 3], | |
| ["examples/40.mp4", "1.3B", 5, 3], | |
| ["examples/2.mp4", "1.3B", 5, 3], | |
| # ["examples/31.mp4", "1.3B", 5, 3], | |
| # ["examples/DJI_20250912164311_0007_D.mp4", "1.3B", 5, 3], | |
| # ["examples/DJI_20250912163642_0003_D.mp4", "1.3B", 5, 3], | |
| # ["examples/5.mp4", "1.3B", 5, 3], | |
| # ["examples/1b0daeb776471c7389b36cee53049417.mp4", "1.3B", 5, 3], | |
| # ["examples/8a6dfb8cfe80634f4f77ae9aa830d075.mp4", "1.3B", 5, 3], | |
| # ["examples/69230f105ad8740e08d743a8ee11c651.mp4", "1.3B", 5, 3], | |
| # ["examples/b68045aa2128ab63d9c7518f8d62eafe.mp4", "1.3B", 5, 3], | |
| ] | |
| def pmap_to_glb(point_map, valid_mask, frame) -> trimesh.Scene: | |
| pts_3d = point_map[valid_mask] * np.array([-1, -1, 1]) | |
| pts_rgb = frame[valid_mask] | |
| # Initialize a 3D scene | |
| scene_3d = trimesh.Scene() | |
| # Add point cloud data to the scene | |
| point_cloud_data = trimesh.PointCloud( | |
| vertices=pts_3d, colors=pts_rgb | |
| ) | |
| scene_3d.add_geometry(point_cloud_data) | |
| return scene_3d | |
| def create_simple_glb_from_pointcloud(points, colors, glb_filename): | |
| try: | |
| if len(points) == 0: | |
| logger.warning(f"No valid points to create GLB for {glb_filename}") | |
| return False | |
| if colors is not None: | |
| # logger.info(f"Adding colors to GLB: shape={colors.shape}, range=[{colors.min():.3f}, {colors.max():.3f}]") | |
| pts_rgb = colors | |
| else: | |
| logger.info("No colors provided, adding default white colors") | |
| pts_rgb = np.ones((len(points), 3)) | |
| valid_mask = np.ones(len(points), dtype=bool) | |
| scene_3d = pmap_to_glb(points, valid_mask, pts_rgb) | |
| scene_3d.export(glb_filename) | |
| # logger.info(f"Saved GLB file using trimesh: {glb_filename}") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error creating GLB from pointcloud using trimesh: {str(e)}") | |
| return False | |
| def process_video( | |
| video_file, | |
| model_size, | |
| num_inference_steps, | |
| overlap | |
| ): | |
| global height | |
| global width | |
| global window_size | |
| global DKT_PIPELINE | |
| timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | |
| cur_save_dir = tempfile.mkdtemp(prefix=f'dkt_{timestamp}_{model_size}_') | |
| prediction_result = DKT_PIPELINE(video_file, prompt=PROMPT, \ | |
| negative_prompt=NEGATIVE_PROMPT,\ | |
| height=height,width=width,num_inference_steps=num_inference_steps,\ | |
| overlap=overlap, return_rgb=True) | |
| frame_length = len(prediction_result['rgb_frames']) | |
| vis_pc_num = 4 | |
| indices = np.linspace(0, frame_length-1, vis_pc_num) | |
| indices = np.round(indices).astype(np.int32) | |
| pcds = DKT_PIPELINE.prediction2pc_v2(prediction_result['depth_map'], prediction_result['rgb_frames'], indices, return_pcd=True) | |
| glb_files = [] | |
| for idx, pcd in enumerate(pcds): | |
| points = np.asarray(pcd.points) | |
| colors = np.asarray(pcd.colors) if pcd.has_colors() else None | |
| points[:, 2] = -points[:, 2] | |
| points[:, 0] = -points[:, 0] | |
| glb_filename = os.path.join(cur_save_dir, f'{timestamp}_{idx:02d}.glb') | |
| success = create_simple_glb_from_pointcloud(points, colors, glb_filename) | |
| if not success: | |
| logger.warning(f"Failed to save GLB file: {glb_filename}") | |
| glb_files.append(glb_filename) | |
| #* save depth predictions video | |
| output_filename = f"output_{timestamp}.mp4" | |
| output_path = os.path.join(cur_save_dir, output_filename) | |
| cap = cv2.VideoCapture(video_file) | |
| input_fps = cap.get(cv2.CAP_PROP_FPS) | |
| cap.release() | |
| save_video(prediction_result['colored_depth_map'], output_path, fps=input_fps, quality=8) | |
| return output_path, glb_files | |
| #* gradio creation and initialization | |
| css = """ | |
| #video-display-container { | |
| max-height: 100vh; | |
| } | |
| #video-display-input { | |
| max-height: 80vh; | |
| } | |
| #video-display-output { | |
| max-height: 80vh; | |
| } | |
| #download { | |
| height: 62px; | |
| } | |
| .title { | |
| text-align: center; | |
| } | |
| .description { | |
| text-align: center; | |
| } | |
| .gradio-examples { | |
| max-height: 400px; | |
| overflow-y: auto; | |
| } | |
| .gradio-examples .examples-container { | |
| display: grid; | |
| grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)); | |
| gap: 10px; | |
| padding: 10px; | |
| } | |
| .gradio-container .gradio-examples .pagination, | |
| .gradio-container .gradio-examples .pagination button, | |
| div[data-testid="examples"] .pagination, | |
| div[data-testid="examples"] .pagination button { | |
| font-size: 28px !important; | |
| font-weight: bold !important; | |
| padding: 15px 20px !important; | |
| min-width: 60px !important; | |
| height: 60px !important; | |
| border-radius: 10px !important; | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important; | |
| color: white !important; | |
| border: none !important; | |
| cursor: pointer !important; | |
| margin: 8px !important; | |
| display: inline-block !important; | |
| box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important; | |
| transition: all 0.3s ease !important; | |
| } | |
| div[data-testid="examples"] .pagination button:not(.active), | |
| .gradio-container .gradio-examples .pagination button:not(.active) { | |
| font-size: 32px !important; | |
| font-weight: bold !important; | |
| padding: 15px 20px !important; | |
| min-width: 60px !important; | |
| height: 60px !important; | |
| background: linear-gradient(135deg, #8a9cf0 0%, #9a6bb2 100%) !important; | |
| opacity: 0.8 !important; | |
| } | |
| div[data-testid="examples"] .pagination button:hover, | |
| .gradio-container .gradio-examples .pagination button:hover { | |
| background: linear-gradient(135deg, #5a6fd8 0%, #6a4190 100%) !important; | |
| transform: translateY(-2px) !important; | |
| box-shadow: 0 6px 12px rgba(0,0,0,0.3) !important; | |
| opacity: 1 !important; | |
| } | |
| div[data-testid="examples"] .pagination button.active, | |
| .gradio-container .gradio-examples .pagination button.active { | |
| background: linear-gradient(135deg, #11998e 0%, #38ef7d 100%) !important; | |
| box-shadow: 0 4px 8px rgba(17,153,142,0.4) !important; | |
| opacity: 1 !important; | |
| } | |
| button[class*="pagination"], | |
| button[class*="page"] { | |
| font-size: 28px !important; | |
| font-weight: bold !important; | |
| padding: 15px 20px !important; | |
| min-width: 60px !important; | |
| height: 60px !important; | |
| border-radius: 10px !important; | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important; | |
| color: white !important; | |
| border: none !important; | |
| cursor: pointer !important; | |
| margin: 8px !important; | |
| box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important; | |
| transition: all 0.3s ease !important; | |
| } | |
| """ | |
| head_html = """ | |
| <link rel="icon" type="image/svg+xml" href="data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 100 100'%3E%3Ctext y='.9em' font-size='90'%3E🦾%3C/text%3E%3C/svg%3E"> | |
| <link rel="shortcut icon" type="image/svg+xml" href="data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 100 100'%3E%3Ctext y='.9em' font-size='90'%3E🦾%3C/text%3E%3C/svg%3E"> | |
| <link rel="icon" type="image/png" href="data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 100 100'%3E%3Ctext y='.9em' font-size='90'%3E🦾%3C/text%3E%3C/svg%3E"> | |
| <meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
| """ | |
| # description = """Official demo for **DKT **.""" | |
| # with gr.Blocks(css=css, title="DKT - Diffusion Knows Transparency", favicon_path="favicon.ico") as demo: | |
| with gr.Blocks(css=css, title="DKT", head=head_html) as demo: | |
| # gr.Markdown(title, elem_classes=["title"]) | |
| """ | |
| <a title="Website" href="https://stable-x.github.io/StableNormal/" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> | |
| <img src="https://www.obukhov.ai/img/badges/badge-website.svg"> | |
| </a> | |
| <a title="arXiv" href="https://arxiv.org/abs/2406.16864" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> | |
| <img src="https://www.obukhov.ai/img/badges/badge-pdf.svg"> | |
| </a> | |
| <a title="Social" href="https://x.com/ychngji6" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> | |
| <img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social"> | |
| </a> | |
| """ | |
| gr.Markdown( | |
| """ | |
| # Diffusion Knows Transparency: Repurposing Video Diffusion for Transparent Object Depth and Normal Estimation | |
| <p align="center"> | |
| <a title="Github" href="https://github.com/Daniellli/DKT" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> | |
| <img src="https://img.shields.io/github/stars/Daniellli/DKT?style=social" alt="badge-github-stars"> | |
| </a> | |
| """ | |
| ) | |
| # gr.Markdown(description, elem_classes=["description"]) | |
| # gr.Markdown("### Video Processing Demo", elem_classes=["description"]) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_video = gr.Video(label="Input Video", elem_id='video-display-input') | |
| model_size = gr.Radio( | |
| # choices=["1.3B", "14B"], | |
| choices=["1.3B"], | |
| value="1.3B", | |
| label="Model Size" | |
| ) | |
| with gr.Accordion("Advanced Parameters", open=False): | |
| num_inference_steps = gr.Slider( | |
| minimum=1, maximum=50, value=5, step=1, | |
| label="Number of Inference Steps" | |
| ) | |
| overlap = gr.Slider( | |
| minimum=1, maximum=20, value=3, step=1, | |
| label="Overlap" | |
| ) | |
| submit = gr.Button(value="Compute Depth", variant="primary") | |
| with gr.Column(): | |
| output_video = gr.Video( | |
| label="Depth Outputs", | |
| elem_id='video-display-output', | |
| autoplay=True | |
| ) | |
| vis_video = gr.Video( | |
| label="Visualization Video", | |
| visible=False, | |
| autoplay=True | |
| ) | |
| with gr.Row(): | |
| gr.Markdown("### 3D Point Cloud Visualization", elem_classes=["title"]) | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=1): | |
| output_point_map0 = gr.Model3D( | |
| label="Point Cloud Key Frame 1", | |
| clear_color=[1.0, 1.0, 1.0, 1.0], | |
| interactive=False, | |
| ) | |
| with gr.Column(scale=1): | |
| output_point_map1 = gr.Model3D( | |
| label="Point Cloud Key Frame 2", | |
| clear_color=[1.0, 1.0, 1.0, 1.0], | |
| interactive=False | |
| ) | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=1): | |
| output_point_map2 = gr.Model3D( | |
| label="Point Cloud Key Frame 3", | |
| clear_color=[1.0, 1.0, 1.0, 1.0], | |
| interactive=False | |
| ) | |
| with gr.Column(scale=1): | |
| output_point_map3 = gr.Model3D( | |
| label="Point Cloud Key Frame 4", | |
| clear_color=[1.0, 1.0, 1.0, 1.0], | |
| interactive=False | |
| ) | |
| def on_submit(video_file, model_size, num_inference_steps, overlap): | |
| logger.info('on_submit is calling') | |
| if video_file is None: | |
| return None, None, None, None, None, None, "Please upload a video file" | |
| try: | |
| output_path, glb_files = process_video( | |
| video_file, model_size, num_inference_steps, overlap | |
| ) | |
| if output_path is None: | |
| return None, None, None, None, None, None, glb_files | |
| model3d_outputs = [None] * 4 | |
| if glb_files: | |
| for i, glb_file in enumerate(glb_files[:4]): | |
| if os.path.exists(glb_file): | |
| model3d_outputs[i] = glb_file | |
| return output_path, None, *model3d_outputs | |
| except Exception as e: | |
| logger.error(e) | |
| return None, None, None, None, None, None | |
| submit.click( | |
| on_submit, | |
| inputs=[ | |
| input_video, model_size, num_inference_steps, overlap | |
| ], | |
| outputs=[ | |
| output_video, vis_video, output_point_map0, output_point_map1, output_point_map2, output_point_map3 | |
| ] | |
| ) | |
| logger.info(f'there are {len(example_inputs)} demo files') | |
| print(f'there are {len(example_inputs)} demo files') | |
| examples = gr.Examples( | |
| examples=example_inputs, | |
| inputs=[input_video, model_size, num_inference_steps, overlap], | |
| outputs=[ | |
| output_video, vis_video, | |
| output_point_map0, output_point_map1, output_point_map2, output_point_map3 | |
| ], | |
| fn=on_submit, | |
| examples_per_page=12, | |
| cache_examples=False | |
| ) | |
| if __name__ == '__main__': | |
| #* main code, model and moge model initialization | |
| demo.queue().launch(share = True) | |