akhaliq HF Staff commited on
Commit
9f63005
·
verified ·
1 Parent(s): a13563f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -211
app.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  import os
2
  import torch
3
  import numpy as np
@@ -19,6 +22,7 @@ class VLChatProcessorOutput():
19
  return len(self.input_ids)
20
 
21
  def process_image(image_paths, vl_chat_processor):
 
22
  images = [Image.open(image_path).convert("RGB") for image_path in image_paths]
23
  images_outputs = vl_chat_processor.image_processor(images, return_tensors="pt")
24
  return images_outputs['pixel_values']
@@ -35,10 +39,11 @@ vl_gpt = vl_gpt.cuda().eval()
35
 
36
  # === Text-and-Image-to-Image generation ===
37
  def text_and_image_to_image_generate(input_prompt, input_image_path, output_path, vl_chat_processor, vl_gpt, temperature=1.0, parallel_size=2, cfg_weight=5, cfg_weight2=5):
 
38
  torch.cuda.empty_cache()
39
 
40
  input_img_tokens = vl_chat_processor.image_start_tag + vl_chat_processor.image_tag * vl_chat_processor.num_image_tokens + vl_chat_processor.image_end_tag + vl_chat_processor.image_start_tag + vl_chat_processor.pad_tag * vl_chat_processor.num_image_tokens + vl_chat_processor.image_end_tag
41
- output_img_tokens = vl_chat_processor.image_start_tag
42
 
43
  pre_data = []
44
  input_images = [input_image_path]
