akhaliq HF Staff commited on
Commit
7bf9267
·
verified ·
1 Parent(s): d0ff9d9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -6
app.py CHANGED
@@ -17,7 +17,7 @@ vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
17
  )
18
  vl_gpt = vl_gpt.cuda().eval()
19
 
20
- # === Image generation function ===
21
  def text_to_image_generate(input_prompt, output_path, vl_chat_processor, vl_gpt, temperature=1.0, parallel_size=1, cfg_weight=3.0):
22
  torch.cuda.empty_cache()
23
 
@@ -52,13 +52,22 @@ def text_to_image_generate(input_prompt, output_path, vl_chat_processor, vl_gpt,
52
  for i in range(image_token_num):
53
  if i == 0:
54
  outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True)
 
55
  else:
 
 
 
 
 
 
 
56
  outputs = vl_gpt.language_model.model(
57
  inputs_embeds=img_embeds.unsqueeze(1),
58
  use_cache=True,
59
  past_key_values=past_key_values
60
  )
61
- past_key_values = outputs.past_key_values
 
62
  hidden_states = outputs.last_hidden_state
63
  logits = vl_gpt.gen_head(hidden_states[:, -1, :])
64
 
@@ -87,7 +96,10 @@ def text_to_image_generate(input_prompt, output_path, vl_chat_processor, vl_gpt,
87
  Image.fromarray(dec[i]).save(save_path)
88
  output_images.append(save_path)
89
 
90
- return output_images[:1] # just return first image
 
 
 
91
 
92
 
93
  # === Gradio handler ===
@@ -105,9 +117,9 @@ demo = gr.ChatInterface(
105
  title="Janus Text-to-Image",
106
  description="Generate images from natural language prompts using Janus-4o-7B",
107
  examples=[
108
- "a cat wearing a spacesuit on Mars",
109
- "a beautiful sunset over the mountains",
110
- "a photorealistic dog riding a bicycle"
111
  ],
112
  theme="soft",
113
  )
 
17
  )
18
  vl_gpt = vl_gpt.cuda().eval()
19
 
20
+ # === Image generation ===
21
  def text_to_image_generate(input_prompt, output_path, vl_chat_processor, vl_gpt, temperature=1.0, parallel_size=1, cfg_weight=3.0):
22
  torch.cuda.empty_cache()
23
 
 
52
  for i in range(image_token_num):
53
  if i == 0:
54
  outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True)
55
+ past_key_values = outputs.past_key_values
56
  else:
57
+ # 🧠 SAFE DETACH TRICK
58
+ with torch.no_grad():
59
+ past_key_values = tuple(
60
+ tuple(pkv.detach() for pkv in layer)
61
+ for layer in past_key_values
62
+ )
63
+
64
  outputs = vl_gpt.language_model.model(
65
  inputs_embeds=img_embeds.unsqueeze(1),
66
  use_cache=True,
67
  past_key_values=past_key_values
68
  )
69
+ past_key_values = outputs.past_key_values
70
+
71
  hidden_states = outputs.last_hidden_state
72
  logits = vl_gpt.gen_head(hidden_states[:, -1, :])
73
 
 
96
  Image.fromarray(dec[i]).save(save_path)
97
  output_images.append(save_path)
98
 
99
+ torch.cuda.empty_cache()
100
+ torch.cuda.ipc_collect()
101
+
102
+ return output_images[:1] # return only the first image
103
 
104
 
105
  # === Gradio handler ===
 
117
  title="Janus Text-to-Image",
118
  description="Generate images from natural language prompts using Janus-4o-7B",
119
  examples=[
120
+ "a cat sitting on a windowsill",
121
+ "a futuristic city at sunset",
122
+ "a dragon flying over mountains",
123
  ],
124
  theme="soft",
125
  )