borsojj commited on
Commit
0fad32a
·
verified ·
1 Parent(s): 88de06b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +129 -163
app.py CHANGED
@@ -1,201 +1,167 @@
 
 
1
  import os, random, uuid, json
2
  import gradio as gr
3
  import numpy as np
4
  from PIL import Image
5
  import spaces
6
  import torch
7
- from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
 
 
 
8
 
9
- DESCRIPTION = None
 
 
 
 
 
10
  if not torch.cuda.is_available():
11
- DESCRIPTION = "\nRunning on CPU 🥶 This demo may not work on CPU."
12
 
13
  MAX_SEED = np.iinfo(np.int32).max
14
  CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "1") == "1"
15
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
16
- USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
17
- ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
18
 
19
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
20
 
21
- pipe = StableDiffusionXLPipeline.from_pretrained(
22
- "sd-community/sdxl-flash",
23
- torch_dtype=torch.float16,
24
- use_safetensors=True,
25
- add_watermarker=False
26
- )
27
- pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
28
-
29
- if torch.cuda.is_available():
30
- pipe.to("cuda")
31
  else:
32
- pipe.to("cpu")
33
-
34
- def save_image(img):
35
- unique_name = str(uuid.uuid4()) + ".png"
36
- img.save(unique_name)
37
- return unique_name
38
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
40
- if randomize_seed:
41
- seed = random.randint(0, MAX_SEED)
42
  return seed
43
 
44
- @spaces.GPU(duration=30)
45
  def generate(
46
  prompt: str,
47
  negative_prompt: str = "",
48
- use_negative_prompt: bool = False,
49
  seed: int = 0,
50
  width: int = 1024,
51
  height: int = 1024,
52
- guidance_scale: float = 3,
53
  num_inference_steps: int = 25,
54
  randomize_seed: bool = False,
55
- use_resolution_binning: bool = True,
56
  progress=gr.Progress(track_tqdm=True),
57
  ):
 
58
  pipe.to(device)
59
  seed = int(randomize_seed_fn(seed, randomize_seed))
60
- generator = torch.Generator().manual_seed(seed)
61
-
62
- options = {
63
- "prompt":prompt,
64
- "negative_prompt":negative_prompt,
65
- "width":width,
66
- "height":height,
67
- "guidance_scale":guidance_scale,
68
- "num_inference_steps":num_inference_steps,
69
- "generator":generator,
70
- "use_resolution_binning":use_resolution_binning,
71
- "output_type":"pil",
72
-
73
- }
74
-
75
- images = pipe(**options).images
76
-
77
- image_paths = [save_image(img) for img in images]
78
- return image_paths, seed
79
-
80
-
81
- examples = [
82
- "a cat eating a piece of cheese",
83
- "a ROBOT riding a BLUE horse on Mars, photorealistic",
84
- "a cartoon of a IRONMAN fighting with HULK, wall painting",
85
- "a cute robot artist painting on an easel, concept art",
86
- "Astronaut in a jungle, cold color palette, oil pastel, detailed, 8k",
87
- "An alien grasping a sign board contain word 'Flash', futuristic, neonpunk, detailed",
88
- "Kids going to school, Anime style"
89
- ]
90
-
91
- css = '''
92
- .gradio-container {
93
- max-width: 700px !important;
94
- margin: 0 auto !important;
95
- }
96
- h1{text-align:left}
97
- '''
98
  with gr.Blocks(css=css) as demo:
99
- gr.Markdown(f"""# SDXL Flash
100
- ### First Image processing takes time then images generate faster.
101
  {DESCRIPTION}""")
102
  with gr.Group():
103
  with gr.Row():
104
- prompt = gr.Text(
105
- label="Prompt",
106
- show_label=False,
107
- max_lines=1,
108
- placeholder="Enter your prompt",
109
- container=False,
110
- )
111
- run_button = gr.Button("Run", scale=0)
112
- result = gr.Gallery(label="Result", columns=1)
113
- with gr.Accordion("Advanced options", open=False):
114
  with gr.Row():