@@ -67,7 +72,7 @@ def text_and_image_to_image_generate(input_prompt, input_image_path, output_path
67
  image_embeds_input = vl_gpt.prepare_gen_img_embeds(image_tokens_input)
68
 
69
  input_ids = torch.LongTensor(vl_chat_processor.tokenizer.encode(sft_format))
70
-
71
  encoder_pixel_values = process_image(input_images, vl_chat_processor).cuda()
72
  tokens = torch.zeros((parallel_size * 3, len(input_ids)), dtype=torch.long)
73
  for i in range(parallel_size * 3):
@@ -99,8 +104,8 @@ def text_and_image_to_image_generate(input_prompt, input_image_path, output_path
99
 
100
  for i in range(image_token_num_per_image):
101
  outputs = vl_gpt.language_model.model(
102
- inputs_embeds=inputs_embeds,
103
- use_cache=True,
104
  past_key_values=past_key_values
105
  )
106
  hidden_states = outputs.last_hidden_state
@@ -120,11 +125,11 @@ def text_and_image_to_image_generate(input_prompt, input_image_path, output_path
120
  next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
121
  img_embeds = vl_gpt.prepare_gen_img_embeds(next_token)
122
  inputs_embeds = img_embeds.unsqueeze(dim=1)
123
-
124
  past_key_values = outputs.past_key_values
125
 
126
  dec = vl_gpt.gen_vision_model.decode_code(
127
- generated_tokens.to(dtype=torch.int),
128
  shape=[parallel_size, 8, img_size // patch_size, img_size // patch_size]
129
  )
130
  dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
@@ -149,6 +154,7 @@ def text_and_image_to_image_generate(input_prompt, input_image_path, output_path
149
 
150
  # === Text-to-Image generation ===
151
  def text_to_image_generate(input_prompt, output_path, vl_chat_processor, vl_gpt, temperature=1.0, parallel_size=2, cfg_weight=5.0):
 
152
  torch.cuda.empty_cache()
153
 
154
  conversation = [
@@ -179,51 +185,47 @@ def text_to_image_generate(input_prompt, output_path, vl_chat_processor, vl_gpt,
179
 
180
  inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
181
  generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).cuda()
182
-
183
  past_key_values = None
184
-
185
  for i in range(image_token_num_per_image):
186
  outputs = vl_gpt.language_model.model(
187
- inputs_embeds=inputs_embeds,
188
- use_cache=True,
189
  past_key_values=past_key_values
190
  )
191
-
192
  hidden_states = outputs.last_hidden_state
193
  logits = vl_gpt.gen_head(hidden_states[:, -1, :])
194
-
195
  logit_cond = logits[0::2, :]
196
  logit_uncond = logits[1::2, :]
197
-
198
  logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
199
  probs = torch.softmax(logits / temperature, dim=-1)
200
  next_token = torch.multinomial(probs, num_samples=1)
201
  generated_tokens[:, i] = next_token.squeeze(dim=-1)
202
-
203
- # Prepare next token for both conditional and unconditional
204
  next_token_expanded = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
205
  img_embeds = vl_gpt.prepare_gen_img_embeds(next_token_expanded)
206
  inputs_embeds = img_embeds.unsqueeze(dim=1)
207
-
208
- # Update past_key_values for next iteration
209
  past_key_values = outputs.past_key_values
210
 
211
- # Decode generated tokens to images
212
  dec = vl_gpt.gen_vision_model.decode_code(
213
- generated_tokens.to(dtype=torch.int),
214
  shape=[parallel_size, 8, img_size // patch_size, img_size // patch_size]
215
  )
216
  dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
217
  dec = np.clip((dec + 1) / 2 * 255, 0, 255)
218
-
219
  visual_img = np.zeros((parallel_size, img_size, img_size, 3), dtype=np.uint8)
220
  visual_img[:, :, :] = dec
221
 
222
- # Create output directory
223
  output_dir = os.path.dirname(output_path)
224
  if output_dir:
225
  os.makedirs(output_dir, exist_ok=True)
226
-
227
  output_images = []
228
  for i in range(parallel_size):
229
  save_path = output_path.replace('.png', f'_{i}.png')
@@ -233,200 +235,65 @@ def text_to_image_generate(input_prompt, output_path, vl_chat_processor, vl_gpt,
233
  torch.cuda.empty_cache()
234
  return output_images
235
 
236
- # === Enhanced Gradio handlers ===
237
  @spaces.GPU(duration=120)
238
- def janus_generate_image(message, history, uploaded_image=None):
239
- output_path = "./output/image.png"
240
-
241
- if uploaded_image is not None:
242
- # Save uploaded image temporarily
243
- temp_image_path = "./temp_input.png"
244
- uploaded_image.save(temp_image_path)
245
-
246
- # Use text+image to image generation
247
- images = text_and_image_to_image_generate(
248
- message, temp_image_path, output_path, vl_chat_processor, vl_gpt
249
- )
250
-
251
- # Clean up temp file
252
- if os.path.exists(temp_image_path):
253
- os.remove(temp_image_path)
254
- else:
255
- # Use text-only generation
256
- images = text_to_image_generate(message, output_path, vl_chat_processor, vl_gpt)
257
-
258
- return {"role": "assistant", "content": {"path": images[0]}}
259
 
260
- # === Alternative interface for explicit text+image input ===
261
- @spaces.GPU(duration=120)
262
- def generate_from_text_and_image(prompt, input_image):
263
- if input_image is None:
264
- return None, "Please upload an image to use text+image generation."
265
-
266
- output_path = "./output/text_image_gen.png"
267
-
268
- # Save uploaded image temporarily
269
- temp_image_path = "./temp_input.png"
270
- input_image.save(temp_image_path)
271
-
272
- try:
273
- images = text_and_image_to_image_generate(
274
- prompt, temp_image_path, output_path, vl_chat_processor, vl_gpt
275
- )
276
- return images[0], "Image generated successfully!"
277
- except Exception as e:
278
- return None, f"Error generating image: {str(e)}"
279
- finally:
280
- # Clean up temp file
281
- if os.path.exists(temp_image_path):
282
- os.remove(temp_image_path)
283
-
284
- @spaces.GPU(duration=120)
285
- def generate_from_text_only(prompt):
286
- output_path = "./output/text_only_gen.png"
287
-
288
- try:
289
- images = text_to_image_generate(prompt, output_path, vl_chat_processor, vl_gpt)
290
- return images[0], "Image generated successfully!"
291
- except Exception as e:
292
- return None, f"Error generating image: {str(e)}"
293
-
294
- # === Enhanced Gradio UI with multiple interfaces ===
295
- with gr.Blocks(theme="soft", title="Janus Text-to-Image & Text+Image-to-Image") as demo:
296
- gr.Markdown("# Janus Multi-Modal Image Generation")
297
- gr.Markdown("Generate images from text prompts or transform existing images with text descriptions using Janus-4o-7B")
298
-
299
- with gr.Tabs():
300
- # Chat Interface Tab
301
- with gr.Tab("Chat Interface"):
302
- gr.Markdown("### Interactive Chat with Optional Image Upload")
303
- gr.Markdown("You can chat and optionally upload an image to influence the generation")
304
-
305
- # Create a custom chat interface that supports image upload
306
- with gr.Row():
307
- with gr.Column(scale=3):
308
- chatbot = gr.Chatbot(label="Chat History")
309
- with gr.Row():
310
- msg_input = gr.Textbox(
311
- label="Message",
312
- placeholder="Describe the image you want to generate...",
313
- scale=4
314
- )
315
- image_input = gr.Image(
316
- type="pil",
317
- label="Upload Image (optional)",
318
- scale=1
319
- )
320
-
321
- with gr.Row():
322
- send_btn = gr.Button("Generate", variant="primary")
323
- clear_btn = gr.Button("Clear Chat")
324
-
325
- # Example prompts
326
- gr.Examples(
327
- examples=[
328
- ["a cat sitting on a windowsill", None],
329
- ["a futuristic city at sunset", None],
330
- ["a dragon flying over mountains", None],
331
- ],
332
- inputs=[msg_input, image_input]
333
  )
334
-
335
- # Separate Text-to-Image Tab
336
- with gr.Tab("Text-to-Image"):
337
- gr.Markdown("### Generate Images from Text Only")
338
-
339
- with gr.Row():
340
- with gr.Column():
341
- text_prompt = gr.Textbox(
342
- label="Text Prompt",
343
- placeholder="a beautiful landscape with mountains and a lake",
344
- lines=3
345
- )
346
- text_generate_btn = gr.Button("Generate Image", variant="primary")
347
-
348
- with gr.Column():
349
- text_output_image = gr.Image(label="Generated Image")
350
- text_status = gr.Textbox(label="Status", interactive=False)
351
-
352
- # Separate Text+Image-to-Image Tab
353
- with gr.Tab("Text+Image-to-Image"):
354
- gr.Markdown("### Transform Images with Text Descriptions")
355
-
356
- with gr.Row():
357
- with gr.Column():
358
- img_text_prompt = gr.Textbox(
359
- label="Text Prompt",
360
- placeholder="Turn this into a nighttime scene",
361
- lines=3
362
- )
363
- input_image = gr.Image(
364
- type="pil",
365
- label="Input Image"
366
- )
367
- img_generate_btn = gr.Button("Generate Image", variant="primary")
368
-
369
- with gr.Column():
370
- img_output_image = gr.Image(label="Generated Image")
371
- img_status = gr.Textbox(label="Status", interactive=False)
372
-
373
- # Event handlers for the chat interface
374
- def chat_respond(message, image, history):
375
- if not message.strip():
376
- return history, ""
377
-
378
- # Add user message to history
379
- if image is not None:
380
- history.append([f"{message} [with uploaded image]", None])
381
- else:
382
- history.append([message, None])
383
-
384
- # Generate response
385
  try:
386
- result = janus_generate_image(message, history, image)
387
- generated_image_path = result["content"]["path"]
388
-
389
- # Add assistant response to history
390
- history[-1][1] = (generated_image_path,)
391
-
392
  except Exception as e:
393
- history[-1][1] = f"Error: {str(e)}"
394
-
395
- return history, ""
396
-
397
- def clear_chat():
398
- return [], ""
399
-
400
- # Wire up the chat interface
401
- send_btn.click(
402
- chat_respond,
403
- inputs=[msg_input, image_input, chatbot],
404
- outputs=[chatbot, msg_input]
405
- )
406
-
407
- msg_input.submit(
408
- chat_respond,
409
- inputs=[msg_input, image_input, chatbot],
410
- outputs=[chatbot, msg_input]
411
- )
412
-
413
- clear_btn.click(
414
- clear_chat,
415
- outputs=[chatbot, msg_input]
416
- )
417
-
418
- # Wire up the separate interfaces
419
- text_generate_btn.click(
420
- generate_from_text_only,
421
- inputs=[text_prompt],
422
- outputs=[text_output_image, text_status]
423
- )
424
-
425
- img_generate_btn.click(
426
- generate_from_text_and_image,
427
- inputs=[img_text_prompt, input_image],
428
- outputs=[img_output_image, img_status]
429
  )
430
 
431
  if __name__ == "__main__":
432
- demo.launch()
 
 
 
 
 
 
 
 
 
1
+ Of course. Below is the modified `app.py` that merges the "Text-to-Image" and "Text+Image-to-Image" functionalities into a single, unified `gr.ChatInterface`, removing the extra tabs and UI elements as you requested.
2
+
3
+ ```python
4
  import os
5
  import torch
6
  import numpy as np
 
22
  return len(self.input_ids)
23
 
24
  def process_image(image_paths, vl_chat_processor):
25
+ """Processes a list of image paths into pixel values."""
26
  images = [Image.open(image_path).convert("RGB") for image_path in image_paths]
27
  images_outputs = vl_chat_processor.image_processor(images, return_tensors="pt")
28
  return images_outputs['pixel_values']
 
39
 
40
  # === Text-and-Image-to-Image generation ===
41
  def text_and_image_to_image_generate(input_prompt, input_image_path, output_path, vl_chat_processor, vl_gpt, temperature=1.0, parallel_size=2, cfg_weight=5, cfg_weight2=5):
42
+ """Generates an image from a text prompt and an input image."""
43
  torch.cuda.empty_cache()
44
 
45
  input_img_tokens = vl_chat_processor.image_start_tag + vl_chat_processor.image_tag * vl_chat_processor.num_image_tokens + vl_chat_processor.image_end_tag + vl_chat_processor.image_start_tag + vl_chat_processor.pad_tag * vl_chat_processor.num_image_tokens + vl_chat_processor.image_end_tag
46
+ output_img_tokens = vl_chat_processor.image_start_tag
47
 
48
  pre_data = []
49
  input_images = [input_image_path]
 
72
  image_embeds_input = vl_gpt.prepare_gen_img_embeds(image_tokens_input)
73
 
74
  input_ids = torch.LongTensor(vl_chat_processor.tokenizer.encode(sft_format))
75
+
76
  encoder_pixel_values = process_image(input_images, vl_chat_processor).cuda()
77
  tokens = torch.zeros((parallel_size * 3, len(input_ids)), dtype=torch.long)
78
  for i in range(parallel_size * 3):
 
104
 
105
  for i in range(image_token_num_per_image):
106
  outputs = vl_gpt.language_model.model(
107
+ inputs_embeds=inputs_embeds,
108
+ use_cache=True,
109
  past_key_values=past_key_values
110
  )
111
  hidden_states = outputs.last_hidden_state
 
125
  next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
126
  img_embeds = vl_gpt.prepare_gen_img_embeds(next_token)
127
  inputs_embeds = img_embeds.unsqueeze(dim=1)
128
+
129
  past_key_values = outputs.past_key_values
130
 
131
  dec = vl_gpt.gen_vision_model.decode_code(
132
+ generated_tokens.to(dtype=torch.int),
133
  shape=[parallel_size, 8, img_size // patch_size, img_size // patch_size]
134
  )
135
  dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
 
154
 
155
  # === Text-to-Image generation ===
156
  def text_to_image_generate(input_prompt, output_path, vl_chat_processor, vl_gpt, temperature=1.0, parallel_size=2, cfg_weight=5.0):
157
+ """Generates an image from a text prompt only."""
158
  torch.cuda.empty_cache()
159
 
160
  conversation = [
 
185
 
186
  inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
187
  generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).cuda()
188
+
189
  past_key_values = None
190
+
191
  for i in range(image_token_num_per_image):
192
  outputs = vl_gpt.language_model.model(
193
+ inputs_embeds=inputs_embeds,
194
+ use_cache=True,
195
  past_key_values=past_key_values
196
  )
197
+
198
  hidden_states = outputs.last_hidden_state
199
  logits = vl_gpt.gen_head(hidden_states[:, -1, :])
200
+
201
  logit_cond = logits[0::2, :]
202
  logit_uncond = logits[1::2, :]
203
+
204
  logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
205
  probs = torch.softmax(logits / temperature, dim=-1)
206
  next_token = torch.multinomial(probs, num_samples=1)
207
  generated_tokens[:, i] = next_token.squeeze(dim=-1)
208
+
 
209
  next_token_expanded = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
210
  img_embeds = vl_gpt.prepare_gen_img_embeds(next_token_expanded)
211
  inputs_embeds = img_embeds.unsqueeze(dim=1)
212
+
 
213
  past_key_values = outputs.past_key_values
214
 
 
215
  dec = vl_gpt.gen_vision_model.decode_code(
216
+ generated_tokens.to(dtype=torch.int),
217
  shape=[parallel_size, 8, img_size // patch_size, img_size // patch_size]
218
  )
219
  dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
220
  dec = np.clip((dec + 1) / 2 * 255, 0, 255)
221
+
222
  visual_img = np.zeros((parallel_size, img_size, img_size, 3), dtype=np.uint8)
223
  visual_img[:, :, :] = dec
224
 
 
225
  output_dir = os.path.dirname(output_path)
226
  if output_dir:
227
  os.makedirs(output_dir, exist_ok=True)
228
+
229
  output_images = []
230
  for i in range(parallel_size):
231
  save_path = output_path.replace('.png', f'_{i}.png')
 
235
  torch.cuda.empty_cache()
236
  return output_images
237
 
238
+ # === Unified Gradio handler for ChatInterface ===
239
  @spaces.GPU(duration=120)
240
+ def janus_chat_responder(message, history):
241
+ """
242
+ Handles both text-only and multimodal (text+image) inputs from the ChatInterface.
243
+ 'message' is a dictionary with 'text' and 'files' keys.
244
+ """
245
+ output_path = "./output/chat_image.png"
246
+ prompt = message["text"]
247
+ uploaded_files = message["files"]
248
+
249
+ if uploaded_files:
250
+ # Handle text+image to image generation
251
+ # Assuming the first uploaded file is the image to process
252
+ temp_image_path = uploaded_files[0]
 
 
 
 
 
 
 
 
253
 
254
+ try:
255
+ images = text_and_image_to_image_generate(
256
+ prompt, temp_image_path, output_path, vl_chat_processor, vl_gpt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
  )
258
+ # Return the path to the first generated image to be displayed in the chat
259
+ return images[0]
260
+ except Exception as e:
261
+ return f"Error during image-to-image generation: {str(e)}"
262
+
263
+ else:
264
+ # Handle text-to-image generation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
  try:
266
+ images = text_to_image_generate(prompt, output_path, vl_chat_processor, vl_gpt)
267
+ # Return the path to the first generated image
268
+ return images[0]
 
 
 
269
  except Exception as e:
270
+ return f"Error during text-to-image generation: {str(e)}"
271
+
272
+
273
+ # === Simplified Gradio UI with a single ChatInterface ===
274
+ with gr.Blocks(theme="soft", title="Janus Image Generation") as demo:
275
+ gr.Markdown("# Janus Multi-Modal Image Generation")
276
+ gr.Markdown("Generate images from text prompts, or upload an image and a prompt to transform it.")
277
+
278
+ gr.ChatInterface(
279
+ fn=janus_chat_responder,
280
+ multimodal=True,
281
+ title="Janus-4o-7B Chat",
282
+ examples=[
283
+ {"text": "a cat sitting on a windowsill", "files": []},
284
+ {"text": "a futuristic city at sunset", "files": []},
285
+ {"text": "a dragon flying over mountains", "files": []},
286
+ {"text": "Turn this into a watercolor painting", "files": ["./assets/example_image.jpg"]}
287
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
  )
289
 
290
  if __name__ == "__main__":
291
+ # Create a dummy image for the example if it doesn't exist
292
+ if not os.path.exists("./assets"):
293
+ os.makedirs("./assets")
294
+ if not os.path.exists("./assets/example_image.jpg"):
295
+ dummy_image = Image.new('RGB', (100, 100), color = 'red')
296
+ dummy_image.save("./assets/example_image.jpg")
297
+
298
+ demo.launch()
299
+ ```