import gradio as gr import torch import io from PIL import Image import numpy as np import spaces # Import spaces for ZeroGPU compatibility import math import re from einops import rearrange from mmengine.config import Config from src.builder import BUILDER import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt from scripts.camera.cam_dataset import Cam_Generator from scripts.camera.visualization.visualize_batch import make_perspective_figures from huggingface_hub import snapshot_download import os NUM = r"[+-]?(?:\d+(?:\.\d+)?|\.\d+)(?:[eE][+-]?\d+)?" CAM_PATTERN = re.compile(r"(?:camera parameters.*?:|roll.*?:)\s*("+NUM+r")\s*,\s*("+NUM+r")\s*,\s*("+NUM+r")", re.IGNORECASE|re.DOTALL) def center_crop(image): w, h = image.size s = min(w, h) l = (w - s) // 2 t = (h - s) // 2 return image.crop((l, t, l + s, t + s)) ##### load model config = "configs/pipelines/stage_2_base.py" config = Config.fromfile(config) model = BUILDER.build(config.model).eval() _ = snapshot_download( repo_id="KangLiao/Puffin", repo_type="model", allow_patterns="Puffin-Base.pth", local_dir="checkpoints/", local_dir_use_symlinks=False, revision="main", ) _ = model.load_state_dict(torch.load("checkpoints/Puffin-Base.pth", map_location='cpu'), strict=False) os.remove("checkpoints/Puffin-Base.pth") _ = snapshot_download( repo_id="wusize/Puffin", repo_type="model", local_dir="checkpoints/", local_dir_use_symlinks=False, revision="main", ) _ = model.vae.load_state_dict(torch.load('checkpoints/vae.pth', map_location='cpu'), strict=True) os.remove('checkpoints/vae.pth') if torch.cuda.is_available(): model = model.to(torch.bfloat16).cuda() else: model = model.to(torch.float32) def fig_to_image(fig): buf = io.BytesIO() fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) buf.seek(0) img = Image.open(buf).convert('RGB') buf.close() return img def extract_up_lat_figs(fig_dict): fig_up, fig_lat = None, None others = {} for k, fig in fig_dict.items(): if ("up_field" in k) and (fig_up is None): fig_up = fig elif ("latitude_field" in k) and (fig_lat is None): fig_lat = fig else: others[k] = fig return fig_up, fig_lat, others @torch.inference_mode() @spaces.GPU(duration=120) # Multimodal Understanding function def camera_understanding(image_src, question, seed, progress=gr.Progress(track_tqdm=True)): # Clear CUDA cache before generating torch.cuda.empty_cache() # set seed # torch.manual_seed(seed) # np.random.seed(seed) # torch.cuda.manual_seed(seed) print(torch.cuda.is_available()) prompt = ("Describe the image in detail. Then reason its spatial distribution and estimate its camera parameters (roll, pitch, and field-of-view).") image = Image.fromarray(image_src).convert('RGB') image = center_crop(image) image = image.resize((512, 512)) x = torch.from_numpy(np.array(image)).float() x = x / 255.0 x = 2 * x - 1 x = rearrange(x, 'h w c -> c h w') with torch.no_grad(): outputs = model.understand(prompt=[prompt], pixel_values=[x], progress_bar=False) text = outputs[0] gen = Cam_Generator(mode="base") cam = gen.get_cam(text) bgr = np.array(image)[:, :, ::-1].astype(np.float32) / 255.0 rgb = bgr[:, :, ::-1].copy() image_tensor = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0) single_batch = {} single_batch["image"] = image_tensor single_batch["up_field"] = cam[:2].unsqueeze(0) single_batch["latitude_field"] = cam[2:].unsqueeze(0) figs = make_perspective_figures(single_batch, single_batch, n_pairs=1) up_img = lat_img = None for k, fig in figs.items(): if "up_field" in k: up_img = fig_to_image(fig) elif "latitude_field" in k: lat_img = fig_to_image(fig) plt.close(fig) return text#, up_img, lat_img @torch.inference_mode() @spaces.GPU(duration=120) # Specify a duration to avoid timeout def generate_image(prompt_scene, seed=42, roll=0.1, pitch=0.1, fov=1.0, progress=gr.Progress(track_tqdm=True)): # Clear CUDA cache and avoid tracking gradients torch.cuda.empty_cache() # Set the seed for reproducible results # if seed is not None: torch.manual_seed(seed) torch.cuda.manual_seed(seed) np.random.seed(seed) print(torch.cuda.is_available()) generator = torch.Generator().manual_seed(seed) prompt_camera = ( "The camera parameters (roll, pitch, and field-of-view) are: " f"{roll:.4f}, {pitch:.4f}, {fov:.4f}." ) gen = Cam_Generator() cam_map = gen.get_cam(prompt_camera).to(model.device) cam_map = cam_map / (math.pi / 2) prompt = prompt_scene + " " + prompt_camera print("prompt:", prompt) bsz = 4 with torch.no_grad(): images, output_reasoning = model.generate( prompt=[prompt]*bsz, cfg_prompt=[""]*bsz, pixel_values_init=None, cfg_scale=4.5, num_steps=50, cam_values=[[cam_map]]*bsz, progress_bar=False, reasoning=False, prompt_reasoning=[""]*bsz, generator=generator, height=512, width=512 ) images = rearrange(images, 'b c h w -> b h w c') images = torch.clamp(127.5 * images + 128.0, 0, 255).to("cpu", dtype=torch.uint8).numpy() ret_images = [Image.fromarray(image) for image in images] return ret_images # Gradio interface css = ''' .gradio-container {max-width: 960px !important} ''' with gr.Blocks(css=css) as demo: gr.Markdown("# Puffin") with gr.Tab("Camera-controllable Image Generation"): gr.Markdown(value="## Camera-controllable Image Generation") prompt_input = gr.Textbox(label="Prompt.") with gr.Accordion("Camera Parameters", open=True): with gr.Row(): roll = gr.Slider(minimum=-0.7854, maximum=0.7854, value=0.1000, step=0.1000, label="roll value") pitch = gr.Slider(minimum=-0.7854, maximum=0.7854, value=-0.1000, step=0.1000, label="pitch value") fov = gr.Slider(minimum=0.3491, maximum=1.8326, value=1.5000, step=0.1000, label="fov value") seed_input = gr.Number(label="Seed (Optional)", precision=0, value=42) generation_button = gr.Button("Generate Images") image_output = gr.Gallery(label="Generated Images", columns=4, rows=1) examples_t2i = gr.Examples( label="Prompt examples.", examples=[ "A sunny day casts light on two warmly colored buildings—yellow with green accents and deeper orange—framed by a lush green tree, with a blue sign and street lamp adding details in the foreground.", "A high-vantage-point view of lush, autumn-colored mountains blanketed in green and gold, set against a clear blue sky with scattered white clouds, offering a tranquil and breathtaking vista of a serene valley below.", "A grand, historic castle with pointed spires and elaborate stone structures stands against a clear blue sky, flanked by a circular fountain, vibrant red flowers, and neatly trimmed hedges in a beautifully landscaped garden.", "A serene aerial view of a coastal landscape at sunrise/sunset, featuring warm pink and orange skies transitioning to cool blues, with calm waters stretching to rugged, snow-capped mountains in the background, creating a tranquil and picturesque scene.", "A worn, light-yellow walls room with herringbone terracotta floors and three large arched windows framed in pink trim and white panes, showcasing signs of age and disrepair, overlooks a residential area through glimpses of greenery and neighboring buildings.", ], inputs=prompt_input, ) with gr.Tab("Camera Understanding"): gr.Markdown(value="## Camera Understanding") image_input = gr.Image() understanding_button = gr.Button("Chat") understanding_output = gr.Textbox(label="Response") #camera1 = gr.Gallery(label="Camera Maps", columns=1, rows=1) #camera2 = gr.Gallery(label="Camera Maps", columns=1, rows=1) with gr.Accordion("Advanced options", open=False): und_seed_input = gr.Number(label="Seed", precision=0, value=42) examples_inpainting = gr.Examples( label="Camera Understanding examples", examples=[ "assets/1.jpg", "assets/2.jpg", "assets/3.jpg", "assets/4.jpg", "assets/5.jpg", "assets/6.jpg", ], inputs=image_input, ) generation_button.click( fn=generate_image, inputs=[prompt_input, seed_input, roll, pitch, fov], outputs=image_output ) understanding_button.click( camera_understanding, inputs=[image_input, und_seed_input], outputs=[understanding_output]#, camera1, camera2] ) demo.launch(share=True)