115
- use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=True)
116
- negative_prompt = gr.Text(
117
- label="Negative prompt",
118
- max_lines=5,
119
- lines=4,
120
- placeholder="Enter a negative prompt",
121
- value="(deformed, distorted, disfigured:1.3), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, (mutated hands and fingers:1.4), disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation, NSFW",
122
- visible=True,
123
- )
124
- seed = gr.Slider(
125
- label="Seed",
126
- minimum=0,
127
- maximum=MAX_SEED,
128
- step=1,
129
- value=0,
130
- )
131
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
132
- with gr.Row(visible=True):
133
- width = gr.Slider(
134
- label="Width",
135
- minimum=512,
136
- maximum=MAX_IMAGE_SIZE,
137
- step=64,
138
- value=1024,
139
- )
140
- height = gr.Slider(
141
- label="Height",
142
- minimum=512,
143
- maximum=MAX_IMAGE_SIZE,
144
- step=64,
145
- value=1024,
146
- )
147
  with gr.Row():
148
- guidance_scale = gr.Slider(
149
- label="Guidance Scale",
150
- minimum=0.1,
151
- maximum=6,
152
- step=0.1,
153
- value=3.0,
154
- )
155
- num_inference_steps = gr.Slider(
156
- label="Number of inference steps",
157
- minimum=1,
158
- maximum=15,
159
- step=1,
160
- value=8,
161
- )
162
-
163
- gr.Examples(
164
- examples=examples,
165
- inputs=prompt,
166
- outputs=[result, seed],
167
- fn=generate,
168
- cache_examples=True,
169
- )
170
-
171
- use_negative_prompt.change(
172
- fn=lambda x: gr.update(visible=x),
173
- inputs=use_negative_prompt,
174
- outputs=negative_prompt,
175
- api_name=False,
176
- )
177
-
178
- gr.on(
179
- triggers=[
180
- prompt.submit,
181
- negative_prompt.submit,
182
- run_button.click,
183
- ],
184
- fn=generate,
185
- inputs=[
186
- prompt,
187
- negative_prompt,
188
- use_negative_prompt,
189
- seed,
190
- width,
191
- height,
192
- guidance_scale,
193
- num_inference_steps,
194
- randomize_seed,
195
- ],
196
- outputs=[result, seed],
197
- api_name="run",
198
- )
199
-
200
  if __name__ == "__main__":
201
- demo.launch()
 
 
 
 
 
 
 
 
1
+ # app.py - Carrega base + UNet de repo privado separado
2
+ # Data e hora atuais para referência: Sunday, May 4, 2025 at 8:23:22 PM -03
3
  import os, random, uuid, json
4
  import gradio as gr
5
  import numpy as np
6
  from PIL import Image
7
  import spaces
8
  import torch
9
+ # Importar UNet também
10
+ from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerAncestralDiscreteScheduler
11
+ import time
12
+ from huggingface_hub import HfApi
13
 
14
+ # --- Configurações ---
15
+ base_model_id = "sd-community/sdxl-flash" # Ou o base que o Space usava
16
+ # ID do Repositório PRIVADO que contém APENAS o UNet treinado
17
+ tuned_unet_repo_id = "borsojj/unet" # <<< Repo com seu UNet
18
+
19
+ DESCRIPTION = f"Interface usando base `{base_model_id}` com UNet de `{tuned_unet_repo_id}`."
20
  if not torch.cuda.is_available():
21
+ DESCRIPTION += "\n**Atenção:** Rodando em CPU 🥶 - A geração pode ser muito lenta ou falhar."
22
 
23
  MAX_SEED = np.iinfo(np.int32).max
24
  CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "1") == "1"
25
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
 
 
26
 
27
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
28
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
29
 
30
+ print(f"Carregando pipeline base de: {base_model_id}...")
31
+ hf_token = os.getenv("HF_TOKEN") # Pega token dos segredos do Space
32
+ if hf_token:
33
+ print("Segredo HF_TOKEN encontrado.")
 
 
 
 
 
 
34
  else:
35
+ # AVISO IMPORTANTE se o repo UNET for privado!
36
+ print("AVISO: Segredo HF_TOKEN NÃO encontrado. O carregamento do UNet treinado falhará se o repositório for privado.")
37
+
38
+ start_time = time.time()
39
+ pipe = None
40
+ loading_error_message = ""
41
+ try:
42
+ # 1. Carrega o pipeline base completo
43
+ pipe = StableDiffusionXLPipeline.from_pretrained(
44
+ base_model_id,
45
+ torch_dtype=torch_dtype,
46
+ use_safetensors=True,
47
+ add_watermarker=False,
48
+ token=hf_token # Passa token caso o base precise também
49
+ )
50
+ print(f"Pipeline base '{base_model_id}' carregado.")
51
+
52
+ # 2. Tenta carregar e substituir o UNet do repo separado e PRIVADO
53
+ if hf_token: # Só tenta carregar se houver token
54
+ print(f"Tentando carregar UNet treinado do repo privado: {tuned_unet_repo_id}")
55
+ try:
56
+ # Carrega o UNet do repo ID, usando o token
57
+ tuned_unet = UNet2DConditionModel.from_pretrained(
58
+ tuned_unet_repo_id,
59
+ torch_dtype=torch_dtype, # Carrega com mesmo dtype
60
+ token=hf_token # Usa o token para acessar repo privado
61
+ # low_cpu_mem_usage=False # Pode precisar desativar se der erro OOM aqui
62
+ )
63
+ print("UNet treinado carregado. Substituindo UNet no pipeline...")
64
+ pipe.unet = tuned_unet # A SUBSTITUIÇÃO
65
+ print("UNet substituído com sucesso.")
66
+ except Exception as unet_load_e:
67
+ loading_error_message = f"**<font color='red'>ERRO:</font>** Falha ao carregar UNet de `{tuned_unet_repo_id}` (Verifique token e repo). Usando UNet base. Erro: `{unet_load_e}`"
68
+ print(loading_error_message)
69
+ DESCRIPTION += "\n" + loading_error_message
70
+ # Continua com o UNet base se falhar
71
+ else:
72
+ # Se não há token, não pode carregar UNet privado
73
+ loading_error_message = f"**<font color='orange'>AVISO:</font>** HF_TOKEN não encontrado. Não é possível carregar UNet do repositório privado `{tuned_unet_repo_id}`. Usando UNet base."
74
+ print(loading_error_message)
75
+ DESCRIPTION += "\n" + loading_error_message
76
+
77
+
78
+ # Configura o scheduler (como antes)
79
+ pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
80
+ print(f"Scheduler configurado para: {pipe.scheduler.__class__.__name__}")
81
+
82
+ print(f"Movendo pipeline final para o device: {device}")
83
+ pipe.to(device) # Move para o device APÓS substituir o UNet
84
+ print("Pipeline pronto no device.")
85
+
86
+ except Exception as e:
87
+ # Erro ao carregar o pipeline BASE
88
+ print(f"Erro CRÍTICO ao carregar o pipeline base de '{base_model_id}': {e}")
89
+ loading_error_message = f"**<font color='red'>ERRO CRÍTICO:</font>** Não foi possível carregar o pipeline base `{base_model_id}`. Erro: `{e}`."
90
+ DESCRIPTION += "\n" + loading_error_message
91
+ pipe = None
92
+
93
+
94
+ # Função generate (sem alterações significativas)
95
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
96
+ if randomize_seed: seed = random.randint(0, MAX_SEED)
 
97
  return seed
98
 
99
+ @spaces.GPU(duration=90)
100
  def generate(
101
  prompt: str,
102
  negative_prompt: str = "",
103
+ use_negative_prompt: bool = True,
104
  seed: int = 0,
105
  width: int = 1024,
106
  height: int = 1024,
107
+ guidance_scale: float = 7.0,
108
  num_inference_steps: int = 25,
109
  randomize_seed: bool = False,
 
110
  progress=gr.Progress(track_tqdm=True),
111
  ):
112
+ if pipe is None: raise gr.Error(f"Pipeline não carregado. {loading_error_message}")
113
  pipe.to(device)
114
  seed = int(randomize_seed_fn(seed, randomize_seed))
