ghibli-lora / app.py
osmr's picture
Add remove-bg
d5501ef
import gradio as gr
import numpy as np
import random
from typing import Optional
from rembg import remove
# import spaces #[uncomment to use ZeroGPU]
from diffusers import StableDiffusionPipeline, StableDiffusionControlNetPipeline, ControlNetModel
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
if torch.cuda.is_available():
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
DEFAULT_SEED = 42
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
DEFAULT_WIDTH = 512
DEFAULT_HEIGHT = 512
DEFAULT_GS = 7.5
DEFAULT_LS = 1.0
DEFAULT_NUM_INF_STEPS = 50
DEFAULT_CN_COND_SCALE = 1.0
DEFAULT_IPA_SCALE = 0.5
# @spaces.GPU #[uncomment to use ZeroGPU]
def infer(lora_model_id: Optional[str] = "osmr/stable-diffusion-v1-4-lora-iv-ghibli",
prompt: str = "",
negative_prompt: str = "",
seed: Optional[int] = DEFAULT_SEED,
randomize_seed: bool = True,
width: int = DEFAULT_WIDTH,
height: int = DEFAULT_HEIGHT,
guidance_scale: Optional[float] = DEFAULT_GS,
lora_scale: Optional[float] = DEFAULT_LS,
num_inference_steps: Optional[int] = DEFAULT_NUM_INF_STEPS,
controlnet_type: str = "Edge-Detection",
controlnet_cond_scale: float = DEFAULT_CN_COND_SCALE,
controlnet_image: object = None,
ipadapter_scale: float = DEFAULT_IPA_SCALE,
ipadapter_image: object = None,
do_remove_bg: bool = False,
progress = gr.Progress(track_tqdm=True)):
use_lora = (lora_model_id in [
"osmr/stable-diffusion-v1-4-lora-iv-ghibli",
"osmr/stable-diffusion-v1-4-lora-db-ghibli",
"osmr/stable-diffusion-v1-5-lora-iv-ghibli",
"osmr/stable-diffusion-v1-5-lora-db-ghibli",
])
if not use_lora:
model_id = lora_model_id
lora_model_id = None
else:
if lora_model_id == "osmr/stable-diffusion-v1-4-lora-iv-ghibli":
model_id = "CompVis/stable-diffusion-v1-4"
elif lora_model_id == "osmr/stable-diffusion-v1-4-lora-db-ghibli":
model_id = "CompVis/stable-diffusion-v1-4"
elif lora_model_id == "osmr/stable-diffusion-v1-5-lora-iv-ghibli":
model_id = "runwayml/stable-diffusion-v1-5"
elif lora_model_id == "osmr/stable-diffusion-v1-5-lora-db-ghibli":
model_id = "runwayml/stable-diffusion-v1-5"
else:
model_id = lora_model_id
lora_model_id = None
sd_version = "1.5" if (model_id == "runwayml/stable-diffusion-v1-5") else "1.4"
use_controlnet = (controlnet_image is not None)
use_ipadapter = (ipadapter_image is not None)
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
if use_controlnet:
if sd_version == "1.4":
if controlnet_type == "Edge-Detection":
controlnet_id = "lllyasviel/sd-controlnet-canny"
else:
controlnet_id = "lllyasviel/sd-controlnet-openpose"
else:
if controlnet_type == "Edge-Detection":
controlnet_id = "lllyasviel/control_v11p_sd15_canny"
else:
controlnet_id = "lllyasviel/control_v11p_sd15_openpose"
controlnet = ControlNetModel.from_pretrained(
pretrained_model_name_or_path=controlnet_id,
torch_dtype=torch_dtype)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
pretrained_model_name_or_path=model_id,
controlnet=controlnet,
torch_dtype=torch_dtype)
else:
pipe = StableDiffusionPipeline.from_pretrained(
pretrained_model_name_or_path=model_id,
torch_dtype=torch_dtype)
if use_ipadapter:
pipe.load_ip_adapter(
"h94/IP-Adapter",
subfolder="models",
weight_name="ip-adapter_sd15.bin")
pipe.set_ip_adapter_scale(ipadapter_scale)
if use_lora:
pipe.load_lora_weights(lora_model_id)
cross_attention_kwargs = {"scale": lora_scale}
else:
cross_attention_kwargs = None
pipe = pipe.to(device)
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator,
cross_attention_kwargs=cross_attention_kwargs,
image=controlnet_image,
controlnet_conditioning_scale=(float(controlnet_cond_scale) if use_controlnet else None),
ip_adapter_image=ipadapter_image
).images[0]
if do_remove_bg:
image = remove(image)
return image, seed
examples = [
"GBL, a man and a woman sitting at a table with glasses of wine in front of them",
"a man and a woman sitting at a table with glasses of wine in front of them",
"GBL, a man sitting at a desk in a library with a book open in front of him",
"GBL, a cartoon woman is standing in front of a wall",
]
css = """
#col-container {
margin: 0 auto;
max-width: 640px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(" # Ghibli LoRA generation")
with gr.Row():
lora_model_id = gr.Dropdown(
choices=[
"osmr/stable-diffusion-v1-4-lora-iv-ghibli",
"osmr/stable-diffusion-v1-4-lora-db-ghibli",
"osmr/stable-diffusion-v1-5-lora-iv-ghibli",
"osmr/stable-diffusion-v1-5-lora-db-ghibli",
"CompVis/stable-diffusion-v1-4",
"runwayml/stable-diffusion-v1-5"],
multiselect=False,
allow_custom_value=True,
label="Model",
)
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
value="GBL, a man and a woman sitting at a table with glasses of wine in front of them",
placeholder="Enter your prompt",
container=False,
)
run_button = gr.Button("Run", scale=0, variant="primary")
result = gr.Image(label="Result", show_label=False)
with gr.Accordion("Advanced Settings", open=False):
negative_prompt = gr.Text(
label="Negative prompt",
max_lines=1,
value="low quality, deformed, ugly, bad art, poorly drawn, bad anatomy, low detail, unrealistic",
placeholder="Enter a negative prompt",
visible=True,
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=DEFAULT_SEED,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=DEFAULT_WIDTH,
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=DEFAULT_HEIGHT,
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=10.0,
step=0.1,
value=DEFAULT_GS,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=50,
step=1,
value=DEFAULT_NUM_INF_STEPS,
)
with gr.Row():
lora_scale = gr.Slider(
label="LoRA scale",
minimum=0.0,
maximum=1.0,
step=0.1,
value=DEFAULT_LS,
)
do_remove_bg = gr.Checkbox(label="Remove background", value=False)
with gr.Accordion("ControlNet Settings", open=False):
controlnet_type = gr.Dropdown(
choices=[
"Edge-Detection",
"Pose-Estimation"],
interactive=True,
label="ControlNet Type",
)
controlnet_cond_scale = gr.Slider(
label="ControlNet Conditioning Scale",
minimum=0.0,
maximum=2.0,
step=0.1,
value=DEFAULT_CN_COND_SCALE
)
controlnet_image = gr.Image(
label="Control Image",
type="pil",
show_label=True)
with gr.Accordion("IP-adapter Settings", open=False):
ipadapter_scale = gr.Slider(
label="IP-adapter Scale",
minimum=0.0,
maximum=1.0,
step=0.1,
value=DEFAULT_IPA_SCALE
)
ipadapter_image = gr.Image(
label="IP-adapter Image",
type="pil",
show_label=True)
gr.Examples(examples=examples, inputs=[prompt])
gr.on(
triggers=[run_button.click, prompt.submit],
fn=infer,
inputs=[
lora_model_id,
prompt,
negative_prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
lora_scale,
num_inference_steps,
controlnet_type,
controlnet_cond_scale,
controlnet_image,
ipadapter_scale,
ipadapter_image,
do_remove_bg,
],
outputs=[result, seed],
)
if __name__ == "__main__":
demo.launch(share=True)