akhaliq HF Staff commited on
Commit
5e32fba
·
verified ·
1 Parent(s): 1be336f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -26
app.py CHANGED
@@ -1,34 +1,27 @@
1
  import os
2
- import PIL.Image
3
  import torch
4
  import numpy as np
 
5
  import gradio as gr
6
  from transformers import AutoModelForCausalLM
7
  from janus.models import MultiModalityCausalLM, VLChatProcessor
8
  import spaces
9
 
10
- # Load model and processor
11
  model_path = "FreedomIntelligence/Janus-4o-7B"
12
  vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
13
  tokenizer = vl_chat_processor.tokenizer
 
14
  vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
15
  model_path, trust_remote_code=True, torch_dtype=torch.bfloat16
16
  )
17
  vl_gpt = vl_gpt.cuda().eval()
18
 
19
- # Define image generation function
20
- @spaces.GPU(duration=120)
21
- def janus_generate_image(message, history):
22
- prompt = message
23
- output_path = "./output_image.png"
24
- images = text_to_image_generate(prompt, output_path, vl_chat_processor, vl_gpt, parallel_size=1)
25
- return {"role": "assistant", "content": gr.Image(images[0])}
26
-
27
- # Optimized text-to-image generation
28
-
29
  def text_to_image_generate(input_prompt, output_path, vl_chat_processor, vl_gpt, temperature=1.0, parallel_size=1, cfg_weight=5):
30
  torch.cuda.empty_cache()
