# app.py - Carrega base + UNet de repo privado separado # Data e hora atuais para referência: Sunday, May 4, 2025 at 8:23:22 PM -03 import os, random, uuid, json import gradio as gr import numpy as np from PIL import Image import spaces import torch # Importar UNet também from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerAncestralDiscreteScheduler import time from huggingface_hub import HfApi # --- Configurações --- base_model_id = "sd-community/sdxl-flash" # Ou o base que o Space usava # ID do Repositório PRIVADO que contém APENAS o UNet treinado tuned_unet_repo_id = "borsojj/unet" # <<< Repo com seu UNet DESCRIPTION = f"Interface usando base `{base_model_id}` com UNet de `{tuned_unet_repo_id}`." if not torch.cuda.is_available(): DESCRIPTION += "\n**Atenção:** Rodando em CPU 🥶 - A geração pode ser muito lenta ou falhar." MAX_SEED = np.iinfo(np.int32).max CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "1") == "1" MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096")) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 print(f"Carregando pipeline base de: {base_model_id}...") hf_token = os.getenv("HF_TOKEN") # Pega token dos segredos do Space if hf_token: print("Segredo HF_TOKEN encontrado.") else: # AVISO IMPORTANTE se o repo UNET for privado! print("AVISO: Segredo HF_TOKEN NÃO encontrado. O carregamento do UNet treinado falhará se o repositório for privado.") start_time = time.time() pipe = None loading_error_message = "" try: # 1. Carrega o pipeline base completo pipe = StableDiffusionXLPipeline.from_pretrained( base_model_id, torch_dtype=torch_dtype, use_safetensors=True, add_watermarker=False, token=hf_token # Passa token caso o base precise também ) print(f"Pipeline base '{base_model_id}' carregado.") # 2. Tenta carregar e substituir o UNet do repo separado e PRIVADO if hf_token: # Só tenta carregar se houver token print(f"Tentando carregar UNet treinado do repo privado: {tuned_unet_repo_id}") try: # Carrega o UNet do repo ID, usando o token tuned_unet = UNet2DConditionModel.from_pretrained( tuned_unet_repo_id, torch_dtype=torch_dtype, # Carrega com mesmo dtype token=hf_token # Usa o token para acessar repo privado ) print("UNet treinado carregado. Substituindo UNet no pipeline...") pipe.unet = tuned_unet # A SUBSTITUIÇÃO print("UNet substituído com sucesso.") except Exception as unet_load_e: loading_error_message = f"**ERRO:** Falha ao carregar UNet de `{tuned_unet_repo_id}` (Verifique token e repo). Usando UNet base. Erro: `{unet_load_e}`" print(loading_error_message) DESCRIPTION += "\n" + loading_error_message # Continua com o UNet base se falhar else: # Se não há token, não pode carregar UNet privado loading_error_message = f"**AVISO:** HF_TOKEN não encontrado. Não é possível carregar UNet do repositório privado `{tuned_unet_repo_id}`. Usando UNet base." print(loading_error_message) DESCRIPTION += "\n" + loading_error_message # Configura o scheduler (como antes) pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config) print(f"Scheduler configurado para: {pipe.scheduler.__class__.__name__}") print(f"Movendo pipeline final para o device: {device}") pipe.to(device) # Move para o device APÓS substituir o UNet print("Pipeline pronto no device.") except Exception as e: # Erro ao carregar o pipeline BASE print(f"Erro CRÍTICO ao carregar o pipeline base de '{base_model_id}': {e}") loading_error_message = f"**ERRO CRÍTICO:** Não foi possível carregar o pipeline base `{base_model_id}`. Erro: `{e}`." DESCRIPTION += "\n" + loading_error_message pipe = None # Função generate (sem alterações significativas) def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: if randomize_seed: seed = random.randint(0, MAX_SEED) return seed @spaces.GPU(duration=90) def generate( prompt: str, negative_prompt: str = "", use_negative_prompt: bool = False, seed: int = 0, width: int = 1024, height: int = 1024, guidance_scale: float = 7.0, num_inference_steps: int = 25, randomize_seed: bool = False, progress=gr.Progress(track_tqdm=True), ): if pipe is None: raise gr.Error(f"Pipeline não carregado. {loading_error_message}") pipe.to(device) seed = int(randomize_seed_fn(seed, randomize_seed)) generator = torch.Generator(device=device).manual_seed(seed) if not use_negative_prompt: negative_prompt = None options = {"prompt":prompt, "negative_prompt":negative_prompt, "width":width, "height":height, "guidance_scale":guidance_scale, "num_inference_steps":num_inference_steps, "generator":generator, "output_type":"pil"} print(f"Gerando imagem com seed: {seed}, Steps: {num_inference_steps}, Guidance: {guidance_scale}") start_gen_time = time.time() try: images = pipe(**options).images print(f"Gerado {len(images)} imagem(s) em {time.time() - start_gen_time:.2f} segundos.") return images, seed except Exception as e: print(f"Erro durante a geração: {e}") raise gr.Error(f"Erro durante a geração: {e}") # Interface Gradio examples = [ "photo of a futuristic city skyline at sunset, high detail", "an oil painting of a cute cat wearing a wizard hat", "Astronaut in a jungle, cold color palette, oil pastel, detailed, 8k", "An alien grasping a sign board contain word 'Flash', futuristic, neonpunk, detailed" ] css = ".gradio-container { max-width: 800px !important; margin: 0 auto !important; } h1{ text-align:center }" with gr.Blocks(css=css) as demo: gr.Markdown(f"""# SDXL Base com UNet Fine-tuned (`{tuned_unet_repo_id}`) Base: `{base_model_id}` {DESCRIPTION} **Aviso:** O filtro de conteúdo explícito foi desativado. Use prompts com cuidado.""") with gr.Group(): with gr.Row(): prompt = gr.Text(label="Prompt", show_label=False, max_lines=3, placeholder="Descreva a imagem...", container=False) run_button = gr.Button("Gerar Imagem", variant="primary", scale=0) result = gr.Gallery(label="Resultado", show_label=False, elem_id="gallery", columns=1, height=768) with gr.Accordion("Opções Avançadas", open=False): with gr.Row(): use_negative_prompt = gr.Checkbox(label="Usar prompt negativo", value=False) negative_prompt = gr.Text( label="Prompt Negativo", max_lines=3, lines=2, placeholder="O que evitar na imagem (ex.: blurry, deformed)...", value="(deformed, distorted, disfigured:1.3), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, (mutated hands and fingers:1.4), disconnected limbs, mutation, mutated, blurry, amputation", visible=False ) seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0) randomize_seed = gr.Checkbox(label="Seed Aleatória", value=True) with gr.Row(): width = gr.Slider(label="Largura", minimum=512, maximum=MAX_IMAGE_SIZE, step=64, value=1024) height = gr.Slider(label="Altura", minimum=512, maximum=MAX_IMAGE_SIZE, step=64, value=1024) with gr.Row(): guidance_scale = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=20.0, step=0.5, value=7.0) num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=100, step=1, value=25) gr.Examples(examples=examples, inputs=prompt, outputs=[result, seed], fn=generate, cache_examples=CACHE_EXAMPLES) use_negative_prompt.change(fn=lambda x: gr.update(visible=x), inputs=use_negative_prompt, outputs=negative_prompt, api_name=False) generate_inputs = [prompt, negative_prompt, use_negative_prompt, seed, width, height, guidance_scale, num_inference_steps, randomize_seed] gr.on(triggers=[prompt.submit, run_button.click], fn=generate, inputs=generate_inputs, outputs=[result, seed], api_name="generate_image") # Lança a interface no ambiente Space if __name__ == "__main__": if pipe is not None: print("Lançando interface Gradio no Space...") demo.queue().launch() # Importante para Spaces else: print("ERRO CRÍTICO: Pipeline não carregado ou UNet não substituído corretamente. Lançando UI de erro.") with gr.Blocks() as error_demo: gr.Markdown(f"# Erro ao Carregar Modelo\n{DESCRIPTION}") error_demo.queue().launch()