akhaliq HF Staff commited on
Commit
13d07d6
·
verified ·
1 Parent(s): c6b73f1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -9
app.py CHANGED
@@ -47,18 +47,18 @@ def text_to_image_generate(input_prompt, output_path, vl_chat_processor, vl_gpt,
47
 
48
  inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
49
  generated_tokens = torch.zeros((parallel_size, image_token_num), dtype=torch.int).cuda()
50
- past_key_values = None
51
 
52
  for i in range(image_token_num):
53
  if i == 0:
54
- outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True)
55
  else:
56
- outputs = vl_gpt.language_model.model(
57
- inputs_embeds=img_embeds.unsqueeze(1),
58
- use_cache=True,
59
- past_key_values=past_key_values
60
- )
61
- past_key_values = outputs.past_key_values
 
62
 
63
  hidden_states = outputs.last_hidden_state
64
  logits = vl_gpt.gen_head(hidden_states[:, -1, :])
@@ -117,4 +117,4 @@ demo = gr.ChatInterface(
117
  )
118
 
119
  if __name__ == "__main__":
120
- demo.launch()
 
47
 
48
  inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
49
  generated_tokens = torch.zeros((parallel_size, image_token_num), dtype=torch.int).cuda()
 
50
 
51
  for i in range(image_token_num):
52
  if i == 0:
53
+ current_inputs = inputs_embeds
54
  else:
55
+ current_inputs = img_embeds.unsqueeze(1)
56
+
57
+ # ✅ No past_key_values, crash-safe
58
+ outputs = vl_gpt.language_model.model(
59
+ inputs_embeds=current_inputs,
60
+ use_cache=False
61
+ )
62
 
63
  hidden_states = outputs.last_hidden_state
64
  logits = vl_gpt.gen_head(hidden_states[:, -1, :])
 
117
  )
118
 
119
  if __name__ == "__main__":
120
+ demo.launch()