31
 
 
32
  conversation = [
33
  {"role": "<|User|>", "content": input_prompt},
34
  {"role": "<|Assistant|>", "content": ""},
@@ -42,28 +35,26 @@ def text_to_image_generate(input_prompt, output_path, vl_chat_processor, vl_gpt,
42
 
43
  prompt = sft_format + vl_chat_processor.image_start_tag
44
  mmgpt = vl_gpt
 
45
  image_token_num = 576
46
  img_size = 384
47
  patch_size = 16
48
 
49
  with torch.inference_mode():
50
- input_ids = vl_chat_processor.tokenizer.encode(prompt)
51
  input_ids = torch.LongTensor(input_ids)
52
 
53
  tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).cuda()
54
  for i in range(parallel_size * 2):
55
  tokens[i, :] = input_ids
56
  if i % 2 != 0:
57
- tokens[i, 1:-1] = vl_chat_processor.pad_id
58
 
59
  inputs_embeds = mmgpt.language_model.get_input_embeddings()(tokens)
60
  generated_tokens = torch.zeros((parallel_size, image_token_num), dtype=torch.int).cuda()
61
 
62
  for i in range(image_token_num):
63
- outputs = mmgpt.language_model.model(
64
- inputs_embeds=inputs_embeds,
65
- use_cache=False
66
- )
67
  hidden_states = outputs.last_hidden_state
68
  logits = mmgpt.gen_head(hidden_states[:, -1, :])
69
  logit_cond = logits[0::2]
@@ -78,6 +69,7 @@ def text_to_image_generate(input_prompt, output_path, vl_chat_processor, vl_gpt,
78
  img_embeds = mmgpt.prepare_gen_img_embeds(next_token)
79
  inputs_embeds = img_embeds.unsqueeze(1)
80
 
 
81
  dec = mmgpt.gen_vision_model.decode_code(
82
  generated_tokens.to(dtype=torch.int),
83
  shape=[parallel_size, 8, img_size // patch_size, img_size // patch_size]
@@ -85,26 +77,33 @@ def text_to_image_generate(input_prompt, output_path, vl_chat_processor, vl_gpt,
85
  dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
86
  dec = np.clip((dec + 1) / 2 * 255, 0, 255).astype(np.uint8)
87
 
88
- visual_img = np.zeros((parallel_size, img_size, img_size, 3), dtype=np.uint8)
89
- visual_img[:, :, :] = dec
90
-
91
  os.makedirs(os.path.dirname(output_path), exist_ok=True)
92
  output_images = []
93
  for i in range(parallel_size):
94
  save_path = output_path.replace('.png', f'_{i}.png')
95
- PIL.Image.fromarray(visual_img[i]).save(save_path)
96
  output_images.append(save_path)
97
 
98
  return output_images
99
 
100
- # Launch the ChatInterface UI
 
 
 
 
 
 
 
 
101
  demo = gr.ChatInterface(
102
  fn=janus_generate_image,
103
  title="Janus Text-to-Image",
104
  description="Generate images from natural language prompts using Janus-4o-7B",
 
 
 
105
  theme="soft",
106
- fill_height=True,
107
- fill_width=True
108
  )
109
 
110
- demo.launch()
 
 
1
  import os
 
2
  import torch
3
  import numpy as np
4
+ from PIL import Image
5
  import gradio as gr
6
  from transformers import AutoModelForCausalLM
7
  from janus.models import MultiModalityCausalLM, VLChatProcessor
8
  import spaces
9
 
10
+ # === Load model and processor ===
11
  model_path = "FreedomIntelligence/Janus-4o-7B"
12
  vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
13
  tokenizer = vl_chat_processor.tokenizer
14
+
15
  vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
16
  model_path, trust_remote_code=True, torch_dtype=torch.bfloat16
17
  )
18
  vl_gpt = vl_gpt.cuda().eval()
19
 
20
+ # === Image generation function ===
 
 
 
 
 
 
 
 
 
21
  def text_to_image_generate(input_prompt, output_path, vl_chat_processor, vl_gpt, temperature=1.0, parallel_size=1, cfg_weight=5):
22
  torch.cuda.empty_cache()
23
 
24
+ # Apply prompt formatting
25
  conversation = [
26
  {"role": "<|User|>", "content": input_prompt},
27
  {"role": "<|Assistant|>", "content": ""},
 
35
 
36
  prompt = sft_format + vl_chat_processor.image_start_tag
37
  mmgpt = vl_gpt
38
+
39
  image_token_num = 576
40
  img_size = 384
41
  patch_size = 16
42
 
43
  with torch.inference_mode():
44
+ input_ids = tokenizer.encode(prompt)
45
  input_ids = torch.LongTensor(input_ids)
46
 
47
  tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).cuda()
48
  for i in range(parallel_size * 2):
49
  tokens[i, :] = input_ids
50
  if i % 2 != 0:
51
+ tokens[i, 1:-1] = tokenizer.pad_token_id # More robust
52
 
53
  inputs_embeds = mmgpt.language_model.get_input_embeddings()(tokens)
54
  generated_tokens = torch.zeros((parallel_size, image_token_num), dtype=torch.int).cuda()
55
 
56
  for i in range(image_token_num):
57
+ outputs = mmgpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=False)
 
 
 
58
  hidden_states = outputs.last_hidden_state
59
  logits = mmgpt.gen_head(hidden_states[:, -1, :])
60
  logit_cond = logits[0::2]
 
69
  img_embeds = mmgpt.prepare_gen_img_embeds(next_token)
70
  inputs_embeds = img_embeds.unsqueeze(1)
71
 
72
+ # Decode image
73
  dec = mmgpt.gen_vision_model.decode_code(
74
  generated_tokens.to(dtype=torch.int),
75
  shape=[parallel_size, 8, img_size // patch_size, img_size // patch_size]
 
77
  dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
78
  dec = np.clip((dec + 1) / 2 * 255, 0, 255).astype(np.uint8)
79
 
 
 
 
80
  os.makedirs(os.path.dirname(output_path), exist_ok=True)
81
  output_images = []
82
  for i in range(parallel_size):
83
  save_path = output_path.replace('.png', f'_{i}.png')
84
+ Image.fromarray(dec[i]).save(save_path)
85
  output_images.append(save_path)
86
 
87
  return output_images
88
 
89
+ # === Gradio handler ===
90
+ @spaces.GPU(duration=120)
91
+ def janus_generate_image(message, history):
92
+ prompt = message
93
+ output_path = "./output/image.png"
94
+ images = text_to_image_generate(prompt, output_path, vl_chat_processor, vl_gpt, parallel_size=1)
95
+ return {"role": "assistant", "content": images[0]}
96
+
97
+ # === Gradio UI ===
98
  demo = gr.ChatInterface(
99
  fn=janus_generate_image,
100
  title="Janus Text-to-Image",
101
  description="Generate images from natural language prompts using Janus-4o-7B",
102
+ additional_inputs=[],
103
+ chatbot=gr.Chatbot(show_copy_button=True),
104
+ examples=["a cat", "a spaceship landing on Mars", "a fantasy castle at sunset"],
105
  theme="soft",
 
 
106
  )
107
 
108
+ if __name__ == "__main__":
109
+ demo.launch()