115
+ generator = torch.Generator(device=device).manual_seed(seed)
116
+ if not use_negative_prompt: negative_prompt = None
117
+ options = {"prompt":prompt, "negative_prompt":negative_prompt, "width":width, "height":height, "guidance_scale":guidance_scale, "num_inference_steps":num_inference_steps, "generator":generator, "output_type":"pil"}
118
+ print(f"Gerando imagem com seed: {seed}, Steps: {num_inference_steps}, Guidance: {guidance_scale}")
119
+ start_gen_time = time.time()
120
+ try:
121
+ images = pipe(**options).images
122
+ print(f"Gerado {len(images)} imagem(s) em {time.time() - start_gen_time:.2f} segundos.")
123
+ return images, seed
124
+ except Exception as e:
125
+ print(f"Erro durante a geração: {e}")
126
+ raise gr.Error(f"Erro durante a geração: {e}")
127
+
128
+
129
+ # Interface Gradio (sem alterações significativas)
130
+ examples = [ "photo of a futuristic city skyline at sunset, high detail", "an oil painting of a cute cat wearing a wizard hat", "Astronaut in a jungle, cold color palette, oil pastel, detailed, 8k", "An alien grasping a sign board contain word 'Flash', futuristic, neonpunk, detailed" ]
131
+ css = ".gradio-container { max-width: 800px !important; margin: 0 auto !important; } h1{ text-align:center }"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  with gr.Blocks(css=css) as demo:
133
+ gr.Markdown(f"""# SDXL Base com UNet Fine-tuned (`{tuned_unet_repo_id}`)
134
+ Base: `{base_model_id}`
135
  {DESCRIPTION}""")
136
  with gr.Group():
137
  with gr.Row():
138
+ prompt = gr.Text(label="Prompt", show_label=False, max_lines=3, placeholder="Descreva a imagem...", container=False)
139
+ run_button = gr.Button("Gerar Imagem", variant="primary", scale=0)
140
+ result = gr.Gallery(label="Resultado", show_label=False, elem_id="gallery", columns=1, height=768)
141
+ with gr.Accordion("Opções Avançadas", open=False):
 
 
 
 
 
 
142
  with gr.Row():
143
+ use_negative_prompt = gr.Checkbox(label="Usar prompt negativo", value=True)
144
+ negative_prompt = gr.Text(label="Prompt Negativo", max_lines=3, lines=2, placeholder="O que NÃO ver...", value="(deformed, distorted, disfigured:1.3), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, (mutated hands and fingers:1.4), disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation, NSFW", visible=True)
145
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
146
+ randomize_seed = gr.Checkbox(label="Seed Aleatória", value=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  with gr.Row():
148
+ width = gr.Slider(label="Largura", minimum=512, maximum=MAX_IMAGE_SIZE, step=64, value=1024)
149
+ height = gr.Slider(label="Altura", minimum=512, maximum=MAX_IMAGE_SIZE, step=64, value=1024)
150
+ with gr.Row():
151
+ guidance_scale = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=20.0, step=0.5, value=7.0)
152
+ num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=100, step=1, value=25)
153
+ gr.Examples(examples=examples, inputs=prompt, outputs=[result, seed], fn=generate, cache_examples=CACHE_EXAMPLES)
154
+ use_negative_prompt.change(fn=lambda x: gr.update(visible=x), inputs=use_negative_prompt, outputs=negative_prompt, api_name=False)
155
+ generate_inputs = [prompt, negative_prompt, use_negative_prompt, seed, width, height, guidance_scale, num_inference_steps, randomize_seed]
156
+ gr.on(triggers=[prompt.submit, run_button.click], fn=generate, inputs=generate_inputs, outputs=[result, seed], api_name="generate_image")
157
+
158
+ # Lança a interface no ambiente Space
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  if __name__ == "__main__":
160
+ if pipe is not None:
161
+ print("Lançando interface Gradio no Space...")
162
+ demo.queue().launch() # Importante para Spaces
163
+ else:
164
+ print("ERRO CRÍTICO: Pipeline não carregado ou UNet não substituído corretamente. Lançando UI de erro.")
165
+ with gr.Blocks() as error_demo:
166
+ gr.Markdown(f"# Erro ao Carregar Modelo\n{DESCRIPTION}")
167
+ error_demo.queue().launch()