Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
76bf5ed
1
Parent(s):
e6c446b
added load model pipeline
Browse files- app.py +52 -28
- pipeline/util.py +1 -31
app.py
CHANGED
|
@@ -8,34 +8,51 @@ from pipeline.util import (
|
|
| 8 |
SAMPLERS,
|
| 9 |
create_hdr_effect,
|
| 10 |
progressive_upscale,
|
| 11 |
-
select_scheduler
|
|
|
|
| 12 |
)
|
| 13 |
|
| 14 |
device = "cuda"
|
| 15 |
-
pipe = None
|
| 16 |
-
last_loaded_model = None
|
| 17 |
MODELS = {"RealVisXL 5 Lightning": "SG161222/RealVisXL_V5.0_Lightning",
|
| 18 |
"RealVisXL 5": "SG161222/RealVisXL_V5.0"
|
| 19 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
-
def load_model(model_id):
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
"brad-twinkl/controlnet-union-sdxl-1.0-promax", torch_dtype=torch.float16
|
| 28 |
-
).to(device)
|
| 29 |
-
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to(device)
|
| 30 |
-
pipe = StableDiffusionXLControlNetTileSRPipeline.from_pretrained(
|
| 31 |
-
MODELS[model_id], controlnet=controlnet, vae=vae, torch_dtype=torch.float16, use_safetensors=True, variant="fp16"
|
| 32 |
-
).to(device)
|
| 33 |
-
#pipe.enable_model_cpu_offload() # << Enable this if you have limited VRAM
|
| 34 |
-
pipe.enable_vae_tiling() # << Enable this if you have limited VRAM
|
| 35 |
-
pipe.enable_vae_slicing() # << Enable this if you have limited VRAM
|
| 36 |
-
last_loaded_model = model_id
|
| 37 |
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
# region functions
|
| 41 |
@spaces.GPU(duration=120)
|
|
@@ -56,14 +73,12 @@ def predict(
|
|
| 56 |
tile_weighting_method,
|
| 57 |
progress=gr.Progress(track_tqdm=True),
|
| 58 |
):
|
| 59 |
-
global pipe
|
| 60 |
-
|
| 61 |
# Load model if changed
|
| 62 |
load_model(model_id)
|
| 63 |
-
|
| 64 |
# Set selected scheduler
|
| 65 |
print(f"Using scheduler: {scheduler}...")
|
| 66 |
-
pipe.scheduler = select_scheduler(pipe, scheduler)
|
| 67 |
|
| 68 |
# Get current image size
|
| 69 |
original_height = image.height
|
|
@@ -86,7 +101,7 @@ def predict(
|
|
| 86 |
|
| 87 |
# Image generation
|
| 88 |
print("Diffusion kicking in... almost done, coffee's on you!")
|
| 89 |
-
image =
|
| 90 |
image=image,
|
| 91 |
control_image=control_image,
|
| 92 |
control_mode=[6],
|
|
@@ -112,6 +127,14 @@ def predict(
|
|
| 112 |
def clear_result():
|
| 113 |
return gr.update(value=None)
|
| 114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
def set_maximum_resolution(max_tile_size, current_value):
|
| 116 |
max_scale = 8 # <- you can try increase it to 12x, 16x if you wish!
|
| 117 |
maximum_value = max_tile_size * max_scale
|
|
@@ -213,7 +236,7 @@ with gr.Blocks(css=css, theme=gr.themes.Ocean(), title="MoD ControlNet Tile Upsc
|
|
| 213 |
with gr.Column(scale=3):
|
| 214 |
with gr.Row():
|
| 215 |
with gr.Column():
|
| 216 |
-
input_image = gr.Image(type="pil", label="Input Image",sources=["upload"], height=500)
|
| 217 |
with gr.Column():
|
| 218 |
result = gr.Image(
|
| 219 |
label="Generated Image", show_label=True, format="png", interactive=False, scale=1, height=500, min_width=670
|
|
@@ -245,7 +268,7 @@ with gr.Blocks(css=css, theme=gr.themes.Ocean(), title="MoD ControlNet Tile Upsc
|
|
| 245 |
with gr.Row(elem_id="parameters_row"):
|
| 246 |
gr.Markdown("### General parameters")
|
| 247 |
model = gr.Dropdown(
|
| 248 |
-
label="Model", choices=MODELS.keys(), value=list(MODELS.keys())[0]
|
| 249 |
)
|
| 250 |
tile_weighting_method = gr.Dropdown(
|
| 251 |
label="Tile Weighting Method", choices=["Cosine", "Gaussian"], value="Cosine"
|
|
@@ -446,6 +469,7 @@ with gr.Blocks(css=css, theme=gr.themes.Ocean(), title="MoD ControlNet Tile Upsc
|
|
| 446 |
|
| 447 |
max_tile_size.select(fn=set_maximum_resolution, inputs=[max_tile_size, resolution], outputs=resolution)
|
| 448 |
tile_weighting_method.change(fn=select_tile_weighting_method, inputs=tile_weighting_method, outputs=tile_gaussian_sigma)
|
|
|
|
| 449 |
generate_button.click(
|
| 450 |
fn=clear_result,
|
| 451 |
inputs=None,
|
|
@@ -468,8 +492,8 @@ with gr.Blocks(css=css, theme=gr.themes.Ocean(), title="MoD ControlNet Tile Upsc
|
|
| 468 |
max_tile_size,
|
| 469 |
tile_weighting_method,
|
| 470 |
],
|
| 471 |
-
outputs=result,
|
| 472 |
-
show_progress="full"
|
| 473 |
)
|
| 474 |
gr.Markdown(about)
|
|
|
|
| 475 |
app.launch(share=False)
|
|
|
|
| 8 |
SAMPLERS,
|
| 9 |
create_hdr_effect,
|
| 10 |
progressive_upscale,
|
| 11 |
+
select_scheduler,
|
| 12 |
+
torch_gc,
|
| 13 |
)
|
| 14 |
|
| 15 |
device = "cuda"
|
|
|
|
|
|
|
| 16 |
MODELS = {"RealVisXL 5 Lightning": "SG161222/RealVisXL_V5.0_Lightning",
|
| 17 |
"RealVisXL 5": "SG161222/RealVisXL_V5.0"
|
| 18 |
}
|
| 19 |
+
class Pipeline:
|
| 20 |
+
def __init__(self):
|
| 21 |
+
self.pipe = None
|
| 22 |
+
self.controlnet = None
|
| 23 |
+
self.vae = None
|
| 24 |
+
self.last_loaded_model = None
|
| 25 |
|
| 26 |
+
def load_model(self, model_id):
|
| 27 |
+
if model_id != self.last_loaded_model:
|
| 28 |
+
print(f"\n--- Loading model: {model_id} ---")
|
| 29 |
+
if self.pipe is not None:
|
| 30 |
+
self.pipe.to("cpu")
|
| 31 |
+
del self.pipe
|
| 32 |
+
self.pipe = None
|
| 33 |
+
del self.controlnet
|
| 34 |
+
self.controlnet = None
|
| 35 |
+
del self.vae
|
| 36 |
+
self.vae = None
|
| 37 |
+
torch_gc()
|
| 38 |
+
|
| 39 |
+
self.controlnet = ControlNetUnionModel.from_pretrained(
|
| 40 |
+
"brad-twinkl/controlnet-union-sdxl-1.0-promax", torch_dtype=torch.float16
|
| 41 |
+
).to(device=device)
|
| 42 |
+
self.vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to(device=device)
|
| 43 |
|
| 44 |
+
self.pipe = StableDiffusionXLControlNetTileSRPipeline.from_pretrained(
|
| 45 |
+
MODELS[model_id], controlnet=self.controlnet, vae=self.vae, torch_dtype=torch.float16, variant="fp16"
|
| 46 |
+
).to(device=device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
+
self.pipe.enable_model_cpu_offload()
|
| 49 |
+
self.pipe.enable_vae_tiling()
|
| 50 |
+
self.pipe.enable_vae_slicing()
|
| 51 |
+
self.last_loaded_model = model_id
|
| 52 |
+
print(f"Model {model_id} loaded.")
|
| 53 |
+
|
| 54 |
+
def __call__(self, *args, **kwargs):
|
| 55 |
+
return self.pipe(*args, **kwargs)
|
| 56 |
|
| 57 |
# region functions
|
| 58 |
@spaces.GPU(duration=120)
|
|
|
|
| 73 |
tile_weighting_method,
|
| 74 |
progress=gr.Progress(track_tqdm=True),
|
| 75 |
):
|
|
|
|
|
|
|
| 76 |
# Load model if changed
|
| 77 |
load_model(model_id)
|
| 78 |
+
|
| 79 |
# Set selected scheduler
|
| 80 |
print(f"Using scheduler: {scheduler}...")
|
| 81 |
+
pipeline.pipe.scheduler = select_scheduler(pipeline.pipe, scheduler)
|
| 82 |
|
| 83 |
# Get current image size
|
| 84 |
original_height = image.height
|
|
|
|
| 101 |
|
| 102 |
# Image generation
|
| 103 |
print("Diffusion kicking in... almost done, coffee's on you!")
|
| 104 |
+
image = pipeline(
|
| 105 |
image=image,
|
| 106 |
control_image=control_image,
|
| 107 |
control_mode=[6],
|
|
|
|
| 127 |
def clear_result():
|
| 128 |
return gr.update(value=None)
|
| 129 |
|
| 130 |
+
def load_model(model_name, on_load=False):
|
| 131 |
+
global pipeline # Declare pipeline as global
|
| 132 |
+
if on_load and 'pipeline' not in globals(): # Prevent reload page
|
| 133 |
+
pipeline = Pipeline() # Create pipeline inside the function
|
| 134 |
+
pipeline.load_model(model_name) # Load the initial model
|
| 135 |
+
elif pipeline is not None and not on_load:
|
| 136 |
+
pipeline.load_model(model_name) # Switch model
|
| 137 |
+
|
| 138 |
def set_maximum_resolution(max_tile_size, current_value):
|
| 139 |
max_scale = 8 # <- you can try increase it to 12x, 16x if you wish!
|
| 140 |
maximum_value = max_tile_size * max_scale
|
|
|
|
| 236 |
with gr.Column(scale=3):
|
| 237 |
with gr.Row():
|
| 238 |
with gr.Column():
|
| 239 |
+
input_image = gr.Image(type="pil", label="Input Image", sources=["upload"], height=500)
|
| 240 |
with gr.Column():
|
| 241 |
result = gr.Image(
|
| 242 |
label="Generated Image", show_label=True, format="png", interactive=False, scale=1, height=500, min_width=670
|
|
|
|
| 268 |
with gr.Row(elem_id="parameters_row"):
|
| 269 |
gr.Markdown("### General parameters")
|
| 270 |
model = gr.Dropdown(
|
| 271 |
+
label="Model", choices=list(MODELS.keys()), value=list(MODELS.keys())[0]
|
| 272 |
)
|
| 273 |
tile_weighting_method = gr.Dropdown(
|
| 274 |
label="Tile Weighting Method", choices=["Cosine", "Gaussian"], value="Cosine"
|
|
|
|
| 469 |
|
| 470 |
max_tile_size.select(fn=set_maximum_resolution, inputs=[max_tile_size, resolution], outputs=resolution)
|
| 471 |
tile_weighting_method.change(fn=select_tile_weighting_method, inputs=tile_weighting_method, outputs=tile_gaussian_sigma)
|
| 472 |
+
|
| 473 |
generate_button.click(
|
| 474 |
fn=clear_result,
|
| 475 |
inputs=None,
|
|
|
|
| 492 |
max_tile_size,
|
| 493 |
tile_weighting_method,
|
| 494 |
],
|
| 495 |
+
outputs=result,
|
|
|
|
| 496 |
)
|
| 497 |
gr.Markdown(about)
|
| 498 |
+
app.load(fn=load_model, inputs=[model, gr.State(value=True)], outputs=None, concurrency_limit=1) # Load initial model on app load
|
| 499 |
app.launch(share=False)
|
pipeline/util.py
CHANGED
|
@@ -16,8 +16,6 @@
|
|
| 16 |
import gc
|
| 17 |
import cv2
|
| 18 |
import numpy as np
|
| 19 |
-
from torch import nn
|
| 20 |
-
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
|
| 21 |
import torch
|
| 22 |
from PIL import Image
|
| 23 |
|
|
@@ -98,32 +96,6 @@ def select_scheduler(pipe, selected_sampler):
|
|
| 98 |
|
| 99 |
return scheduler.from_config(config, **add_kwargs)
|
| 100 |
|
| 101 |
-
def optionally_disable_offloading(_pipeline):
|
| 102 |
-
"""
|
| 103 |
-
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
|
| 104 |
-
|
| 105 |
-
Args:
|
| 106 |
-
_pipeline (`DiffusionPipeline`):
|
| 107 |
-
The pipeline to disable offloading for.
|
| 108 |
-
|
| 109 |
-
Returns:
|
| 110 |
-
tuple:
|
| 111 |
-
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
|
| 112 |
-
"""
|
| 113 |
-
is_model_cpu_offload = False
|
| 114 |
-
is_sequential_cpu_offload = False
|
| 115 |
-
if _pipeline is not None:
|
| 116 |
-
for _, component in _pipeline.components.items():
|
| 117 |
-
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
|
| 118 |
-
if not is_model_cpu_offload:
|
| 119 |
-
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
|
| 120 |
-
if not is_sequential_cpu_offload:
|
| 121 |
-
is_sequential_cpu_offload = isinstance(component._hf_hook, AlignDevicesHook)
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
remove_hook_from_module(component, recurse=True)
|
| 125 |
-
|
| 126 |
-
return (is_model_cpu_offload, is_sequential_cpu_offload)
|
| 127 |
|
| 128 |
# This function was copied and adapted from https://huggingface.co/spaces/gokaygokay/TileUpscalerV2, licensed under Apache 2.0.
|
| 129 |
def progressive_upscale(input_image, target_resolution, steps=3):
|
|
@@ -210,14 +182,12 @@ def create_hdr_effect(original_image, hdr):
|
|
| 210 |
|
| 211 |
|
| 212 |
def torch_gc():
|
|
|
|
| 213 |
if torch.cuda.is_available():
|
| 214 |
with torch.cuda.device("cuda"):
|
| 215 |
torch.cuda.empty_cache()
|
| 216 |
torch.cuda.ipc_collect()
|
| 217 |
|
| 218 |
-
gc.collect()
|
| 219 |
-
|
| 220 |
-
|
| 221 |
def quantize_8bit(unet):
|
| 222 |
if unet is None:
|
| 223 |
return
|
|
|
|
| 16 |
import gc
|
| 17 |
import cv2
|
| 18 |
import numpy as np
|
|
|
|
|
|
|
| 19 |
import torch
|
| 20 |
from PIL import Image
|
| 21 |
|
|
|
|
| 96 |
|
| 97 |
return scheduler.from_config(config, **add_kwargs)
|
| 98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
# This function was copied and adapted from https://huggingface.co/spaces/gokaygokay/TileUpscalerV2, licensed under Apache 2.0.
|
| 101 |
def progressive_upscale(input_image, target_resolution, steps=3):
|
|
|
|
| 182 |
|
| 183 |
|
| 184 |
def torch_gc():
|
| 185 |
+
gc.collect()
|
| 186 |
if torch.cuda.is_available():
|
| 187 |
with torch.cuda.device("cuda"):
|
| 188 |
torch.cuda.empty_cache()
|
| 189 |
torch.cuda.ipc_collect()
|
| 190 |
|
|
|
|
|
|
|
|
|
|
| 191 |
def quantize_8bit(unet):
|
| 192 |
if unet is None:
|
| 193 |
return
|