tool-uilny5se / app.py
Gertie01's picture
Update app.py (#2)
b7be347 verified
import gradio as gr
from models import remix_images_inference # Import the inference function from models.py
def main():
with gr.Blocks(title="Image Remix with SDXL") as demo:
gr.HTML(
"""
<div style="text-align: center; max-width: 700px; margin: 0 auto;">
<h1 style="font-weight: 900; font-size: 2.5em; margin-bottom: 0.5em;">
Image Remix with SDXL
</h1>
<p style="margin-bottom: 1em; font-size: 1.1em; color: #555;">
Drag and drop up to three images and provide a text prompt to remix them using Stable Diffusion XL.
If the first image is provided, it will be used as a base for image-to-image generation.
</p>
<p style="font-size: 0.9em; color: #777;">
Built with <a href="https://huggingface.co/spaces/akhaliq/anycoder" target="_blank" style="color: #4CAF50; text-decoration: none;">anycoder</a>
</p>
</div>
"""
)
with gr.Row():
with gr.Column(scale=1):
with gr.Group():
image_input_1 = gr.Image(
label="Image Input 1 (Base for Img2Img)",
type="pil",
height=256,
width=256,
image_mode="RGBA",
drop_threshold_height=200
)
image_input_2 = gr.Image(
label="Image Input 2 (Optional)",
type="pil",
height=256,
width=256,
image_mode="RGBA",
drop_threshold_height=200
)
image_input_3 = gr.Image(
label="Image Input 3 (Optional)",
type="pil",
height=256,
width=256,
image_mode="RGBA",
drop_threshold_height=200
)
prompt_input = gr.Textbox(
label="Remix Prompt",
placeholder="A vibrant abstract painting blending elements of nature and technology",
lines=2
)
remix_button = gr.Button("Remix Images", variant="primary")
with gr.Column(scale=2):
output_image = gr.Image(
label="Remixed Image",
type="pil",
interactive=False,
height=512,
width=512
)
# Define the interaction
remix_button.click(
fn=remix_images_inference,
inputs=[image_input_1, image_input_2, image_input_3, prompt_input],
outputs=output_image,
queue=True,
show_progress="full"
)
gr.Examples(
examples=[
[
"https://www.kasandbox.org/programming-images/avatars/spunky-sam-headphones.png",
None,
None,
"a robot wearing headphones, futuristic, cyberpunk art style"
],
[
"https://gradio-docs-json.s3.us-west-2.amazonaws.com/base.png",
"https://gradio-docs-json.s3.us-west-2.amazonaws.com/buildings.png",
None,
"a serene landscape with ancient ruins, overgrown with lush vegetation, concept art, fantasy"
],
[
None,
None,
None,
"an astronaut riding a horse on the moon, cinematic, photorealistic"
]
],
inputs=[image_input_1, image_input_2, image_input_3, prompt_input],
outputs=output_image,
fn=remix_images_inference,
cache_examples=False, # Cache examples can be set to True if the inference is fast enough
run_on_click=True
)
demo.launch(enable_monitoring=True)
if __name__ == "__main__":
main()
```
### `models.py`
```python
import spaces
import torch
from diffusers import DiffusionPipeline, AutoPipelineForImage2Image
from PIL import Image
import numpy as np
import os
import gradio as gr
# Set a cache directory for models if not already set, for smoother experience on Spaces
os.environ["HF_HOME"] = os.getenv("HF_HOME", "/data/hf_cache")
MODEL_ID_TEXT2IMG = "stabilityai/stable-diffusion-xl-base-1.0"
MODEL_ID_IMG2IMG = "stabilityai/stable-diffusion-xl-refiner-1.0" # SDXL Refiner for img2img
# Load models outside the GPU context first. Using fp16 for faster inference.
print("Loading models (this may take a moment)...")
pipe_t2i_raw = DiffusionPipeline.from_pretrained(MODEL_ID_TEXT2IMG, torch_dtype=torch.float16, use_safetensors=True, variant="fp16")
pipe_i2i_raw = AutoPipelineForImage2Image.from_pretrained(MODEL_ID_IMG2IMG, torch_dtype=torch.float16, use_safetensors=True, variant="fp16")
print("Models loaded.")
def prepare_unet_dummy_inputs(pipe, resolution=(1024, 1024), batch_size=2, dtype=torch.float16, device="cuda"):
"""
Prepares dummy inputs for SDXL's UNet for AoT compilation.
"""
height, width = resolution
latent_height = height // 8
latent_width = width // 8
dummy_latents = torch.randn(batch_size, pipe.unet.config.in_channels, latent_height, latent_width, device=device, dtype=dtype)
dummy_timestep = torch.tensor(1.0, device=device, dtype=dtype)
dummy_encoder_hidden_states = torch.randn(batch_size, 77, pipe.unet.config.cross_attention_dim, device=device, dtype=dtype)
# added_cond_kwargs contains text_embeds and time_ids for SDXL UNet
dummy_text_embeds = torch.randn(batch_size, pipe.unet.config.addition_embed_type_num_vector_context_tokens, device=device, dtype=dtype)
dummy_time_ids = torch.randn(batch_size, 6, device=device, dtype=dtype)
unet_inputs = {
"sample": dummy_latents,
"timestep": dummy_timestep,
"encoder_hidden_states": dummy_encoder_hidden_states,
"added_cond_kwargs": {
"text_embeds": dummy_text_embeds,
"time_ids": dummy_time_ids,
}
}
return unet_inputs
@spaces.GPU(duration=1500) # Use max duration for compilation at startup
def compile_optimized_models():
"""
Compiles the UNet components of both text-to-image and image-to-image pipelines
using Ahead-of-Time (AoT) compilation for performance optimization.
"""
print("Moving models to CUDA...")
pipe_t2i_raw.to("cuda")
pipe_i2i_raw.to("cuda")
print("Models moved to CUDA.")
# Compile UNet for text2img pipeline
print("Compiling Text2Image UNet...")
dummy_inputs_t2i = prepare_unet_dummy_inputs(pipe_t2i_raw)
with spaces.aoti_capture(pipe_t2i_raw.unet, **dummy_inputs_t2i) as call_t2i:
pass # Inputs are passed directly to aoti_capture for explicit tracing
exported_t2i_unet = torch.export.export(pipe_t2i_raw.unet, args=call_t2i.args, kwargs=call_t2i.kwargs)
compiled_t2i_unet = spaces.aoti_compile(exported_t2i_unet)
spaces.aoti_apply(compiled_t2i_unet, pipe_t2i_raw.unet)
print("Text2Image UNet compiled.")
# Compile UNet for img2img (refiner) pipeline
print("Compiling Image2Image UNet...")
dummy_inputs_i2i = prepare_unet_dummy_inputs(pipe_i2i_raw)
with spaces.aoti_capture(pipe_i2i_raw.unet, **dummy_inputs_i2i) as call_i2i:
pass # Inputs are passed directly to aoti_capture for explicit tracing
exported_i2i_unet = torch.export.export(pipe_i2i_raw.unet, args=call_i2i.args, kwargs=call_i2i.kwargs)
compiled_i2i_unet = spaces.aoti_compile(exported_i2i_unet)
spaces.aoti_apply(compiled_i2i_unet, pipe_i2i_raw.unet)
print("Image2Image UNet compiled.")
# Return the now-compiled pipelines
return pipe_t2i_raw, pipe_i2i_raw
# Execute compilation during startup
pipe_text2img, pipe_img2img = compile_optimized_models()
@spaces.GPU(duration=120) # Allocate GPU for inference, max 120 seconds
def remix_images_inference(image1: Image.Image | None, image2: Image.Image | None, image3: Image.Image | None, prompt: str) -> Image.Image:
"""
Remixes images based on a text prompt using a diffusion model.
If image1 is provided, it uses an image-to-image pipeline for remixing.
Otherwise, it falls back to a text-to-image pipeline.
Args:
image1 (Image.Image | None): The first input image. If provided, used as base for img2img.
image2 (Image.Image | None): The second input image (currently influences prompt slightly).
image3 (Image.Image | None): The third input image (currently influences prompt slightly).
prompt (str): The text prompt to guide the remixing.
Returns:
Image.Image: The remixed image.
"""
if not prompt:
raise gr.Error("Remix prompt cannot be empty!")
output_resolution = (1024, 1024) # Fixed resolution for compiled models
# Build a more descriptive prompt if additional images are provided,
# to somewhat acknowledge their presence in the "remix".
extra_prompt_info = ""
if image2 is not None and image3 is not None:
extra_prompt_info = ", incorporating elements from other images"
elif image2 is not None or image3 is not None:
extra_prompt_info = ", with subtle influences from another image"
full_prompt = f"{prompt}{extra_prompt_info}"
print(f"Full prompt for generation: {full_prompt}")
if image1 is not None:
# Resize the input image to the target resolution for img2img
input_image_resized = image1.resize(output_resolution, Image.LANCZOS)
print("Performing image-to-image remixing...")
generated_image = pipe_img2img(
prompt=full_prompt,
image=input_image_resized,
strength=0.75, # High strength allows significant transformation from original image
guidance_scale=8.0,
num_inference_steps=50,
width=output_resolution[0],
height=output_resolution[1],
).images[0]
else:
print("Performing text-to-image generation...")
generated_image = pipe_text2img(
prompt=full_prompt,
height=output_resolution[0],
width=output_resolution[1],
guidance_scale=8.0,
num_inference_steps=50
).images[0]
return generated_image