# app.py - Carrega base + UNet de repo privado separado
# Data e hora atuais para referência: Sunday, May 4, 2025 at 8:23:22 PM -03
import os, random, uuid, json
import gradio as gr
import numpy as np
from PIL import Image
import spaces
import torch
# Importar UNet também
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerAncestralDiscreteScheduler
import time
from huggingface_hub import HfApi
# --- Configurações ---
base_model_id = "sd-community/sdxl-flash" # Ou o base que o Space usava
# ID do Repositório PRIVADO que contém APENAS o UNet treinado
tuned_unet_repo_id = "borsojj/unet" # <<< Repo com seu UNet
DESCRIPTION = f"Interface usando base `{base_model_id}` com UNet de `{tuned_unet_repo_id}`."
if not torch.cuda.is_available():
DESCRIPTION += "\n**Atenção:** Rodando em CPU 🥶 - A geração pode ser muito lenta ou falhar."
MAX_SEED = np.iinfo(np.int32).max
CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "1") == "1"
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
print(f"Carregando pipeline base de: {base_model_id}...")
hf_token = os.getenv("HF_TOKEN") # Pega token dos segredos do Space
if hf_token:
print("Segredo HF_TOKEN encontrado.")
else:
# AVISO IMPORTANTE se o repo UNET for privado!
print("AVISO: Segredo HF_TOKEN NÃO encontrado. O carregamento do UNet treinado falhará se o repositório for privado.")
start_time = time.time()
pipe = None
loading_error_message = ""
try:
# 1. Carrega o pipeline base completo
pipe = StableDiffusionXLPipeline.from_pretrained(
base_model_id,
torch_dtype=torch_dtype,
use_safetensors=True,
add_watermarker=False,
token=hf_token # Passa token caso o base precise também
)
print(f"Pipeline base '{base_model_id}' carregado.")
# 2. Tenta carregar e substituir o UNet do repo separado e PRIVADO
if hf_token: # Só tenta carregar se houver token
print(f"Tentando carregar UNet treinado do repo privado: {tuned_unet_repo_id}")
try:
# Carrega o UNet do repo ID, usando o token
tuned_unet = UNet2DConditionModel.from_pretrained(
tuned_unet_repo_id,
torch_dtype=torch_dtype, # Carrega com mesmo dtype
token=hf_token # Usa o token para acessar repo privado
)
print("UNet treinado carregado. Substituindo UNet no pipeline...")
pipe.unet = tuned_unet # A SUBSTITUIÇÃO
print("UNet substituído com sucesso.")
except Exception as unet_load_e:
loading_error_message = f"**ERRO:** Falha ao carregar UNet de `{tuned_unet_repo_id}` (Verifique token e repo). Usando UNet base. Erro: `{unet_load_e}`"
print(loading_error_message)
DESCRIPTION += "\n" + loading_error_message
# Continua com o UNet base se falhar
else:
# Se não há token, não pode carregar UNet privado
loading_error_message = f"**AVISO:** HF_TOKEN não encontrado. Não é possível carregar UNet do repositório privado `{tuned_unet_repo_id}`. Usando UNet base."
print(loading_error_message)
DESCRIPTION += "\n" + loading_error_message
# Configura o scheduler (como antes)
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
print(f"Scheduler configurado para: {pipe.scheduler.__class__.__name__}")
print(f"Movendo pipeline final para o device: {device}")
pipe.to(device) # Move para o device APÓS substituir o UNet
print("Pipeline pronto no device.")
except Exception as e:
# Erro ao carregar o pipeline BASE
print(f"Erro CRÍTICO ao carregar o pipeline base de '{base_model_id}': {e}")
loading_error_message = f"**ERRO CRÍTICO:** Não foi possível carregar o pipeline base `{base_model_id}`. Erro: `{e}`."
DESCRIPTION += "\n" + loading_error_message
pipe = None
# Função generate (sem alterações significativas)
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
if randomize_seed: seed = random.randint(0, MAX_SEED)
return seed
@spaces.GPU(duration=90)
def generate(
prompt: str,
negative_prompt: str = "",
use_negative_prompt: bool = False,
seed: int = 0,
width: int = 1024,
height: int = 1024,
guidance_scale: float = 7.0,
num_inference_steps: int = 25,
randomize_seed: bool = False,
progress=gr.Progress(track_tqdm=True),
):
if pipe is None: raise gr.Error(f"Pipeline não carregado. {loading_error_message}")
pipe.to(device)
seed = int(randomize_seed_fn(seed, randomize_seed))
generator = torch.Generator(device=device).manual_seed(seed)
if not use_negative_prompt: negative_prompt = None
options = {"prompt":prompt, "negative_prompt":negative_prompt, "width":width, "height":height, "guidance_scale":guidance_scale, "num_inference_steps":num_inference_steps, "generator":generator, "output_type":"pil"}
print(f"Gerando imagem com seed: {seed}, Steps: {num_inference_steps}, Guidance: {guidance_scale}")
start_gen_time = time.time()
try:
images = pipe(**options).images
print(f"Gerado {len(images)} imagem(s) em {time.time() - start_gen_time:.2f} segundos.")
return images, seed
except Exception as e:
print(f"Erro durante a geração: {e}")
raise gr.Error(f"Erro durante a geração: {e}")
# Interface Gradio
examples = [
"photo of a futuristic city skyline at sunset, high detail",
"an oil painting of a cute cat wearing a wizard hat",
"Astronaut in a jungle, cold color palette, oil pastel, detailed, 8k",
"An alien grasping a sign board contain word 'Flash', futuristic, neonpunk, detailed"
]
css = ".gradio-container { max-width: 800px !important; margin: 0 auto !important; } h1{ text-align:center }"
with gr.Blocks(css=css) as demo:
gr.Markdown(f"""# SDXL Base com UNet Fine-tuned (`{tuned_unet_repo_id}`)
Base: `{base_model_id}`
{DESCRIPTION}
**Aviso:** O filtro de conteúdo explícito foi desativado. Use prompts com cuidado.""")
with gr.Group():
with gr.Row():
prompt = gr.Text(label="Prompt", show_label=False, max_lines=3, placeholder="Descreva a imagem...", container=False)
run_button = gr.Button("Gerar Imagem", variant="primary", scale=0)
result = gr.Gallery(label="Resultado", show_label=False, elem_id="gallery", columns=1, height=768)
with gr.Accordion("Opções Avançadas", open=False):
with gr.Row():
use_negative_prompt = gr.Checkbox(label="Usar prompt negativo", value=False)
negative_prompt = gr.Text(
label="Prompt Negativo",
max_lines=3,
lines=2,
placeholder="O que evitar na imagem (ex.: blurry, deformed)...",
value="(deformed, distorted, disfigured:1.3), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, (mutated hands and fingers:1.4), disconnected limbs, mutation, mutated, blurry, amputation",
visible=False
)
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
randomize_seed = gr.Checkbox(label="Seed Aleatória", value=True)
with gr.Row():
width = gr.Slider(label="Largura", minimum=512, maximum=MAX_IMAGE_SIZE, step=64, value=1024)
height = gr.Slider(label="Altura", minimum=512, maximum=MAX_IMAGE_SIZE, step=64, value=1024)
with gr.Row():
guidance_scale = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=20.0, step=0.5, value=7.0)
num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=100, step=1, value=25)
gr.Examples(examples=examples, inputs=prompt, outputs=[result, seed], fn=generate, cache_examples=CACHE_EXAMPLES)
use_negative_prompt.change(fn=lambda x: gr.update(visible=x), inputs=use_negative_prompt, outputs=negative_prompt, api_name=False)
generate_inputs = [prompt, negative_prompt, use_negative_prompt, seed, width, height, guidance_scale, num_inference_steps, randomize_seed]
gr.on(triggers=[prompt.submit, run_button.click], fn=generate, inputs=generate_inputs, outputs=[result, seed], api_name="generate_image")
# Lança a interface no ambiente Space
if __name__ == "__main__":
if pipe is not None:
print("Lançando interface Gradio no Space...")
demo.queue().launch() # Importante para Spaces
else:
print("ERRO CRÍTICO: Pipeline não carregado ou UNet não substituído corretamente. Lançando UI de erro.")
with gr.Blocks() as error_demo:
gr.Markdown(f"# Erro ao Carregar Modelo\n{DESCRIPTION}")
error_demo.queue().launch()