Spaces:
Sleeping
Sleeping
| # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved. | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import dataclasses | |
| import json | |
| import os | |
| from pathlib import Path | |
| import gradio as gr | |
| import torch | |
| import spaces | |
| from uso.flux.pipeline import USOPipeline | |
| from transformers import SiglipVisionModel, SiglipImageProcessor | |
| with open("assets/uso_text.svg", "r", encoding="utf-8") as svg_file: | |
| text_content = svg_file.read() | |
| with open("assets/uso_logo.svg", "r", encoding="utf-8") as svg_file: | |
| logo_content = svg_file.read() | |
| title = f""" | |
| <div style="display: flex; align-items: center; justify-content: center;"> | |
| <span style="transform: scale(0.7);margin-right: -5px;">{text_content}</span> | |
| <span style="font-size: 1.8em;margin-left: -10px;font-weight: bold; font-family: Gill Sans;">바이트댄스 USO</span> | |
| <span style="margin-left: 0px; transform: scale(0.85); display: inline-block;">{logo_content}</span> | |
| </div> | |
| """.strip() | |
| badges_text = r""" | |
| <div style="text-align: center; display: flex; justify-content: center; gap: 5px;"> | |
| <a href="https://github.com/bytedance/USO"><img src="https://img.shields.io/static/v1?label=GitHub&message=Code&color=green&logo=github"></a> | |
| <a href="https://bytedance.github.io/USO/"><img alt="Build" src="https://img.shields.io/badge/Project%20Page-USO-yellow"></a> | |
| <a href="https://arxiv.org/abs/2504.02160"><img alt="Build" src="https://img.shields.io/badge/arXiv%20paper-USO-b31b1b.svg"></a> | |
| <a href="https://huggingface.co/bytedance-research/USO"><img src="https://img.shields.io/static/v1?label=%F0%9F%A4%97%20Hugging%20Face&message=Model&color=orange"></a> | |
| </div> | |
| """.strip() | |
| tips = """ | |
| **USO란 무엇인가요?** 🎨 | |
| USO는 통합 스타일-주제 최적화 커스터마이징 모델이며, UXO 패밀리(<a href='https://github.com/bytedance/USO' target='_blank'> USO</a>와 <a href='https://github.com/bytedance/UNO' target='_blank'> UNO</a>)의 최신 버전입니다. | |
| 모든 주제를 모든 스타일과 모든 시나리오에서 자유롭게 결합할 수 있습니다. | |
| **사용 방법** 💡 | |
| <a href='https://github.com/bytedance/USO' target='_blank'> Github 레포지토리</a>에서 단계별 지침을 제공합니다. | |
| 또한 아래 데모에서 제공되는 예제들을 시도해보시고, USO에 빠르게 익숙해지고 창의력을 발휘해보세요! | |
| <details> | |
| <summary style="cursor: pointer; color: #d34c0e; font-weight: 500;">이 모델은 1024x1024 해상도로 학습되었으며 3가지 유형의 사용을 지원합니다. 📌 팁:</summary> | |
| * **콘텐츠 이미지만**: 다음 유형을 지원합니다: | |
| * 주제/아이덴티티 기반 (자연스러운 프롬프트 지원, 예: *탁자 위의 시계.* *바다 근처의 여성.*, **사실적인 초상화** 생성에 뛰어남) | |
| * 스타일 편집 (레이아웃 보존): *이미지를 지브리 스타일/픽셀 스타일/레트로 만화 스타일/수채화 스타일로 변환...*. | |
| * 스타일 편집 (레이아웃 변경): *지브리 스타일, 해변의 남자.*. | |
| * **스타일 이미지만**: 입력 스타일을 참조하여 프롬프트에 따라 무엇이든 생성합니다. 이 분야에서 뛰어나며 다중 스타일 참조도 지원합니다 (베타). | |
| * **콘텐츠 이미지 + 스타일 이미지**: 콘텐츠를 원하는 스타일로 배치합니다. | |
| * 레이아웃 보존: 프롬프트를 **비워두세요**. | |
| * 레이아웃 변경: 자연스러운 프롬프트를 사용하세요.</details>""" | |
| star = r""" | |
| USO가 도움이 되셨다면, <a href='https://github.com/bytedance/USO' target='_blank'> Github 레포지토리</a>에 ⭐를 눌러주세요. 감사합니다!""" | |
| def get_examples(examples_dir: str = "assets/examples") -> list: | |
| examples = Path(examples_dir) | |
| ans = [] | |
| for example in examples.iterdir(): | |
| if not example.is_dir() or len(os.listdir(example)) == 0: | |
| continue | |
| with open(example / "config.json") as f: | |
| example_dict = json.load(f) | |
| example_list = [] | |
| # example_list.append(example_dict["usage"]) # case for | |
| example_list.append(example_dict["prompt"]) # prompt | |
| for key in ["image_ref1", "image_ref2", "image_ref3"]: | |
| if key in example_dict: | |
| example_list.append(str(example / example_dict[key])) | |
| else: | |
| example_list.append(None) | |
| example_list.append(example_dict["seed"]) | |
| ans.append(example_list) | |
| return ans | |
| def create_demo( | |
| model_type: str, | |
| device: str = "cuda" if torch.cuda.is_available() else "cpu", | |
| offload: bool = False, | |
| ): | |
| pipeline = USOPipeline( | |
| model_type, device, offload, only_lora=True, lora_rank=128, hf_download=True | |
| ) | |
| print("USOPipeline loaded successfully") | |
| siglip_processor = SiglipImageProcessor.from_pretrained( | |
| "google/siglip-so400m-patch14-384" | |
| ) | |
| siglip_model = SiglipVisionModel.from_pretrained( | |
| "google/siglip-so400m-patch14-384" | |
| ) | |
| siglip_model.eval() | |
| siglip_model.to(device) | |
| pipeline.model.vision_encoder = siglip_model | |
| pipeline.model.vision_encoder_processor = siglip_processor | |
| print("SigLIP model loaded successfully") | |
| pipeline.gradio_generate = spaces.GPU(duration=120)(pipeline.gradio_generate) | |
| with gr.Blocks() as demo: | |
| gr.Markdown(title) | |
| gr.Markdown(badges_text) | |
| gr.Markdown(tips) | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt = gr.Textbox(label="프롬프트", value="아름다운 여성.") | |
| with gr.Row(): | |
| image_prompt1 = gr.Image( | |
| label="콘텐츠 참조 이미지", visible=True, interactive=True, type="pil" | |
| ) | |
| image_prompt2 = gr.Image( | |
| label="스타일 참조 이미지", visible=True, interactive=True, type="pil" | |
| ) | |
| image_prompt3 = gr.Image( | |
| label="추가 스타일 참조 이미지 (베타)", visible=True, interactive=True, type="pil" | |
| ) | |
| with gr.Row(): | |
| with gr.Row(): | |
| width = gr.Slider( | |
| 512, 1536, 1024, step=16, label="생성 너비" | |
| ) | |
| height = gr.Slider( | |
| 512, 1536, 1024, step=16, label="생성 높이" | |
| ) | |
| with gr.Row(): | |
| with gr.Row(): | |
| keep_size = gr.Checkbox( | |
| label="입력 크기 유지", | |
| value=False, | |
| interactive=True | |
| ) | |
| with gr.Column(): | |
| gr.Markdown("스타일 편집만 필요하거나 레이아웃을 유지하려면 True로 설정하세요.") | |
| with gr.Accordion("고급 옵션", open=True): | |
| with gr.Row(): | |
| num_steps = gr.Slider( | |
| 1, 50, 25, step=1, label="단계 수" | |
| ) | |
| guidance = gr.Slider( | |
| 1.0, 5.0, 4.0, step=0.1, label="가이던스", interactive=True | |
| ) | |
| content_long_size = gr.Slider( | |
| 0, 1024, 512, step=16, label="콘텐츠 참조 크기" | |
| ) | |
| seed = gr.Number(-1, label="시드 (랜덤: -1)") | |
| generate_btn = gr.Button("생성") | |
| gr.Markdown(star) | |
| with gr.Column(): | |
| output_image = gr.Image(label="생성된 이미지") | |
| download_btn = gr.File( | |
| label="고해상도 다운로드", type="filepath", interactive=False | |
| ) | |
| inputs = [ | |
| prompt, | |
| image_prompt1, | |
| image_prompt2, | |
| image_prompt3, | |
| seed, | |
| width, | |
| height, | |
| guidance, | |
| num_steps, | |
| keep_size, | |
| content_long_size, | |
| ] | |
| generate_btn.click( | |
| fn=pipeline.gradio_generate, | |
| inputs=inputs, | |
| outputs=[output_image, download_btn], | |
| ) | |
| example_text = gr.Text("", visible=False, label="사용 사례:") | |
| examples = get_examples("./assets/gradio_examples") | |
| gr.Examples( | |
| examples=examples, | |
| inputs=[ | |
| prompt, | |
| image_prompt1, | |
| image_prompt2, | |
| image_prompt3, | |
| seed, | |
| ], | |
| # cache_examples='lazy', | |
| outputs=[output_image, download_btn], | |
| fn=pipeline.gradio_generate, | |
| label='행 1-4: 아이덴티티/주제 기반; 행 5-7: 스타일-주제 기반; 행 8-9: 스타일 기반; 행 10-12: 다중 스타일 기반 작업; 행 13: 텍스트-이미지', | |
| examples_per_page=15 | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| from typing import Literal | |
| from transformers import HfArgumentParser | |
| class AppArgs: | |
| name: Literal["flux-dev", "flux-dev-fp8", "flux-schnell", "flux-krea-dev"] = "flux-dev" | |
| device: Literal["cuda", "cpu"] = "cuda" if torch.cuda.is_available() else "cpu" | |
| offload: bool = dataclasses.field( | |
| default=False, | |
| metadata={ | |
| "help": "If True, sequantial offload the models(ae, dit, text encoder) to CPU if not used." | |
| }, | |
| ) | |
| port: int = 7860 | |
| parser = HfArgumentParser([AppArgs]) | |
| args_tuple = parser.parse_args_into_dataclasses() # type: tuple[AppArgs] | |
| args = args_tuple[0] | |
| demo = create_demo(args.name, args.device, args.offload) | |
| demo.launch(server_port=args.port, ssr_mode=False) | |