Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -17,7 +17,7 @@ vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
|
|
| 17 |
)
|
| 18 |
vl_gpt = vl_gpt.cuda().eval()
|
| 19 |
|
| 20 |
-
# === Image generation
|
| 21 |
def text_to_image_generate(input_prompt, output_path, vl_chat_processor, vl_gpt, temperature=1.0, parallel_size=1, cfg_weight=3.0):
|
| 22 |
torch.cuda.empty_cache()
|
| 23 |
|
|
@@ -52,13 +52,22 @@ def text_to_image_generate(input_prompt, output_path, vl_chat_processor, vl_gpt,
|
|
| 52 |
for i in range(image_token_num):
|
| 53 |
if i == 0:
|
| 54 |
outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True)
|
|
|
|
| 55 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
outputs = vl_gpt.language_model.model(
|
| 57 |
inputs_embeds=img_embeds.unsqueeze(1),
|
| 58 |
use_cache=True,
|
| 59 |
past_key_values=past_key_values
|
| 60 |
)
|
| 61 |
-
|
|
|
|
| 62 |
hidden_states = outputs.last_hidden_state
|
| 63 |
logits = vl_gpt.gen_head(hidden_states[:, -1, :])
|
| 64 |
|
|
@@ -87,7 +96,10 @@ def text_to_image_generate(input_prompt, output_path, vl_chat_processor, vl_gpt,
|
|
| 87 |
Image.fromarray(dec[i]).save(save_path)
|
| 88 |
output_images.append(save_path)
|
| 89 |
|
| 90 |
-
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
|
| 93 |
# === Gradio handler ===
|
|
@@ -105,9 +117,9 @@ demo = gr.ChatInterface(
|
|
| 105 |
title="Janus Text-to-Image",
|
| 106 |
description="Generate images from natural language prompts using Janus-4o-7B",
|
| 107 |
examples=[
|
| 108 |
-
"a cat
|
| 109 |
-
"a
|
| 110 |
-
"a
|
| 111 |
],
|
| 112 |
theme="soft",
|
| 113 |
)
|
|
|
|
| 17 |
)
|
| 18 |
vl_gpt = vl_gpt.cuda().eval()
|
| 19 |
|
| 20 |
+
# === Image generation ===
|
| 21 |
def text_to_image_generate(input_prompt, output_path, vl_chat_processor, vl_gpt, temperature=1.0, parallel_size=1, cfg_weight=3.0):
|
| 22 |
torch.cuda.empty_cache()
|
| 23 |
|
|
|
|
| 52 |
for i in range(image_token_num):
|
| 53 |
if i == 0:
|
| 54 |
outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True)
|
| 55 |
+
past_key_values = outputs.past_key_values
|
| 56 |
else:
|
| 57 |
+
# 🧠 SAFE DETACH TRICK
|
| 58 |
+
with torch.no_grad():
|
| 59 |
+
past_key_values = tuple(
|
| 60 |
+
tuple(pkv.detach() for pkv in layer)
|
| 61 |
+
for layer in past_key_values
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
outputs = vl_gpt.language_model.model(
|
| 65 |
inputs_embeds=img_embeds.unsqueeze(1),
|
| 66 |
use_cache=True,
|
| 67 |
past_key_values=past_key_values
|
| 68 |
)
|
| 69 |
+
past_key_values = outputs.past_key_values
|
| 70 |
+
|
| 71 |
hidden_states = outputs.last_hidden_state
|
| 72 |
logits = vl_gpt.gen_head(hidden_states[:, -1, :])
|
| 73 |
|
|
|
|
| 96 |
Image.fromarray(dec[i]).save(save_path)
|
| 97 |
output_images.append(save_path)
|
| 98 |
|
| 99 |
+
torch.cuda.empty_cache()
|
| 100 |
+
torch.cuda.ipc_collect()
|
| 101 |
+
|
| 102 |
+
return output_images[:1] # return only the first image
|
| 103 |
|
| 104 |
|
| 105 |
# === Gradio handler ===
|
|
|
|
| 117 |
title="Janus Text-to-Image",
|
| 118 |
description="Generate images from natural language prompts using Janus-4o-7B",
|
| 119 |
examples=[
|
| 120 |
+
"a cat sitting on a windowsill",
|
| 121 |
+
"a futuristic city at sunset",
|
| 122 |
+
"a dragon flying over mountains",
|
| 123 |
],
|
| 124 |
theme="soft",
|
| 125 |
)
|