akhaliq HF Staff commited on
Commit
c6b73f1
·
verified ·
1 Parent(s): 7bf9267

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -9
app.py CHANGED
@@ -52,21 +52,13 @@ def text_to_image_generate(input_prompt, output_path, vl_chat_processor, vl_gpt,
52
  for i in range(image_token_num):
53
  if i == 0:
54
  outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True)
55
- past_key_values = outputs.past_key_values
56
  else:
57
- # 🧠 SAFE DETACH TRICK
58
- with torch.no_grad():
59
- past_key_values = tuple(
60
- tuple(pkv.detach() for pkv in layer)
61
- for layer in past_key_values
62
- )
63
-
64
  outputs = vl_gpt.language_model.model(
65
  inputs_embeds=img_embeds.unsqueeze(1),
66
  use_cache=True,
67
  past_key_values=past_key_values
68
  )
69
- past_key_values = outputs.past_key_values
70
 
71
  hidden_states = outputs.last_hidden_state
72
  logits = vl_gpt.gen_head(hidden_states[:, -1, :])
 
52
  for i in range(image_token_num):
53
  if i == 0:
54
  outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True)
 
55
  else:
 
 
 
 
 
 
 
56
  outputs = vl_gpt.language_model.model(
57
  inputs_embeds=img_embeds.unsqueeze(1),
58
  use_cache=True,
59
  past_key_values=past_key_values
60
  )
61
+ past_key_values = outputs.past_key_values
62
 
63
  hidden_states = outputs.last_hidden_state
64
  logits = vl_gpt.gen_head(hidden_states[:, -1, :])