akhaliq HF Staff commited on
Commit
ad9c7fd
·
verified ·
1 Parent(s): a4e17a0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -57
app.py CHANGED
@@ -18,7 +18,7 @@ vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
18
  vl_gpt = vl_gpt.cuda().eval()
19
 
20
  # === Image generation ===
21
- def text_to_image_generate(input_prompt, output_path, vl_chat_processor, vl_gpt, temperature=1.0, parallel_size=1, cfg_weight=3.0):
22
  torch.cuda.empty_cache()
23
 
24
  conversation = [
@@ -26,72 +26,82 @@ def text_to_image_generate(input_prompt, output_path, vl_chat_processor, vl_gpt,
26
  {"role": "<|Assistant|>", "content": ""},
27
  ]
28
 
29
- prompt = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
30
  conversations=conversation,
31
  sft_format=vl_chat_processor.sft_format,
32
  system_prompt="",
33
- ) + vl_chat_processor.image_start_tag
 
34
 
35
- image_token_num = 576
36
  img_size = 384
37
  patch_size = 16
38
 
39
- input_ids = tokenizer.encode(prompt)
40
- input_ids = torch.LongTensor(input_ids)
41
- tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).cuda()
42
-
43
- for i in range(parallel_size * 2):
44
- tokens[i, :] = input_ids
45
- if i % 2 != 0:
46
- tokens[i, 1:-1] = tokenizer.pad_token_id
47
-
48
- inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
49
- generated_tokens = torch.zeros((parallel_size, image_token_num), dtype=torch.int).cuda()
50
-
51
- for i in range(image_token_num):
52
- if i == 0:
53
- current_inputs = inputs_embeds
54
- else:
55
- current_inputs = img_embeds.unsqueeze(1)
56
-
57
- # ✅ No past_key_values, crash-safe
58
- outputs = vl_gpt.language_model.model(
59
- inputs_embeds=current_inputs,
60
- use_cache=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  )
62
-
63
- hidden_states = outputs.last_hidden_state
64
- logits = vl_gpt.gen_head(hidden_states[:, -1, :])
65
-
66
- logit_cond = logits[0::2]
67
- logit_uncond = logits[1::2]
68
- logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
69
- probs = torch.softmax(logits / temperature, dim=-1)
70
-
71
- next_token = torch.multinomial(probs, num_samples=1)
72
- generated_tokens[:, i] = next_token.squeeze(dim=-1)
73
-
74
- next_token = torch.cat([next_token, next_token], dim=1).reshape(-1)
75
- img_embeds = vl_gpt.prepare_gen_img_embeds(next_token)
76
-
77
- dec = vl_gpt.gen_vision_model.decode_code(
78
- generated_tokens.to(dtype=torch.int),
79
- shape=[parallel_size, 8, img_size // patch_size, img_size // patch_size]
80
- )
81
- dec = dec.to(torch.float32).detach().cpu().numpy().transpose(0, 2, 3, 1)
82
- dec = np.clip((dec + 1) / 2 * 255, 0, 255).astype(np.uint8)
83
-
84
- os.makedirs(os.path.dirname(output_path), exist_ok=True)
85
- output_images = []
86
- for i in range(parallel_size):
87
- save_path = output_path.replace('.png', f'_{i}.png')
88
- Image.fromarray(dec[i]).save(save_path)
89
- output_images.append(save_path)
90
 
91
  torch.cuda.empty_cache()
92
- torch.cuda.ipc_collect()
93
-
94
- return output_images[:1] # return only the first image
95
 
96
 
97
  # === Gradio handler ===
@@ -99,6 +109,7 @@ def text_to_image_generate(input_prompt, output_path, vl_chat_processor, vl_gpt,
99
  def janus_generate_image(message, history):
100
  output_path = "./output/image.png"
101
  images = text_to_image_generate(message, output_path, vl_chat_processor, vl_gpt)
 
102
  return {"role": "assistant", "content": {"path": images[0]}}
103
 
104
 
 
18
  vl_gpt = vl_gpt.cuda().eval()
19
 
20
  # === Image generation ===
21
+ def text_to_image_generate(input_prompt, output_path, vl_chat_processor, vl_gpt, temperature=1.0, parallel_size=2, cfg_weight=5.0):
22
  torch.cuda.empty_cache()
23
 
24
  conversation = [
 
26
  {"role": "<|Assistant|>", "content": ""},
27
  ]
28
 
29
+ sft_format = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
30
  conversations=conversation,
31
  sft_format=vl_chat_processor.sft_format,
32
  system_prompt="",
33
+ )
34
+ prompt = sft_format + vl_chat_processor.image_start_tag
35
 
36
+ image_token_num_per_image = 576
37
  img_size = 384
38
  patch_size = 16
39
 
40
+ with torch.inference_mode():
41
+ input_ids = vl_chat_processor.tokenizer.encode(prompt)
42
+ input_ids = torch.LongTensor(input_ids)
43
+ tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).cuda()
44
+
45
+ for i in range(parallel_size * 2):
46
+ tokens[i, :] = input_ids
47
+ if i % 2 != 0:
48
+ tokens[i, 1:-1] = vl_chat_processor.pad_id
49
+
50
+ inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
51
+ generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).cuda()
52
+
53
+ past_key_values = None
54
+
55
+ for i in range(image_token_num_per_image):
56
+ outputs = vl_gpt.language_model.model(
57
+ inputs_embeds=inputs_embeds,
58
+ use_cache=True,
59
+ past_key_values=past_key_values
60
+ )
61
+
62
+ hidden_states = outputs.last_hidden_state
63
+ logits = vl_gpt.gen_head(hidden_states[:, -1, :])
64
+
65
+ logit_cond = logits[0::2, :]
66
+ logit_uncond = logits[1::2, :]
67
+
68
+ logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
69
+ probs = torch.softmax(logits / temperature, dim=-1)
70
+ next_token = torch.multinomial(probs, num_samples=1)
71
+ generated_tokens[:, i] = next_token.squeeze(dim=-1)
72
+
73
+ # Prepare next token for both conditional and unconditional
74
+ next_token_expanded = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
75
+ img_embeds = vl_gpt.prepare_gen_img_embeds(next_token_expanded)
76
+ inputs_embeds = img_embeds.unsqueeze(dim=1)
77
+
78
+ # Update past_key_values for next iteration
79
+ past_key_values = outputs.past_key_values
80
+
81
+ # Decode generated tokens to images
82
+ dec = vl_gpt.gen_vision_model.decode_code(
83
+ generated_tokens.to(dtype=torch.int),
84
+ shape=[parallel_size, 8, img_size // patch_size, img_size // patch_size]
85
  )
86
+ dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
87
+ dec = np.clip((dec + 1) / 2 * 255, 0, 255)
88
+
89
+ visual_img = np.zeros((parallel_size, img_size, img_size, 3), dtype=np.uint8)
90
+ visual_img[:, :, :] = dec
91
+
92
+ # Create output directory
93
+ output_dir = os.path.dirname(output_path)
94
+ if output_dir:
95
+ os.makedirs(output_dir, exist_ok=True)
96
+
97
+ output_images = []
98
+ for i in range(parallel_size):
99
+ save_path = output_path.replace('.png', f'_{i}.png')
100
+ Image.fromarray(visual_img[i]).save(save_path)
101
+ output_images.append(save_path)
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
  torch.cuda.empty_cache()
104
+ return output_images
 
 
105
 
106
 
107
  # === Gradio handler ===
 
109
  def janus_generate_image(message, history):
110
  output_path = "./output/image.png"
111
  images = text_to_image_generate(message, output_path, vl_chat_processor, vl_gpt)
112
+ # Return the first generated image
113
  return {"role": "assistant", "content": {"path": images[0]}}
114
 
115