# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved. # Copyright (c) VectorSpaceLab and its affiliates. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import dotenv dotenv.load_dotenv(override=True) import gradio as gr import spaces import argparse import json import random from datetime import datetime from glob import glob from typing import Literal import torch from torchvision.transforms.functional import to_pil_image, to_tensor from accelerate import Accelerator from huggingface_hub import hf_hub_download from peft import LoraConfig from safetensors.torch import load_file from omnigen2.pipelines.omnigen2.pipeline_omnigen2 import OmniGen2Pipeline from omnigen2.models.transformers.transformer_omnigen2 import OmniGen2Transformer2DModel from omnigen2.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler from omnigen2.schedulers.scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler from omnigen2.utils.img_util import create_collage NEGATIVE_PROMPT = "(((deformed))), blurry, over saturation, bad anatomy, disfigured, poorly drawn face, mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), fused fingers, messy drawing, broken legs censor, censored, censor_bar" SAVE_DIR = "output/gradio" pipeline = None accelerator = None save_images = False enable_taylorseer = False enable_teacache = False def load_pipeline(accelerator, weight_dtype, args): pipeline = OmniGen2Pipeline.from_pretrained( args.model_path, torch_dtype=weight_dtype, trust_remote_code=True, ) pipeline.transformer = OmniGen2Transformer2DModel.from_pretrained( args.model_path, subfolder="transformer", torch_dtype=weight_dtype, ) lora_path = hf_hub_download("bytedance-research/UMO", "UMO_OmniGen2.safetensors") if args.lora_path is None else args.lora_path target_modules = ["to_k", "to_q", "to_v", "to_out.0"] lora_config = LoraConfig( r=512, lora_alpha=512, lora_dropout=0, init_lora_weights="gaussian", target_modules=target_modules, ) pipeline.transformer.add_adapter(lora_config) lora_state_dict = load_file(lora_path, device=accelerator.device.__str__()) pipeline.transformer.load_state_dict(lora_state_dict, strict=False) pipeline.transformer.fuse_lora(lora_scale=1, safe_fusing=False, adapter_names=["default"]) pipeline.transformer.unload_lora() if args.enable_sequential_cpu_offload: pipeline.enable_sequential_cpu_offload() elif args.enable_model_cpu_offload: pipeline.enable_model_cpu_offload() else: pipeline = pipeline.to(accelerator.device) return pipeline @spaces.GPU(duration=120) def run( instruction, width_input, height_input, image_input_1, image_input_2, image_input_3, scheduler: Literal["euler", "dpmsolver++"] = "euler", num_inference_steps: int = 50, negative_prompt: str = NEGATIVE_PROMPT, guidance_scale_input: float = 5.0, img_guidance_scale_input: float = 2.0, cfg_range_start: float = 0.0, cfg_range_end: float = 1.0, num_images_per_prompt: int = 1, max_input_image_side_length: int = 2048, max_pixels: int = 1024 * 1024, seed_input: int = -1, align_res: bool = True, ): if enable_taylorseer: pipeline.enable_taylorseer = True elif enable_teacache: pipeline.transformer.enable_teacache = True pipeline.transformer.teacache_rel_l1_thresh = 0.05 input_images = [image_input_1, image_input_2, image_input_3] input_images = [img for img in input_images if img is not None] if len(input_images) == 0: input_images = None if seed_input == -1: seed_input = random.randint(0, 2**16 - 1) generator = torch.Generator(device="cpu").manual_seed(seed_input) # set random to cpu to avoid different result on different GPU if scheduler == 'euler' and not isinstance(pipeline.scheduler, FlowMatchEulerDiscreteScheduler): pipeline.scheduler = FlowMatchEulerDiscreteScheduler() elif scheduler == 'dpmsolver++' and not isinstance(pipeline.scheduler, DPMSolverMultistepScheduler): pipeline.scheduler = DPMSolverMultistepScheduler( algorithm_type="dpmsolver++", solver_type="midpoint", solver_order=2, prediction_type="flow_prediction", ) results = pipeline( prompt=instruction, input_images=input_images, width=width_input, height=height_input, align_res=align_res, max_input_image_side_length=max_input_image_side_length, max_pixels=max_pixels, num_inference_steps=num_inference_steps, max_sequence_length=1024, text_guidance_scale=guidance_scale_input, image_guidance_scale=img_guidance_scale_input, cfg_range=(cfg_range_start, cfg_range_end), negative_prompt=negative_prompt, num_images_per_prompt=num_images_per_prompt, generator=generator, output_type="pil", ) vis_images = [to_tensor(image) * 2 - 1 for image in results.images] output_image = create_collage(vis_images) output_path = "" if save_images: # Create outputs directory if it doesn't exist output_dir = SAVE_DIR os.makedirs(output_dir, exist_ok=True) # Generate unique filename with timestamp timestamp = datetime.now().strftime("%Y_%m_%d-%H_%M_%S") # Generate unique filename with timestamp output_path = os.path.join(output_dir, f"{timestamp}_seed{seed_input}_{instruction[:20]}.png") # Save the image output_image.save(output_path) # Save All Generated Images if len(results.images) > 1: for i, image in enumerate(results.images): image_name, ext = os.path.splitext(output_path) image.save(f"{image_name}_{i}{ext}") return output_image, output_path def get_examples(base_dir="assets/examples/OmniGen2"): example_keys = ["instruction", "width_input", "height_input", "image_input_1", "image_input_2", "image_input_3", "seed_input", "align_res", "output_image", "output_image_OmniGen2"] examples = [] example_configs = glob(os.path.join(base_dir, "*", "config.json")) for config_path in example_configs: with open(config_path, "r", encoding="utf-8") as f: config = json.load(f) _example = [config.get(k, None) for k in example_keys] examples.append(_example) return examples with open("assets/logo.svg", "r", encoding="utf-8") as svg_file: logo_content = svg_file.read() title = f"""
{logo_content} UMO (based on OmniGen2) by UXO Team
""".strip() badges_text = r"""
Build Build Build
""".strip() tips = """ 📌 ***UMO*** is a **U**nified **M**ulti-identity **O**ptimization framework to *boost the multi-ID fidelity and mitigate confusion* for image customization model, and the latest addition to the UXO family ( UMO, USO and UNO). 🎨 UMO in the demo is trained based on OmniGen2. 💡 We provide step-by-step instructions in our Github Repo. Additionally, try the examples and comparison provided below the demo to quickly get familiar with UMO and spark your creativity!
⚡️ Tips from the based OmniGen2 - Image Quality: Use high-resolution images (**at least 512x512 recommended**). - Be Specific: Instead of "Add bird to desk", try "Add the bird from image 1 to the desk in image 2". - Use English: English prompts currently yield better results. - Increase image_guidance_scale for better consistency with the reference image: - Image Editing: 1.3 - 2.0 - In-context Generation: 2.0 - 3.0 - For in-context edit (edit based multiple images), we recommend using the following prompt format: "Edit the first image: add/replace (the [object] with) the [object] from the second image. [descripton for your target image]." - For example: "Edit the first image: add the man from the second image. The man is talking with a woman in the kitchen" """.strip() article = """ ```bibtex @misc{cheng2025umoscalingmultiidentityconsistency, title={UMO: Scaling Multi-Identity Consistency for Image Customization via Matching Reward}, author={Yufeng Cheng and Wenxu Wu and Shaojin Wu and Mengqi Huang and Fei Ding and Qian He}, year={2025}, eprint={2509.06818}, archivePrefix={arXiv}, primaryClass={cs.CV}, url={https://arxiv.org/abs/2509.06818}, } ``` """.strip() star = f""" If UMO is helpful, please help to ⭐ our Github Repo or cite our paper. Thanks a lot! {article} """ def main(args): # Gradio with gr.Blocks() as demo: gr.Markdown(title) gr.Markdown(badges_text) gr.Markdown(tips) with gr.Row(): with gr.Column(): # text prompt instruction = gr.Textbox( label='Enter your prompt', info='Use "first/second image" or “第一张图/第二张图” as reference.', placeholder="Type your prompt here...", ) with gr.Row(equal_height=True): # input images image_input_1 = gr.Image(label="First Image", type="pil") image_input_2 = gr.Image(label="Second Image", type="pil") image_input_3 = gr.Image(label="Third Image", type="pil") generate_button = gr.Button("Generate Image") negative_prompt = gr.Textbox( label="Enter your negative prompt", placeholder="Type your negative prompt here...", value=NEGATIVE_PROMPT, ) # slider with gr.Row(equal_height=True): height_input = gr.Slider( label="Height", minimum=256, maximum=2048, value=1024, step=128 ) width_input = gr.Slider( label="Width", minimum=256, maximum=2048, value=1024, step=128 ) with gr.Accordion("Speed Up Options", open=True): with gr.Row(equal_height=True): global enable_taylorseer global enable_teacache enable_taylorseer = gr.Checkbox(label="Using TaylorSeer to speed up", value=True) enable_teacache = gr.Checkbox(label="Using TeaCache to speed up", value=False) with gr.Row(equal_height=True): scheduler_input = gr.Dropdown( label="Scheduler", choices=["euler", "dpmsolver++"], value="euler", info="The scheduler to use for the model.", ) num_inference_steps = gr.Slider( label="Inference Steps", minimum=20, maximum=100, value=50, step=1 ) with gr.Accordion("Advanced Options", open=False): with gr.Row(equal_height=True): align_res = gr.Checkbox( label="Align Resolution", info="Align output's resolution with the first reference image. Only valid when there is only one reference image.", value=True ) with gr.Row(equal_height=True): text_guidance_scale_input = gr.Slider( label="Text Guidance Scale", minimum=1.0, maximum=8.0, value=5.0, step=0.1, ) image_guidance_scale_input = gr.Slider( label="Image Guidance Scale", minimum=1.0, maximum=3.0, value=2.0, step=0.1, ) with gr.Row(equal_height=True): cfg_range_start = gr.Slider( label="CFG Range Start", minimum=0.0, maximum=1.0, value=0.0, step=0.1, ) cfg_range_end = gr.Slider( label="CFG Range End", minimum=0.0, maximum=1.0, value=1.0, step=0.1, ) def adjust_end_slider(start_val, end_val): return max(start_val, end_val) def adjust_start_slider(end_val, start_val): return min(end_val, start_val) cfg_range_start.input( fn=adjust_end_slider, inputs=[cfg_range_start, cfg_range_end], outputs=[cfg_range_end] ) cfg_range_end.input( fn=adjust_start_slider, inputs=[cfg_range_end, cfg_range_start], outputs=[cfg_range_start] ) with gr.Row(equal_height=True): num_images_per_prompt = gr.Slider( label="Number of images per prompt", minimum=1, maximum=4, value=1, step=1, ) seed_input = gr.Slider( label="Seed", minimum=-1, maximum=2147483647, value=-1, step=1 ) with gr.Row(equal_height=True): max_input_image_side_length = gr.Slider( label="max_input_image_side_length", minimum=256, maximum=2048, value=2048, step=256, ) max_pixels = gr.Slider( label="max_pixels", minimum=256 * 256, maximum=1536 * 1536, value=1024 * 1024, step=256 * 256, ) with gr.Column(): with gr.Column(): # output image output_image = gr.Image(label="Output Image") global save_images # save_images = gr.Checkbox(label="Save generated images", value=True) save_images = True with gr.Accordion("Examples Comparison with OmniGen2", open=False): output_image_omnigen2 = gr.Image(label="Generated Image (OmniGen2)") download_btn = gr.File(label="Download full-resolution", type="filepath", interactive=False) gr.Markdown(star) global accelerator global pipeline bf16 = True accelerator = Accelerator(mixed_precision="bf16" if bf16 else "no") weight_dtype = torch.bfloat16 if bf16 else torch.float32 pipeline = load_pipeline(accelerator, weight_dtype, args) # click generate_button.click( run, inputs=[ instruction, width_input, height_input, image_input_1, image_input_2, image_input_3, scheduler_input, num_inference_steps, negative_prompt, text_guidance_scale_input, image_guidance_scale_input, cfg_range_start, cfg_range_end, num_images_per_prompt, max_input_image_side_length, max_pixels, seed_input, align_res, ], outputs=[output_image, download_btn], ) gr.Examples( examples=get_examples("assets/examples/OmniGen2"), inputs=[ instruction, width_input, height_input, image_input_1, image_input_2, image_input_3, seed_input, align_res, output_image, output_image_omnigen2, ], label="We provide examples for academic research. The vast majority of images used in this demo are either generated or from open-source datasets. If you have any concerns, please contact us, and we will promptly remove any inappropriate content.", examples_per_page=15 ) # launch demo.launch(share=args.share, server_port=args.port, server_name=args.server_name, ssr_mode=False) def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--share", action="store_true", help="Share the Gradio app") parser.add_argument( "--port", type=int, default=7860, help="Port to use for the Gradio app" ) parser.add_argument( "--server_name", type=str, default=None ) parser.add_argument( "--model_path", type=str, default="OmniGen2/OmniGen2", help="Path or HuggingFace name of the model to load." ) parser.add_argument( "--enable_model_cpu_offload", action="store_true", help="Enable model CPU offload." ) parser.add_argument( "--enable_sequential_cpu_offload", action="store_true", help="Enable sequential CPU offload." ) parser.add_argument( "--lora_path", type=str, default=None, help="Path to the LoRA checkpoint to load." ) args = parser.parse_args() return args if __name__ == "__main__": args = parse_args() main(args)