Spaces:
Running
on
Zero
Running
on
Zero
added sdxl from sd community model
Browse files
app.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
import torch
|
| 2 |
-
from diffusers import StableDiffusion3Pipeline, StableDiffusionPipeline, DiffusionPipeline
|
| 3 |
import gradio as gr
|
| 4 |
import os
|
| 5 |
import random
|
|
@@ -34,6 +34,12 @@ sdxl_pipe = DiffusionPipeline.from_pretrained(
|
|
| 34 |
)
|
| 35 |
sdxl_pipe.to(device)
|
| 36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
# Define the image generation function for the Arena tab
|
| 38 |
@spaces.GPU(duration=80)
|
| 39 |
def generate_arena_images(
|
|
@@ -103,6 +109,8 @@ def generate_single_image(
|
|
| 103 |
pipe = sd2_1_pipe
|
| 104 |
elif model_choice == "sdxl":
|
| 105 |
pipe = sdxl_pipe
|
|
|
|
|
|
|
| 106 |
else:
|
| 107 |
raise ValueError(f"Invalid model choice: {model_choice}")
|
| 108 |
|
|
@@ -192,13 +200,13 @@ with gr.Blocks(css=css) as demo:
|
|
| 192 |
)
|
| 193 |
model_choice_1 = gr.Dropdown(
|
| 194 |
label="Stable Diffusion Model 1",
|
| 195 |
-
choices=["sd3 medium", "sd2.1", "sdxl"],
|
| 196 |
value="sd3 medium",
|
| 197 |
)
|
| 198 |
model_choice_2 = gr.Dropdown(
|
| 199 |
label="Stable Diffusion Model 2",
|
| 200 |
-
choices=["sd3 medium", "sd2.1", "sdxl"],
|
| 201 |
-
value="
|
| 202 |
)
|
| 203 |
run_button = gr.Button("Run")
|
| 204 |
result_1 = gr.Gallery(label="Generated Images (Model 1)", elem_id="gallery_1")
|
|
@@ -301,7 +309,7 @@ with gr.Blocks(css=css) as demo:
|
|
| 301 |
)
|
| 302 |
model_choice = gr.Dropdown(
|
| 303 |
label="Stable Diffusion Model",
|
| 304 |
-
choices=["sd3 medium", "sd2.1", "sdxl"],
|
| 305 |
value="sd3 medium",
|
| 306 |
)
|
| 307 |
run_button = gr.Button("Run")
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from diffusers import StableDiffusion3Pipeline, StableDiffusionPipeline, DiffusionPipeline, DPMSolverSinglestepScheduler
|
| 3 |
import gradio as gr
|
| 4 |
import os
|
| 5 |
import random
|
|
|
|
| 34 |
)
|
| 35 |
sdxl_pipe.to(device)
|
| 36 |
|
| 37 |
+
sdxl_flash_pipe = StableDiffusionXLPipeline.from_pretrained("sd-community/sdxl-flash", torch_dtype=torch.float16)
|
| 38 |
+
sdxl_flash_pipe.to(device)
|
| 39 |
+
|
| 40 |
+
# Ensure sampler uses "trailing" timesteps.
|
| 41 |
+
sdxl_flash_pipe.scheduler = DPMSolverSinglestepScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
|
| 42 |
+
|
| 43 |
# Define the image generation function for the Arena tab
|
| 44 |
@spaces.GPU(duration=80)
|
| 45 |
def generate_arena_images(
|
|
|
|
| 109 |
pipe = sd2_1_pipe
|
| 110 |
elif model_choice == "sdxl":
|
| 111 |
pipe = sdxl_pipe
|
| 112 |
+
elif model_choice == "sdxl flash":
|
| 113 |
+
pipe = sdxl_flash_pipe
|
| 114 |
else:
|
| 115 |
raise ValueError(f"Invalid model choice: {model_choice}")
|
| 116 |
|
|
|
|
| 200 |
)
|
| 201 |
model_choice_1 = gr.Dropdown(
|
| 202 |
label="Stable Diffusion Model 1",
|
| 203 |
+
choices=["sd3 medium", "sd2.1", "sdxl", "sdxl flash"],
|
| 204 |
value="sd3 medium",
|
| 205 |
)
|
| 206 |
model_choice_2 = gr.Dropdown(
|
| 207 |
label="Stable Diffusion Model 2",
|
| 208 |
+
choices=["sd3 medium", "sd2.1", "sdxl", "sdxl flash"],
|
| 209 |
+
value="sdxl",
|
| 210 |
)
|
| 211 |
run_button = gr.Button("Run")
|
| 212 |
result_1 = gr.Gallery(label="Generated Images (Model 1)", elem_id="gallery_1")
|
|
|
|
| 309 |
)
|
| 310 |
model_choice = gr.Dropdown(
|
| 311 |
label="Stable Diffusion Model",
|
| 312 |
+
choices=["sd3 medium", "sd2.1", "sdxl", "sdxl flash"],
|
| 313 |
value="sd3 medium",
|
| 314 |
)
|
| 315 |
run_button = gr.Button("Run")
|