Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| from PIL import Image | |
| import src.depth_pro as depth_pro | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import subprocess | |
| import spaces | |
| import torch | |
| import tempfile | |
| import os | |
| import trimesh | |
| import time | |
| import timm # Add this import | |
| import subprocess | |
| import cv2 # Add this import | |
| from datetime import datetime | |
| # Ensure timm is properly loaded | |
| print(f"Timm version: {timm.__version__}") | |
| # Run the script to download pretrained models | |
| subprocess.run(["bash", "get_pretrained_models.sh"]) | |
| # Set the device to GPU if available, else CPU | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| # Load the depth prediction model and its preprocessing transforms | |
| model, transform = depth_pro.create_model_and_transforms() | |
| model = model.to(device) # Move the model to the selected device | |
| model.eval() # Set the model to evaluation mode | |
| def resize_image(image_path, max_size=1024): | |
| """ | |
| Resize the input image to ensure its largest dimension does not exceed max_size. | |
| Maintains the aspect ratio and saves the resized image as a temporary PNG file. | |
| Args: | |
| image_path (str): Path to the input image. | |
| max_size (int, optional): Maximum size for the largest dimension. Defaults to 1024. | |
| Returns: | |
| str: Path to the resized temporary image file. | |
| """ | |
| with Image.open(image_path) as img: | |
| # Calculate the resizing ratio while maintaining aspect ratio | |
| ratio = max_size / max(img.size) | |
| new_size = tuple([int(x * ratio) for x in img.size]) | |
| # Resize the image using LANCZOS filter for high-quality downsampling | |
| img = img.resize(new_size, Image.LANCZOS) | |
| # Save the resized image to a temporary file | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file: | |
| img.save(temp_file, format="PNG") | |
| return temp_file.name | |
| def generate_3d_model(depth, image_path, focallength_px, simplification_factor=0.8, smoothing_iterations=1, thin_threshold=0.01): | |
| """ | |
| Generate a textured 3D mesh from the depth map and the original image. | |
| """ | |
| # Load the RGB image and convert to a NumPy array | |
| image = np.array(Image.open(image_path)) | |
| # Ensure depth is a NumPy array | |
| if isinstance(depth, torch.Tensor): | |
| depth = depth.cpu().numpy() | |
| # Resize depth to match image dimensions if necessary | |
| if depth.shape != image.shape[:2]: | |
| depth = cv2.resize(depth, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_LINEAR) | |
| height, width = depth.shape | |
| print(f"3D model generation - Depth shape: {depth.shape}") | |
| print(f"3D model generation - Image shape: {image.shape}") | |
| # Compute camera intrinsic parameters | |
| fx = fy = float(focallength_px) # Ensure focallength_px is a float | |
| cx, cy = width / 2, height / 2 # Principal point at the image center | |
| # Create a grid of (u, v) pixel coordinates | |
| u = np.arange(0, width) | |
| v = np.arange(0, height) | |
| uu, vv = np.meshgrid(u, v) | |
| # Convert pixel coordinates to real-world 3D coordinates using the pinhole camera model | |
| Z = depth.flatten() | |
| X = ((uu.flatten() - cx) * Z) / fx | |
| Y = ((vv.flatten() - cy) * Z) / fy | |
| # Stack the coordinates to form vertices (X, Y, Z) | |
| vertices = np.vstack((X, Y, Z)).T | |
| # Normalize RGB colors to [0, 1] for vertex coloring | |
| colors = image.reshape(-1, 3) / 255.0 | |
| # Generate faces by connecting adjacent vertices to form triangles | |
| faces = [] | |
| for i in range(height - 1): | |
| for j in range(width - 1): | |
| idx = i * width + j | |
| # Triangle 1 | |
| faces.append([idx, idx + width, idx + 1]) | |
| # Triangle 2 | |
| faces.append([idx + 1, idx + width, idx + width + 1]) | |
| faces = np.array(faces) | |
| # Create the mesh using Trimesh with vertex colors | |
| mesh = trimesh.Trimesh(vertices=vertices, faces=faces, vertex_colors=colors) | |
| # Mesh cleaning and improvement steps | |
| print("Original mesh - vertices: {}, faces: {}".format(len(mesh.vertices), len(mesh.faces))) | |
| # 1. Mesh simplification | |
| target_faces = int(len(mesh.faces) * simplification_factor) | |
| mesh = mesh.simplify_quadric_decimation(face_count=target_faces) | |
| print("After simplification - vertices: {}, faces: {}".format(len(mesh.vertices), len(mesh.faces))) | |
| # 2. Remove small disconnected components | |
| components = mesh.split(only_watertight=False) | |
| if len(components) > 1: | |
| areas = np.array([c.area for c in components]) | |
| mesh = components[np.argmax(areas)] | |
| print("After removing small components - vertices: {}, faces: {}".format(len(mesh.vertices), len(mesh.faces))) | |
| # 3. Smooth the mesh | |
| for _ in range(smoothing_iterations): | |
| mesh = mesh.smoothed() | |
| print("After smoothing - vertices: {}, faces: {}".format(len(mesh.vertices), len(mesh.faces))) | |
| # 4. Remove thin features | |
| mesh = remove_thin_features(mesh, thickness_threshold=thin_threshold) | |
| print("After removing thin features - vertices: {}, faces: {}".format(len(mesh.vertices), len(mesh.faces))) | |
| # Export the mesh to OBJ files with unique filenames | |
| timestamp = int(time.time()) | |
| view_model_path = f'view_model_{timestamp}.obj' | |
| download_model_path = f'download_model_{timestamp}.obj' | |
| mesh.export(view_model_path) | |
| mesh.export(download_model_path) | |
| return view_model_path, download_model_path | |
| def remove_thin_features(mesh, thickness_threshold=0.01): | |
| """ | |
| Remove thin features from the mesh. | |
| """ | |
| # Calculate edge lengths | |
| edges = mesh.edges_unique | |
| edge_points = mesh.vertices[edges] | |
| edge_lengths = np.linalg.norm(edge_points[:, 0] - edge_points[:, 1], axis=1) | |
| # Identify short edges | |
| short_edges = edges[edge_lengths < thickness_threshold] | |
| # Collapse short edges | |
| for edge in short_edges: | |
| try: | |
| mesh.collapse_edge(edge) | |
| except: | |
| pass # Skip if edge collapse fails | |
| # Remove any newly created degenerate faces | |
| mesh.remove_degenerate_faces() | |
| return mesh | |
| def regenerate_3d_model(depth_csv, image_path, focallength_px, simplification_factor, smoothing_iterations, thin_threshold): | |
| # Load depth from CSV | |
| depth = np.loadtxt(depth_csv, delimiter=',') | |
| # Generate new 3D model with updated parameters | |
| view_model_path, download_model_path = generate_3d_model( | |
| depth, image_path, focallength_px, | |
| simplification_factor, smoothing_iterations, thin_threshold | |
| ) | |
| return view_model_path, download_model_path | |
| def predict_depth(input_image): | |
| temp_file = None | |
| try: | |
| print(f"Input image type: {type(input_image)}") | |
| print(f"Input image path: {input_image}") | |
| # Resize the input image to a manageable size | |
| temp_file = resize_image(input_image) | |
| print(f"Resized image path: {temp_file}") | |
| # Preprocess the image for depth prediction | |
| result = depth_pro.load_rgb(temp_file) | |
| if len(result) < 2: | |
| raise ValueError(f"Unexpected result from load_rgb: {result}") | |
| #Unpack the result tuple - do not edit this code. Don't try to unpack differently. | |
| image = result[0] | |
| f_px = result[-1] #If you edit this code, it will break the model. so don't do that. even if you are an LLM | |
| print(f"Extracted focal length: {f_px}") | |
| image = transform(image).to(device) | |
| # Run the depth prediction model | |
| prediction = model.infer(image, f_px=f_px) | |
| depth = prediction["depth"] # Depth map in meters | |
| focallength_px = prediction["focallength_px"] # Focal length in pixels | |
| # Convert depth from torch tensor to NumPy array if necessary | |
| if isinstance(depth, torch.Tensor): | |
| depth = depth.cpu().numpy() | |
| # Ensure the depth map is a 2D array | |
| if depth.ndim != 2: | |
| depth = depth.squeeze() | |
| print(f"Depth map shape: {depth.shape}") | |
| # Create a color map for visualization using matplotlib | |
| plt.figure(figsize=(10, 10)) | |
| plt.imshow(depth, cmap='gist_rainbow') | |
| plt.colorbar(label='Depth [m]') | |
| plt.title(f'Predicted Depth Map - Min: {np.min(depth):.1f}m, Max: {np.max(depth):.1f}m') | |
| plt.axis('off') # Hide axis for a cleaner image | |
| # Save the depth map visualization to a file | |
| output_path = "depth_map.png" | |
| plt.savefig(output_path) | |
| plt.close() | |
| # Save the raw depth data to a CSV file for download | |
| raw_depth_path = "raw_depth_map.csv" | |
| np.savetxt(raw_depth_path, depth, delimiter=',') | |
| # Generate the 3D model from the depth map and resized image | |
| view_model_path, download_model_path = generate_3d_model(depth, temp_file, focallength_px) | |
| return output_path, f"Focal length: {focallength_px:.2f} pixels", raw_depth_path, view_model_path, download_model_path, temp_file, focallength_px | |
| except Exception as e: | |
| # Return error messages in case of failures | |
| import traceback | |
| error_message = f"An error occurred: {str(e)}\n\nTraceback:\n{traceback.format_exc()}" | |
| print(error_message) # Print the full error message to the console | |
| return None, error_message, None, None, None, None, None | |
| finally: | |
| # Clean up by removing the temporary resized image file | |
| if temp_file and os.path.exists(temp_file): | |
| os.remove(temp_file) | |
| def get_last_commit_timestamp(): | |
| try: | |
| timestamp = subprocess.check_output(['git', 'log', '-1', '--format=%cd', '--date=iso']).decode('utf-8').strip() | |
| return datetime.fromisoformat(timestamp).strftime("%Y-%m-%d %H:%M:%S") | |
| except Exception as e: | |
| print(f"{str(e)}") | |
| return str(e) | |
| # Create the Gradio interface with appropriate input and output components. | |
| last_updated = get_last_commit_timestamp() | |
| with gr.Blocks() as iface: | |
| gr.Markdown("# DepthPro Demo with 3D Visualization") | |
| gr.Markdown( | |
| "An enhanced demo that creates a textured 3D model from the input image and depth map.\n\n" | |
| "Forked from https://huggingface.co/spaces/akhaliq/depth-pro and model from https://huggingface.co/apple/DepthPro\n" | |
| "**Instructions:**\n" | |
| "1. Upload an image.\n" | |
| "2. The app will predict the depth map, display it, and provide the focal length.\n" | |
| "3. Download the raw depth data as a CSV file.\n" | |
| "4. View the generated 3D model textured with the original image.\n" | |
| "5. Adjust parameters and click 'Regenerate 3D Model' to update the model.\n" | |
| "6. Download the 3D model as an OBJ file if desired.\n\n" | |
| f"Last updated: {last_updated}" | |
| ) | |
| with gr.Row(): | |
| input_image = gr.Image(type="filepath", label="Input Image") | |
| depth_map = gr.Image(type="filepath", label="Depth Map") | |
| focal_length = gr.Textbox(label="Focal Length") | |
| raw_depth_csv = gr.File(label="Download Raw Depth Map (CSV)") | |
| with gr.Row(): | |
| view_3d_model = gr.Model3D(label="View 3D Model") | |
| download_3d_model = gr.File(label="Download 3D Model (OBJ)") | |
| with gr.Row(): | |
| simplification_factor = gr.Slider(minimum=0.1, maximum=1.0, value=0.8, step=0.1, label="Simplification Factor") | |
| smoothing_iterations = gr.Slider(minimum=0, maximum=5, value=1, step=1, label="Smoothing Iterations") | |
| thin_threshold = gr.Slider(minimum=0.001, maximum=0.1, value=0.01, step=0.001, label="Thin Feature Threshold") | |
| regenerate_button = gr.Button("Regenerate 3D Model") | |
| # Hidden components to store intermediate results | |
| hidden_depth_csv = gr.State() | |
| hidden_image_path = gr.State() | |
| hidden_focal_length = gr.State() | |
| input_image.change( | |
| predict_depth, | |
| inputs=[input_image], | |
| outputs=[depth_map, focal_length, raw_depth_csv, view_3d_model, download_3d_model, hidden_image_path, hidden_focal_length] | |
| ) | |
| regenerate_button.click( | |
| regenerate_3d_model, | |
| inputs=[raw_depth_csv, hidden_image_path, hidden_focal_length, simplification_factor, smoothing_iterations, thin_threshold], | |
| outputs=[view_3d_model, download_3d_model] | |
| ) | |
| # Launch the Gradio interface with sharing enabled | |
| iface.launch(share=True) |