Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from models import remix_images_inference # Import the inference function from models.py | |
| def main(): | |
| with gr.Blocks(title="Image Remix with SDXL") as demo: | |
| gr.HTML( | |
| """ | |
| <div style="text-align: center; max-width: 700px; margin: 0 auto;"> | |
| <h1 style="font-weight: 900; font-size: 2.5em; margin-bottom: 0.5em;"> | |
| Image Remix with SDXL | |
| </h1> | |
| <p style="margin-bottom: 1em; font-size: 1.1em; color: #555;"> | |
| Drag and drop up to three images and provide a text prompt to remix them using Stable Diffusion XL. | |
| If the first image is provided, it will be used as a base for image-to-image generation. | |
| </p> | |
| <p style="font-size: 0.9em; color: #777;"> | |
| Built with <a href="https://huggingface.co/spaces/akhaliq/anycoder" target="_blank" style="color: #4CAF50; text-decoration: none;">anycoder</a> | |
| </p> | |
| </div> | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| with gr.Group(): | |
| image_input_1 = gr.Image( | |
| label="Image Input 1 (Base for Img2Img)", | |
| type="pil", | |
| height=256, | |
| width=256, | |
| image_mode="RGBA", | |
| drop_threshold_height=200 | |
| ) | |
| image_input_2 = gr.Image( | |
| label="Image Input 2 (Optional)", | |
| type="pil", | |
| height=256, | |
| width=256, | |
| image_mode="RGBA", | |
| drop_threshold_height=200 | |
| ) | |
| image_input_3 = gr.Image( | |
| label="Image Input 3 (Optional)", | |
| type="pil", | |
| height=256, | |
| width=256, | |
| image_mode="RGBA", | |
| drop_threshold_height=200 | |
| ) | |
| prompt_input = gr.Textbox( | |
| label="Remix Prompt", | |
| placeholder="A vibrant abstract painting blending elements of nature and technology", | |
| lines=2 | |
| ) | |
| remix_button = gr.Button("Remix Images", variant="primary") | |
| with gr.Column(scale=2): | |
| output_image = gr.Image( | |
| label="Remixed Image", | |
| type="pil", | |
| interactive=False, | |
| height=512, | |
| width=512 | |
| ) | |
| # Define the interaction | |
| remix_button.click( | |
| fn=remix_images_inference, | |
| inputs=[image_input_1, image_input_2, image_input_3, prompt_input], | |
| outputs=output_image, | |
| queue=True, | |
| show_progress="full" | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| [ | |
| "https://www.kasandbox.org/programming-images/avatars/spunky-sam-headphones.png", | |
| None, | |
| None, | |
| "a robot wearing headphones, futuristic, cyberpunk art style" | |
| ], | |
| [ | |
| "https://gradio-docs-json.s3.us-west-2.amazonaws.com/base.png", | |
| "https://gradio-docs-json.s3.us-west-2.amazonaws.com/buildings.png", | |
| None, | |
| "a serene landscape with ancient ruins, overgrown with lush vegetation, concept art, fantasy" | |
| ], | |
| [ | |
| None, | |
| None, | |
| None, | |
| "an astronaut riding a horse on the moon, cinematic, photorealistic" | |
| ] | |
| ], | |
| inputs=[image_input_1, image_input_2, image_input_3, prompt_input], | |
| outputs=output_image, | |
| fn=remix_images_inference, | |
| cache_examples=False, # Cache examples can be set to True if the inference is fast enough | |
| run_on_click=True | |
| ) | |
| demo.launch(enable_monitoring=True) | |
| if __name__ == "__main__": | |
| main() | |
| ``` | |
| ### `models.py` | |
| ```python | |
| import spaces | |
| import torch | |
| from diffusers import DiffusionPipeline, AutoPipelineForImage2Image | |
| from PIL import Image | |
| import numpy as np | |
| import os | |
| import gradio as gr | |
| # Set a cache directory for models if not already set, for smoother experience on Spaces | |
| os.environ["HF_HOME"] = os.getenv("HF_HOME", "/data/hf_cache") | |
| MODEL_ID_TEXT2IMG = "stabilityai/stable-diffusion-xl-base-1.0" | |
| MODEL_ID_IMG2IMG = "stabilityai/stable-diffusion-xl-refiner-1.0" # SDXL Refiner for img2img | |
| # Load models outside the GPU context first. Using fp16 for faster inference. | |
| print("Loading models (this may take a moment)...") | |
| pipe_t2i_raw = DiffusionPipeline.from_pretrained(MODEL_ID_TEXT2IMG, torch_dtype=torch.float16, use_safetensors=True, variant="fp16") | |
| pipe_i2i_raw = AutoPipelineForImage2Image.from_pretrained(MODEL_ID_IMG2IMG, torch_dtype=torch.float16, use_safetensors=True, variant="fp16") | |
| print("Models loaded.") | |
| def prepare_unet_dummy_inputs(pipe, resolution=(1024, 1024), batch_size=2, dtype=torch.float16, device="cuda"): | |
| """ | |
| Prepares dummy inputs for SDXL's UNet for AoT compilation. | |
| """ | |
| height, width = resolution | |
| latent_height = height // 8 | |
| latent_width = width // 8 | |
| dummy_latents = torch.randn(batch_size, pipe.unet.config.in_channels, latent_height, latent_width, device=device, dtype=dtype) | |
| dummy_timestep = torch.tensor(1.0, device=device, dtype=dtype) | |
| dummy_encoder_hidden_states = torch.randn(batch_size, 77, pipe.unet.config.cross_attention_dim, device=device, dtype=dtype) | |
| # added_cond_kwargs contains text_embeds and time_ids for SDXL UNet | |
| dummy_text_embeds = torch.randn(batch_size, pipe.unet.config.addition_embed_type_num_vector_context_tokens, device=device, dtype=dtype) | |
| dummy_time_ids = torch.randn(batch_size, 6, device=device, dtype=dtype) | |
| unet_inputs = { | |
| "sample": dummy_latents, | |
| "timestep": dummy_timestep, | |
| "encoder_hidden_states": dummy_encoder_hidden_states, | |
| "added_cond_kwargs": { | |
| "text_embeds": dummy_text_embeds, | |
| "time_ids": dummy_time_ids, | |
| } | |
| } | |
| return unet_inputs | |
| # Use max duration for compilation at startup | |
| def compile_optimized_models(): | |
| """ | |
| Compiles the UNet components of both text-to-image and image-to-image pipelines | |
| using Ahead-of-Time (AoT) compilation for performance optimization. | |
| """ | |
| print("Moving models to CUDA...") | |
| pipe_t2i_raw.to("cuda") | |
| pipe_i2i_raw.to("cuda") | |
| print("Models moved to CUDA.") | |
| # Compile UNet for text2img pipeline | |
| print("Compiling Text2Image UNet...") | |
| dummy_inputs_t2i = prepare_unet_dummy_inputs(pipe_t2i_raw) | |
| with spaces.aoti_capture(pipe_t2i_raw.unet, **dummy_inputs_t2i) as call_t2i: | |
| pass # Inputs are passed directly to aoti_capture for explicit tracing | |
| exported_t2i_unet = torch.export.export(pipe_t2i_raw.unet, args=call_t2i.args, kwargs=call_t2i.kwargs) | |
| compiled_t2i_unet = spaces.aoti_compile(exported_t2i_unet) | |
| spaces.aoti_apply(compiled_t2i_unet, pipe_t2i_raw.unet) | |
| print("Text2Image UNet compiled.") | |
| # Compile UNet for img2img (refiner) pipeline | |
| print("Compiling Image2Image UNet...") | |
| dummy_inputs_i2i = prepare_unet_dummy_inputs(pipe_i2i_raw) | |
| with spaces.aoti_capture(pipe_i2i_raw.unet, **dummy_inputs_i2i) as call_i2i: | |
| pass # Inputs are passed directly to aoti_capture for explicit tracing | |
| exported_i2i_unet = torch.export.export(pipe_i2i_raw.unet, args=call_i2i.args, kwargs=call_i2i.kwargs) | |
| compiled_i2i_unet = spaces.aoti_compile(exported_i2i_unet) | |
| spaces.aoti_apply(compiled_i2i_unet, pipe_i2i_raw.unet) | |
| print("Image2Image UNet compiled.") | |
| # Return the now-compiled pipelines | |
| return pipe_t2i_raw, pipe_i2i_raw | |
| # Execute compilation during startup | |
| pipe_text2img, pipe_img2img = compile_optimized_models() | |
| # Allocate GPU for inference, max 120 seconds | |
| def remix_images_inference(image1: Image.Image | None, image2: Image.Image | None, image3: Image.Image | None, prompt: str) -> Image.Image: | |
| """ | |
| Remixes images based on a text prompt using a diffusion model. | |
| If image1 is provided, it uses an image-to-image pipeline for remixing. | |
| Otherwise, it falls back to a text-to-image pipeline. | |
| Args: | |
| image1 (Image.Image | None): The first input image. If provided, used as base for img2img. | |
| image2 (Image.Image | None): The second input image (currently influences prompt slightly). | |
| image3 (Image.Image | None): The third input image (currently influences prompt slightly). | |
| prompt (str): The text prompt to guide the remixing. | |
| Returns: | |
| Image.Image: The remixed image. | |
| """ | |
| if not prompt: | |
| raise gr.Error("Remix prompt cannot be empty!") | |
| output_resolution = (1024, 1024) # Fixed resolution for compiled models | |
| # Build a more descriptive prompt if additional images are provided, | |
| # to somewhat acknowledge their presence in the "remix". | |
| extra_prompt_info = "" | |
| if image2 is not None and image3 is not None: | |
| extra_prompt_info = ", incorporating elements from other images" | |
| elif image2 is not None or image3 is not None: | |
| extra_prompt_info = ", with subtle influences from another image" | |
| full_prompt = f"{prompt}{extra_prompt_info}" | |
| print(f"Full prompt for generation: {full_prompt}") | |
| if image1 is not None: | |
| # Resize the input image to the target resolution for img2img | |
| input_image_resized = image1.resize(output_resolution, Image.LANCZOS) | |
| print("Performing image-to-image remixing...") | |
| generated_image = pipe_img2img( | |
| prompt=full_prompt, | |
| image=input_image_resized, | |
| strength=0.75, # High strength allows significant transformation from original image | |
| guidance_scale=8.0, | |
| num_inference_steps=50, | |
| width=output_resolution[0], | |
| height=output_resolution[1], | |
| ).images[0] | |
| else: | |
| print("Performing text-to-image generation...") | |
| generated_image = pipe_text2img( | |
| prompt=full_prompt, | |
| height=output_resolution[0], | |
| width=output_resolution[1], | |
| guidance_scale=8.0, | |
| num_inference_steps=50 | |
| ).images[0] | |
| return generated_image |