import gradio as gr from models import remix_images_inference # Import the inference function from models.py def main(): with gr.Blocks(title="Image Remix with SDXL") as demo: gr.HTML( """

Image Remix with SDXL

Drag and drop up to three images and provide a text prompt to remix them using Stable Diffusion XL. If the first image is provided, it will be used as a base for image-to-image generation.

Built with anycoder

""" ) with gr.Row(): with gr.Column(scale=1): with gr.Group(): image_input_1 = gr.Image( label="Image Input 1 (Base for Img2Img)", type="pil", height=256, width=256, image_mode="RGBA", drop_threshold_height=200 ) image_input_2 = gr.Image( label="Image Input 2 (Optional)", type="pil", height=256, width=256, image_mode="RGBA", drop_threshold_height=200 ) image_input_3 = gr.Image( label="Image Input 3 (Optional)", type="pil", height=256, width=256, image_mode="RGBA", drop_threshold_height=200 ) prompt_input = gr.Textbox( label="Remix Prompt", placeholder="A vibrant abstract painting blending elements of nature and technology", lines=2 ) remix_button = gr.Button("Remix Images", variant="primary") with gr.Column(scale=2): output_image = gr.Image( label="Remixed Image", type="pil", interactive=False, height=512, width=512 ) # Define the interaction remix_button.click( fn=remix_images_inference, inputs=[image_input_1, image_input_2, image_input_3, prompt_input], outputs=output_image, queue=True, show_progress="full" ) gr.Examples( examples=[ [ "https://www.kasandbox.org/programming-images/avatars/spunky-sam-headphones.png", None, None, "a robot wearing headphones, futuristic, cyberpunk art style" ], [ "https://gradio-docs-json.s3.us-west-2.amazonaws.com/base.png", "https://gradio-docs-json.s3.us-west-2.amazonaws.com/buildings.png", None, "a serene landscape with ancient ruins, overgrown with lush vegetation, concept art, fantasy" ], [ None, None, None, "an astronaut riding a horse on the moon, cinematic, photorealistic" ] ], inputs=[image_input_1, image_input_2, image_input_3, prompt_input], outputs=output_image, fn=remix_images_inference, cache_examples=False, # Cache examples can be set to True if the inference is fast enough run_on_click=True ) demo.launch(enable_monitoring=True) if __name__ == "__main__": main() ``` ### `models.py` ```python import spaces import torch from diffusers import DiffusionPipeline, AutoPipelineForImage2Image from PIL import Image import numpy as np import os import gradio as gr # Set a cache directory for models if not already set, for smoother experience on Spaces os.environ["HF_HOME"] = os.getenv("HF_HOME", "/data/hf_cache") MODEL_ID_TEXT2IMG = "stabilityai/stable-diffusion-xl-base-1.0" MODEL_ID_IMG2IMG = "stabilityai/stable-diffusion-xl-refiner-1.0" # SDXL Refiner for img2img # Load models outside the GPU context first. Using fp16 for faster inference. print("Loading models (this may take a moment)...") pipe_t2i_raw = DiffusionPipeline.from_pretrained(MODEL_ID_TEXT2IMG, torch_dtype=torch.float16, use_safetensors=True, variant="fp16") pipe_i2i_raw = AutoPipelineForImage2Image.from_pretrained(MODEL_ID_IMG2IMG, torch_dtype=torch.float16, use_safetensors=True, variant="fp16") print("Models loaded.") def prepare_unet_dummy_inputs(pipe, resolution=(1024, 1024), batch_size=2, dtype=torch.float16, device="cuda"): """ Prepares dummy inputs for SDXL's UNet for AoT compilation. """ height, width = resolution latent_height = height // 8 latent_width = width // 8 dummy_latents = torch.randn(batch_size, pipe.unet.config.in_channels, latent_height, latent_width, device=device, dtype=dtype) dummy_timestep = torch.tensor(1.0, device=device, dtype=dtype) dummy_encoder_hidden_states = torch.randn(batch_size, 77, pipe.unet.config.cross_attention_dim, device=device, dtype=dtype) # added_cond_kwargs contains text_embeds and time_ids for SDXL UNet dummy_text_embeds = torch.randn(batch_size, pipe.unet.config.addition_embed_type_num_vector_context_tokens, device=device, dtype=dtype) dummy_time_ids = torch.randn(batch_size, 6, device=device, dtype=dtype) unet_inputs = { "sample": dummy_latents, "timestep": dummy_timestep, "encoder_hidden_states": dummy_encoder_hidden_states, "added_cond_kwargs": { "text_embeds": dummy_text_embeds, "time_ids": dummy_time_ids, } } return unet_inputs @spaces.GPU(duration=1500) # Use max duration for compilation at startup def compile_optimized_models(): """ Compiles the UNet components of both text-to-image and image-to-image pipelines using Ahead-of-Time (AoT) compilation for performance optimization. """ print("Moving models to CUDA...") pipe_t2i_raw.to("cuda") pipe_i2i_raw.to("cuda") print("Models moved to CUDA.") # Compile UNet for text2img pipeline print("Compiling Text2Image UNet...") dummy_inputs_t2i = prepare_unet_dummy_inputs(pipe_t2i_raw) with spaces.aoti_capture(pipe_t2i_raw.unet, **dummy_inputs_t2i) as call_t2i: pass # Inputs are passed directly to aoti_capture for explicit tracing exported_t2i_unet = torch.export.export(pipe_t2i_raw.unet, args=call_t2i.args, kwargs=call_t2i.kwargs) compiled_t2i_unet = spaces.aoti_compile(exported_t2i_unet) spaces.aoti_apply(compiled_t2i_unet, pipe_t2i_raw.unet) print("Text2Image UNet compiled.") # Compile UNet for img2img (refiner) pipeline print("Compiling Image2Image UNet...") dummy_inputs_i2i = prepare_unet_dummy_inputs(pipe_i2i_raw) with spaces.aoti_capture(pipe_i2i_raw.unet, **dummy_inputs_i2i) as call_i2i: pass # Inputs are passed directly to aoti_capture for explicit tracing exported_i2i_unet = torch.export.export(pipe_i2i_raw.unet, args=call_i2i.args, kwargs=call_i2i.kwargs) compiled_i2i_unet = spaces.aoti_compile(exported_i2i_unet) spaces.aoti_apply(compiled_i2i_unet, pipe_i2i_raw.unet) print("Image2Image UNet compiled.") # Return the now-compiled pipelines return pipe_t2i_raw, pipe_i2i_raw # Execute compilation during startup pipe_text2img, pipe_img2img = compile_optimized_models() @spaces.GPU(duration=120) # Allocate GPU for inference, max 120 seconds def remix_images_inference(image1: Image.Image | None, image2: Image.Image | None, image3: Image.Image | None, prompt: str) -> Image.Image: """ Remixes images based on a text prompt using a diffusion model. If image1 is provided, it uses an image-to-image pipeline for remixing. Otherwise, it falls back to a text-to-image pipeline. Args: image1 (Image.Image | None): The first input image. If provided, used as base for img2img. image2 (Image.Image | None): The second input image (currently influences prompt slightly). image3 (Image.Image | None): The third input image (currently influences prompt slightly). prompt (str): The text prompt to guide the remixing. Returns: Image.Image: The remixed image. """ if not prompt: raise gr.Error("Remix prompt cannot be empty!") output_resolution = (1024, 1024) # Fixed resolution for compiled models # Build a more descriptive prompt if additional images are provided, # to somewhat acknowledge their presence in the "remix". extra_prompt_info = "" if image2 is not None and image3 is not None: extra_prompt_info = ", incorporating elements from other images" elif image2 is not None or image3 is not None: extra_prompt_info = ", with subtle influences from another image" full_prompt = f"{prompt}{extra_prompt_info}" print(f"Full prompt for generation: {full_prompt}") if image1 is not None: # Resize the input image to the target resolution for img2img input_image_resized = image1.resize(output_resolution, Image.LANCZOS) print("Performing image-to-image remixing...") generated_image = pipe_img2img( prompt=full_prompt, image=input_image_resized, strength=0.75, # High strength allows significant transformation from original image guidance_scale=8.0, num_inference_steps=50, width=output_resolution[0], height=output_resolution[1], ).images[0] else: print("Performing text-to-image generation...") generated_image = pipe_text2img( prompt=full_prompt, height=output_resolution[0], width=output_resolution[1], guidance_scale=8.0, num_inference_steps=50 ).images[0] return generated_image