Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved. | |
| # Copyright (c) VectorSpaceLab and its affiliates. All rights reserved. | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import os | |
| import dotenv | |
| dotenv.load_dotenv(override=True) | |
| import gradio as gr | |
| import spaces | |
| import argparse | |
| import json | |
| import random | |
| from datetime import datetime | |
| from glob import glob | |
| from typing import Literal | |
| import torch | |
| from torchvision.transforms.functional import to_pil_image, to_tensor | |
| from accelerate import Accelerator | |
| from huggingface_hub import hf_hub_download | |
| from peft import LoraConfig | |
| from safetensors.torch import load_file | |
| from omnigen2.pipelines.omnigen2.pipeline_omnigen2 import OmniGen2Pipeline | |
| from omnigen2.models.transformers.transformer_omnigen2 import OmniGen2Transformer2DModel | |
| from omnigen2.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler | |
| from omnigen2.schedulers.scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler | |
| from omnigen2.utils.img_util import create_collage | |
| NEGATIVE_PROMPT = "(((deformed))), blurry, over saturation, bad anatomy, disfigured, poorly drawn face, mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), fused fingers, messy drawing, broken legs censor, censored, censor_bar" | |
| SAVE_DIR = "output/gradio" | |
| pipeline = None | |
| accelerator = None | |
| save_images = False | |
| enable_taylorseer = False | |
| enable_teacache = False | |
| def load_pipeline(accelerator, weight_dtype, args): | |
| pipeline = OmniGen2Pipeline.from_pretrained( | |
| args.model_path, | |
| torch_dtype=weight_dtype, | |
| trust_remote_code=True, | |
| ) | |
| pipeline.transformer = OmniGen2Transformer2DModel.from_pretrained( | |
| args.model_path, | |
| subfolder="transformer", | |
| torch_dtype=weight_dtype, | |
| ) | |
| lora_path = hf_hub_download("bytedance-research/UMO", "UMO_OmniGen2.safetensors") if args.lora_path is None else args.lora_path | |
| target_modules = ["to_k", "to_q", "to_v", "to_out.0"] | |
| lora_config = LoraConfig( | |
| r=512, | |
| lora_alpha=512, | |
| lora_dropout=0, | |
| init_lora_weights="gaussian", | |
| target_modules=target_modules, | |
| ) | |
| pipeline.transformer.add_adapter(lora_config) | |
| lora_state_dict = load_file(lora_path, device=accelerator.device.__str__()) | |
| pipeline.transformer.load_state_dict(lora_state_dict, strict=False) | |
| pipeline.transformer.fuse_lora(lora_scale=1, safe_fusing=False, adapter_names=["default"]) | |
| pipeline.transformer.unload_lora() | |
| if args.enable_sequential_cpu_offload: | |
| pipeline.enable_sequential_cpu_offload() | |
| elif args.enable_model_cpu_offload: | |
| pipeline.enable_model_cpu_offload() | |
| else: | |
| pipeline = pipeline.to(accelerator.device) | |
| return pipeline | |
| def run( | |
| instruction, | |
| width_input, | |
| height_input, | |
| image_input_1, | |
| image_input_2, | |
| image_input_3, | |
| scheduler: Literal["euler", "dpmsolver++"] = "euler", | |
| num_inference_steps: int = 50, | |
| negative_prompt: str = NEGATIVE_PROMPT, | |
| guidance_scale_input: float = 5.0, | |
| img_guidance_scale_input: float = 2.0, | |
| cfg_range_start: float = 0.0, | |
| cfg_range_end: float = 1.0, | |
| num_images_per_prompt: int = 1, | |
| max_input_image_side_length: int = 2048, | |
| max_pixels: int = 1024 * 1024, | |
| seed_input: int = -1, | |
| align_res: bool = True, | |
| ): | |
| if enable_taylorseer: | |
| pipeline.enable_taylorseer = True | |
| elif enable_teacache: | |
| pipeline.transformer.enable_teacache = True | |
| pipeline.transformer.teacache_rel_l1_thresh = 0.05 | |
| input_images = [image_input_1, image_input_2, image_input_3] | |
| input_images = [img for img in input_images if img is not None] | |
| if len(input_images) == 0: | |
| input_images = None | |
| if seed_input == -1: | |
| seed_input = random.randint(0, 2**16 - 1) | |
| generator = torch.Generator(device="cpu").manual_seed(seed_input) # set random to cpu to avoid different result on different GPU | |
| if scheduler == 'euler' and not isinstance(pipeline.scheduler, FlowMatchEulerDiscreteScheduler): | |
| pipeline.scheduler = FlowMatchEulerDiscreteScheduler() | |
| elif scheduler == 'dpmsolver++' and not isinstance(pipeline.scheduler, DPMSolverMultistepScheduler): | |
| pipeline.scheduler = DPMSolverMultistepScheduler( | |
| algorithm_type="dpmsolver++", | |
| solver_type="midpoint", | |
| solver_order=2, | |
| prediction_type="flow_prediction", | |
| ) | |
| results = pipeline( | |
| prompt=instruction, | |
| input_images=input_images, | |
| width=width_input, | |
| height=height_input, | |
| align_res=align_res, | |
| max_input_image_side_length=max_input_image_side_length, | |
| max_pixels=max_pixels, | |
| num_inference_steps=num_inference_steps, | |
| max_sequence_length=1024, | |
| text_guidance_scale=guidance_scale_input, | |
| image_guidance_scale=img_guidance_scale_input, | |
| cfg_range=(cfg_range_start, cfg_range_end), | |
| negative_prompt=negative_prompt, | |
| num_images_per_prompt=num_images_per_prompt, | |
| generator=generator, | |
| output_type="pil", | |
| ) | |
| vis_images = [to_tensor(image) * 2 - 1 for image in results.images] | |
| output_image = create_collage(vis_images) | |
| output_path = "" | |
| if save_images: | |
| # Create outputs directory if it doesn't exist | |
| output_dir = SAVE_DIR | |
| os.makedirs(output_dir, exist_ok=True) | |
| # Generate unique filename with timestamp | |
| timestamp = datetime.now().strftime("%Y_%m_%d-%H_%M_%S") | |
| # Generate unique filename with timestamp | |
| output_path = os.path.join(output_dir, f"{timestamp}_seed{seed_input}_{instruction[:20]}.png") | |
| # Save the image | |
| output_image.save(output_path) | |
| # Save All Generated Images | |
| if len(results.images) > 1: | |
| for i, image in enumerate(results.images): | |
| image_name, ext = os.path.splitext(output_path) | |
| image.save(f"{image_name}_{i}{ext}") | |
| return output_image, output_path | |
| def get_examples(base_dir="assets/examples/OmniGen2"): | |
| example_keys = ["instruction", "width_input", "height_input", "image_input_1", "image_input_2", "image_input_3", "seed_input", "align_res", "output_image", "output_image_OmniGen2"] | |
| examples = [] | |
| example_configs = glob(os.path.join(base_dir, "*", "config.json")) | |
| for config_path in example_configs: | |
| with open(config_path, "r", encoding="utf-8") as f: | |
| config = json.load(f) | |
| _example = [config.get(k, None) for k in example_keys] | |
| examples.append(_example) | |
| return examples | |
| with open("assets/logo.svg", "r", encoding="utf-8") as svg_file: | |
| logo_content = svg_file.read() | |
| title = f""" | |
| <div style="display: flex; align-items: center; justify-content: center;"> | |
| <span style="transform: scale(0.7);margin-right: -5px;">{logo_content}</span> | |
| <span style="font-size: 1.8em;margin-left: -10px;font-weight: bold; font-family: Gill Sans;">UMO (based on OmniGen2) by UXO Team</span> | |
| </div> | |
| """.strip() | |
| badges_text = r""" | |
| <div style="text-align: center; display: flex; justify-content: center; gap: 5px;"> | |
| <a href="https://github.com/bytedance/UMO"><img alt="Build" src="https://img.shields.io/github/stars/bytedance/UMO"></a> | |
| <a href="https://bytedance.github.io/UMO/"><img alt="Build" src="https://img.shields.io/badge/Project%20Page-UMO-blue"></a> | |
| <a href="https://huggingface.co/bytedance-research/UMO"><img src="https://img.shields.io/static/v1?label=%F0%9F%A4%97%20Hugging%20Face&message=Model&color=green"></a> | |
| <a href="https://arxiv.org/abs/2509.06818"><img alt="Build" src="https://img.shields.io/badge/arXiv%20paper-UMO-b31b1b.svg"></a> | |
| <a href="https://huggingface.co/spaces/bytedance-research/UMO_UNO"><img src="https://img.shields.io/static/v1?label=%F0%9F%A4%97%20Demo&message=UMO-UNO&color=orange"></a> | |
| <a href="https://huggingface.co/spaces/bytedance-research/UMO_OmniGen2"><img src="https://img.shields.io/static/v1?label=%F0%9F%A4%97%20Demo&message=UMO-OmniGen2&color=orange"></a> | |
| </div> | |
| """.strip() | |
| tips = """ | |
| 📌 ***UMO*** is a **U**nified **M**ulti-identity **O**ptimization framework to *boost the multi-ID fidelity and mitigate confusion* for image customization model, and the latest addition to the UXO family (<a href='https://github.com/bytedance/UMO' target='_blank'> UMO</a>, <a href='https://github.com/bytedance/USO' target='_blank'> USO</a> and <a href='https://github.com/bytedance/UNO' target='_blank'> UNO</a>). | |
| 🎨 UMO in the demo is trained based on <a href='https://github.com/VectorSpaceLab/OmniGen2' target='_blank'> OmniGen2</a>. | |
| 💡 We provide step-by-step instructions in our <a href='https://github.com/bytedance/UMO' target='_blank'> Github Repo</a>. Additionally, try the examples and comparison provided below the demo to quickly get familiar with UMO and spark your creativity! | |
| <details> | |
| <summary style="cursor: pointer; color: #d34c0e; font-weight: 500;"> ⚡️ Tips from the based OmniGen2</summary> | |
| - Image Quality: Use high-resolution images (**at least 512x512 recommended**). | |
| - Be Specific: Instead of "Add bird to desk", try "Add the bird from image 1 to the desk in image 2". | |
| - Use English: English prompts currently yield better results. | |
| - Increase image_guidance_scale for better consistency with the reference image: | |
| - Image Editing: 1.3 - 2.0 | |
| - In-context Generation: 2.0 - 3.0 | |
| - For in-context edit (edit based multiple images), we recommend using the following prompt format: "Edit the first image: add/replace (the [object] with) the [object] from the second image. [descripton for your target image]." | |
| - For example: "Edit the first image: add the man from the second image. The man is talking with a woman in the kitchen" | |
| """.strip() | |
| article = """ | |
| ```bibtex | |
| @misc{cheng2025umoscalingmultiidentityconsistency, | |
| title={UMO: Scaling Multi-Identity Consistency for Image Customization via Matching Reward}, | |
| author={Yufeng Cheng and Wenxu Wu and Shaojin Wu and Mengqi Huang and Fei Ding and Qian He}, | |
| year={2025}, | |
| eprint={2509.06818}, | |
| archivePrefix={arXiv}, | |
| primaryClass={cs.CV}, | |
| url={https://arxiv.org/abs/2509.06818}, | |
| } | |
| ``` | |
| """.strip() | |
| star = f""" | |
| If UMO is helpful, please help to ⭐ our <a href='https://github.com/bytedance/UMO' target='_blank'> Github Repo</a> or cite our paper. Thanks a lot! | |
| {article} | |
| """ | |
| def main(args): | |
| # Gradio | |
| with gr.Blocks() as demo: | |
| gr.Markdown(title) | |
| gr.Markdown(badges_text) | |
| gr.Markdown(tips) | |
| with gr.Row(): | |
| with gr.Column(): | |
| # text prompt | |
| instruction = gr.Textbox( | |
| label='Enter your prompt', | |
| info='Use "first/second image" or “第一张图/第二张图” as reference.', | |
| placeholder="Type your prompt here...", | |
| ) | |
| with gr.Row(equal_height=True): | |
| # input images | |
| image_input_1 = gr.Image(label="First Image", type="pil") | |
| image_input_2 = gr.Image(label="Second Image", type="pil") | |
| image_input_3 = gr.Image(label="Third Image", type="pil") | |
| generate_button = gr.Button("Generate Image") | |
| negative_prompt = gr.Textbox( | |
| label="Enter your negative prompt", | |
| placeholder="Type your negative prompt here...", | |
| value=NEGATIVE_PROMPT, | |
| ) | |
| # slider | |
| with gr.Row(equal_height=True): | |
| height_input = gr.Slider( | |
| label="Height", minimum=256, maximum=2048, value=1024, step=128 | |
| ) | |
| width_input = gr.Slider( | |
| label="Width", minimum=256, maximum=2048, value=1024, step=128 | |
| ) | |
| with gr.Accordion("Speed Up Options", open=True): | |
| with gr.Row(equal_height=True): | |
| global enable_taylorseer | |
| global enable_teacache | |
| enable_taylorseer = gr.Checkbox(label="Using TaylorSeer to speed up", value=True) | |
| enable_teacache = gr.Checkbox(label="Using TeaCache to speed up", value=False) | |
| with gr.Row(equal_height=True): | |
| scheduler_input = gr.Dropdown( | |
| label="Scheduler", | |
| choices=["euler", "dpmsolver++"], | |
| value="euler", | |
| info="The scheduler to use for the model.", | |
| ) | |
| num_inference_steps = gr.Slider( | |
| label="Inference Steps", minimum=20, maximum=100, value=50, step=1 | |
| ) | |
| with gr.Accordion("Advanced Options", open=False): | |
| with gr.Row(equal_height=True): | |
| align_res = gr.Checkbox( | |
| label="Align Resolution", | |
| info="Align output's resolution with the first reference image. Only valid when there is only one reference image.", | |
| value=True | |
| ) | |
| with gr.Row(equal_height=True): | |
| text_guidance_scale_input = gr.Slider( | |
| label="Text Guidance Scale", | |
| minimum=1.0, | |
| maximum=8.0, | |
| value=5.0, | |
| step=0.1, | |
| ) | |
| image_guidance_scale_input = gr.Slider( | |
| label="Image Guidance Scale", | |
| minimum=1.0, | |
| maximum=3.0, | |
| value=2.0, | |
| step=0.1, | |
| ) | |
| with gr.Row(equal_height=True): | |
| cfg_range_start = gr.Slider( | |
| label="CFG Range Start", | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.0, | |
| step=0.1, | |
| ) | |
| cfg_range_end = gr.Slider( | |
| label="CFG Range End", | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=1.0, | |
| step=0.1, | |
| ) | |
| def adjust_end_slider(start_val, end_val): | |
| return max(start_val, end_val) | |
| def adjust_start_slider(end_val, start_val): | |
| return min(end_val, start_val) | |
| cfg_range_start.input( | |
| fn=adjust_end_slider, | |
| inputs=[cfg_range_start, cfg_range_end], | |
| outputs=[cfg_range_end] | |
| ) | |
| cfg_range_end.input( | |
| fn=adjust_start_slider, | |
| inputs=[cfg_range_end, cfg_range_start], | |
| outputs=[cfg_range_start] | |
| ) | |
| with gr.Row(equal_height=True): | |
| num_images_per_prompt = gr.Slider( | |
| label="Number of images per prompt", | |
| minimum=1, | |
| maximum=4, | |
| value=1, | |
| step=1, | |
| ) | |
| seed_input = gr.Slider( | |
| label="Seed", minimum=-1, maximum=2147483647, value=-1, step=1 | |
| ) | |
| with gr.Row(equal_height=True): | |
| max_input_image_side_length = gr.Slider( | |
| label="max_input_image_side_length", | |
| minimum=256, | |
| maximum=2048, | |
| value=2048, | |
| step=256, | |
| ) | |
| max_pixels = gr.Slider( | |
| label="max_pixels", | |
| minimum=256 * 256, | |
| maximum=1536 * 1536, | |
| value=1024 * 1024, | |
| step=256 * 256, | |
| ) | |
| with gr.Column(): | |
| with gr.Column(): | |
| # output image | |
| output_image = gr.Image(label="Output Image") | |
| global save_images | |
| # save_images = gr.Checkbox(label="Save generated images", value=True) | |
| save_images = True | |
| with gr.Accordion("Examples Comparison with OmniGen2", open=False): | |
| output_image_omnigen2 = gr.Image(label="Generated Image (OmniGen2)") | |
| download_btn = gr.File(label="Download full-resolution", type="filepath", interactive=False) | |
| gr.Markdown(star) | |
| global accelerator | |
| global pipeline | |
| bf16 = True | |
| accelerator = Accelerator(mixed_precision="bf16" if bf16 else "no") | |
| weight_dtype = torch.bfloat16 if bf16 else torch.float32 | |
| pipeline = load_pipeline(accelerator, weight_dtype, args) | |
| # click | |
| generate_button.click( | |
| run, | |
| inputs=[ | |
| instruction, | |
| width_input, | |
| height_input, | |
| image_input_1, | |
| image_input_2, | |
| image_input_3, | |
| scheduler_input, | |
| num_inference_steps, | |
| negative_prompt, | |
| text_guidance_scale_input, | |
| image_guidance_scale_input, | |
| cfg_range_start, | |
| cfg_range_end, | |
| num_images_per_prompt, | |
| max_input_image_side_length, | |
| max_pixels, | |
| seed_input, | |
| align_res, | |
| ], | |
| outputs=[output_image, download_btn], | |
| ) | |
| gr.Examples( | |
| examples=get_examples("assets/examples/OmniGen2"), | |
| inputs=[ | |
| instruction, | |
| width_input, | |
| height_input, | |
| image_input_1, | |
| image_input_2, | |
| image_input_3, | |
| seed_input, | |
| align_res, | |
| output_image, | |
| output_image_omnigen2, | |
| ], | |
| label="We provide examples for academic research. The vast majority of images used in this demo are either generated or from open-source datasets. If you have any concerns, please contact us, and we will promptly remove any inappropriate content.", | |
| examples_per_page=15 | |
| ) | |
| # launch | |
| demo.launch(share=args.share, server_port=args.port, server_name=args.server_name, ssr_mode=False) | |
| def parse_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--share", action="store_true", help="Share the Gradio app") | |
| parser.add_argument( | |
| "--port", type=int, default=7860, help="Port to use for the Gradio app" | |
| ) | |
| parser.add_argument( | |
| "--server_name", type=str, default=None | |
| ) | |
| parser.add_argument( | |
| "--model_path", | |
| type=str, | |
| default="OmniGen2/OmniGen2", | |
| help="Path or HuggingFace name of the model to load." | |
| ) | |
| parser.add_argument( | |
| "--enable_model_cpu_offload", | |
| action="store_true", | |
| help="Enable model CPU offload." | |
| ) | |
| parser.add_argument( | |
| "--enable_sequential_cpu_offload", | |
| action="store_true", | |
| help="Enable sequential CPU offload." | |
| ) | |
| parser.add_argument( | |
| "--lora_path", | |
| type=str, | |
| default=None, | |
| help="Path to the LoRA checkpoint to load." | |
| ) | |
| args = parser.parse_args() | |
| return args | |
| if __name__ == "__main__": | |
| args = parse_args() | |
| main(args) | |