Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,3 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import torch
|
| 3 |
import numpy as np
|
|
@@ -19,6 +22,7 @@ class VLChatProcessorOutput():
|
|
| 19 |
return len(self.input_ids)
|
| 20 |
|
| 21 |
def process_image(image_paths, vl_chat_processor):
|
|
|
|
| 22 |
images = [Image.open(image_path).convert("RGB") for image_path in image_paths]
|
| 23 |
images_outputs = vl_chat_processor.image_processor(images, return_tensors="pt")
|
| 24 |
return images_outputs['pixel_values']
|
|
@@ -35,10 +39,11 @@ vl_gpt = vl_gpt.cuda().eval()
|
|
| 35 |
|
| 36 |
# === Text-and-Image-to-Image generation ===
|
| 37 |
def text_and_image_to_image_generate(input_prompt, input_image_path, output_path, vl_chat_processor, vl_gpt, temperature=1.0, parallel_size=2, cfg_weight=5, cfg_weight2=5):
|
|
|
|
| 38 |
torch.cuda.empty_cache()
|
| 39 |
|
| 40 |
input_img_tokens = vl_chat_processor.image_start_tag + vl_chat_processor.image_tag * vl_chat_processor.num_image_tokens + vl_chat_processor.image_end_tag + vl_chat_processor.image_start_tag + vl_chat_processor.pad_tag * vl_chat_processor.num_image_tokens + vl_chat_processor.image_end_tag
|
| 41 |
-
output_img_tokens = vl_chat_processor.image_start_tag
|
| 42 |
|
| 43 |
pre_data = []
|
| 44 |
input_images = [input_image_path]
|
|
@@ -67,7 +72,7 @@ def text_and_image_to_image_generate(input_prompt, input_image_path, output_path
|
|
| 67 |
image_embeds_input = vl_gpt.prepare_gen_img_embeds(image_tokens_input)
|
| 68 |
|
| 69 |
input_ids = torch.LongTensor(vl_chat_processor.tokenizer.encode(sft_format))
|
| 70 |
-
|
| 71 |
encoder_pixel_values = process_image(input_images, vl_chat_processor).cuda()
|
| 72 |
tokens = torch.zeros((parallel_size * 3, len(input_ids)), dtype=torch.long)
|
| 73 |
for i in range(parallel_size * 3):
|
|
@@ -99,8 +104,8 @@ def text_and_image_to_image_generate(input_prompt, input_image_path, output_path
|
|
| 99 |
|
| 100 |
for i in range(image_token_num_per_image):
|
| 101 |
outputs = vl_gpt.language_model.model(
|
| 102 |
-
inputs_embeds=inputs_embeds,
|
| 103 |
-
use_cache=True,
|
| 104 |
past_key_values=past_key_values
|
| 105 |
)
|
| 106 |
hidden_states = outputs.last_hidden_state
|
|
@@ -120,11 +125,11 @@ def text_and_image_to_image_generate(input_prompt, input_image_path, output_path
|
|
| 120 |
next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
|
| 121 |
img_embeds = vl_gpt.prepare_gen_img_embeds(next_token)
|
| 122 |
inputs_embeds = img_embeds.unsqueeze(dim=1)
|
| 123 |
-
|
| 124 |
past_key_values = outputs.past_key_values
|
| 125 |
|
| 126 |
dec = vl_gpt.gen_vision_model.decode_code(
|
| 127 |
-
generated_tokens.to(dtype=torch.int),
|
| 128 |
shape=[parallel_size, 8, img_size // patch_size, img_size // patch_size]
|
| 129 |
)
|
| 130 |
dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
|
|
@@ -149,6 +154,7 @@ def text_and_image_to_image_generate(input_prompt, input_image_path, output_path
|
|
| 149 |
|
| 150 |
# === Text-to-Image generation ===
|
| 151 |
def text_to_image_generate(input_prompt, output_path, vl_chat_processor, vl_gpt, temperature=1.0, parallel_size=2, cfg_weight=5.0):
|
|
|
|
| 152 |
torch.cuda.empty_cache()
|
| 153 |
|
| 154 |
conversation = [
|
|
@@ -179,51 +185,47 @@ def text_to_image_generate(input_prompt, output_path, vl_chat_processor, vl_gpt,
|
|
| 179 |
|
| 180 |
inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
|
| 181 |
generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).cuda()
|
| 182 |
-
|
| 183 |
past_key_values = None
|
| 184 |
-
|
| 185 |
for i in range(image_token_num_per_image):
|
| 186 |
outputs = vl_gpt.language_model.model(
|
| 187 |
-
inputs_embeds=inputs_embeds,
|
| 188 |
-
use_cache=True,
|
| 189 |
past_key_values=past_key_values
|
| 190 |
)
|
| 191 |
-
|
| 192 |
hidden_states = outputs.last_hidden_state
|
| 193 |
logits = vl_gpt.gen_head(hidden_states[:, -1, :])
|
| 194 |
-
|
| 195 |
logit_cond = logits[0::2, :]
|
| 196 |
logit_uncond = logits[1::2, :]
|
| 197 |
-
|
| 198 |
logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
|
| 199 |
probs = torch.softmax(logits / temperature, dim=-1)
|
| 200 |
next_token = torch.multinomial(probs, num_samples=1)
|
| 201 |
generated_tokens[:, i] = next_token.squeeze(dim=-1)
|
| 202 |
-
|
| 203 |
-
# Prepare next token for both conditional and unconditional
|
| 204 |
next_token_expanded = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
|
| 205 |
img_embeds = vl_gpt.prepare_gen_img_embeds(next_token_expanded)
|
| 206 |
inputs_embeds = img_embeds.unsqueeze(dim=1)
|
| 207 |
-
|
| 208 |
-
# Update past_key_values for next iteration
|
| 209 |
past_key_values = outputs.past_key_values
|
| 210 |
|
| 211 |
-
# Decode generated tokens to images
|
| 212 |
dec = vl_gpt.gen_vision_model.decode_code(
|
| 213 |
-
generated_tokens.to(dtype=torch.int),
|
| 214 |
shape=[parallel_size, 8, img_size // patch_size, img_size // patch_size]
|
| 215 |
)
|
| 216 |
dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
|
| 217 |
dec = np.clip((dec + 1) / 2 * 255, 0, 255)
|
| 218 |
-
|
| 219 |
visual_img = np.zeros((parallel_size, img_size, img_size, 3), dtype=np.uint8)
|
| 220 |
visual_img[:, :, :] = dec
|
| 221 |
|
| 222 |
-
# Create output directory
|
| 223 |
output_dir = os.path.dirname(output_path)
|
| 224 |
if output_dir:
|
| 225 |
os.makedirs(output_dir, exist_ok=True)
|
| 226 |
-
|
| 227 |
output_images = []
|
| 228 |
for i in range(parallel_size):
|
| 229 |
save_path = output_path.replace('.png', f'_{i}.png')
|
|
@@ -233,200 +235,65 @@ def text_to_image_generate(input_prompt, output_path, vl_chat_processor, vl_gpt,
|
|
| 233 |
torch.cuda.empty_cache()
|
| 234 |
return output_images
|
| 235 |
|
| 236 |
-
# ===
|
| 237 |
@spaces.GPU(duration=120)
|
| 238 |
-
def
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
# Clean up temp file
|
| 252 |
-
if os.path.exists(temp_image_path):
|
| 253 |
-
os.remove(temp_image_path)
|
| 254 |
-
else:
|
| 255 |
-
# Use text-only generation
|
| 256 |
-
images = text_to_image_generate(message, output_path, vl_chat_processor, vl_gpt)
|
| 257 |
-
|
| 258 |
-
return {"role": "assistant", "content": {"path": images[0]}}
|
| 259 |
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
if input_image is None:
|
| 264 |
-
return None, "Please upload an image to use text+image generation."
|
| 265 |
-
|
| 266 |
-
output_path = "./output/text_image_gen.png"
|
| 267 |
-
|
| 268 |
-
# Save uploaded image temporarily
|
| 269 |
-
temp_image_path = "./temp_input.png"
|
| 270 |
-
input_image.save(temp_image_path)
|
| 271 |
-
|
| 272 |
-
try:
|
| 273 |
-
images = text_and_image_to_image_generate(
|
| 274 |
-
prompt, temp_image_path, output_path, vl_chat_processor, vl_gpt
|
| 275 |
-
)
|
| 276 |
-
return images[0], "Image generated successfully!"
|
| 277 |
-
except Exception as e:
|
| 278 |
-
return None, f"Error generating image: {str(e)}"
|
| 279 |
-
finally:
|
| 280 |
-
# Clean up temp file
|
| 281 |
-
if os.path.exists(temp_image_path):
|
| 282 |
-
os.remove(temp_image_path)
|
| 283 |
-
|
| 284 |
-
@spaces.GPU(duration=120)
|
| 285 |
-
def generate_from_text_only(prompt):
|
| 286 |
-
output_path = "./output/text_only_gen.png"
|
| 287 |
-
|
| 288 |
-
try:
|
| 289 |
-
images = text_to_image_generate(prompt, output_path, vl_chat_processor, vl_gpt)
|
| 290 |
-
return images[0], "Image generated successfully!"
|
| 291 |
-
except Exception as e:
|
| 292 |
-
return None, f"Error generating image: {str(e)}"
|
| 293 |
-
|
| 294 |
-
# === Enhanced Gradio UI with multiple interfaces ===
|
| 295 |
-
with gr.Blocks(theme="soft", title="Janus Text-to-Image & Text+Image-to-Image") as demo:
|
| 296 |
-
gr.Markdown("# Janus Multi-Modal Image Generation")
|
| 297 |
-
gr.Markdown("Generate images from text prompts or transform existing images with text descriptions using Janus-4o-7B")
|
| 298 |
-
|
| 299 |
-
with gr.Tabs():
|
| 300 |
-
# Chat Interface Tab
|
| 301 |
-
with gr.Tab("Chat Interface"):
|
| 302 |
-
gr.Markdown("### Interactive Chat with Optional Image Upload")
|
| 303 |
-
gr.Markdown("You can chat and optionally upload an image to influence the generation")
|
| 304 |
-
|
| 305 |
-
# Create a custom chat interface that supports image upload
|
| 306 |
-
with gr.Row():
|
| 307 |
-
with gr.Column(scale=3):
|
| 308 |
-
chatbot = gr.Chatbot(label="Chat History")
|
| 309 |
-
with gr.Row():
|
| 310 |
-
msg_input = gr.Textbox(
|
| 311 |
-
label="Message",
|
| 312 |
-
placeholder="Describe the image you want to generate...",
|
| 313 |
-
scale=4
|
| 314 |
-
)
|
| 315 |
-
image_input = gr.Image(
|
| 316 |
-
type="pil",
|
| 317 |
-
label="Upload Image (optional)",
|
| 318 |
-
scale=1
|
| 319 |
-
)
|
| 320 |
-
|
| 321 |
-
with gr.Row():
|
| 322 |
-
send_btn = gr.Button("Generate", variant="primary")
|
| 323 |
-
clear_btn = gr.Button("Clear Chat")
|
| 324 |
-
|
| 325 |
-
# Example prompts
|
| 326 |
-
gr.Examples(
|
| 327 |
-
examples=[
|
| 328 |
-
["a cat sitting on a windowsill", None],
|
| 329 |
-
["a futuristic city at sunset", None],
|
| 330 |
-
["a dragon flying over mountains", None],
|
| 331 |
-
],
|
| 332 |
-
inputs=[msg_input, image_input]
|
| 333 |
)
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
text_prompt = gr.Textbox(
|
| 342 |
-
label="Text Prompt",
|
| 343 |
-
placeholder="a beautiful landscape with mountains and a lake",
|
| 344 |
-
lines=3
|
| 345 |
-
)
|
| 346 |
-
text_generate_btn = gr.Button("Generate Image", variant="primary")
|
| 347 |
-
|
| 348 |
-
with gr.Column():
|
| 349 |
-
text_output_image = gr.Image(label="Generated Image")
|
| 350 |
-
text_status = gr.Textbox(label="Status", interactive=False)
|
| 351 |
-
|
| 352 |
-
# Separate Text+Image-to-Image Tab
|
| 353 |
-
with gr.Tab("Text+Image-to-Image"):
|
| 354 |
-
gr.Markdown("### Transform Images with Text Descriptions")
|
| 355 |
-
|
| 356 |
-
with gr.Row():
|
| 357 |
-
with gr.Column():
|
| 358 |
-
img_text_prompt = gr.Textbox(
|
| 359 |
-
label="Text Prompt",
|
| 360 |
-
placeholder="Turn this into a nighttime scene",
|
| 361 |
-
lines=3
|
| 362 |
-
)
|
| 363 |
-
input_image = gr.Image(
|
| 364 |
-
type="pil",
|
| 365 |
-
label="Input Image"
|
| 366 |
-
)
|
| 367 |
-
img_generate_btn = gr.Button("Generate Image", variant="primary")
|
| 368 |
-
|
| 369 |
-
with gr.Column():
|
| 370 |
-
img_output_image = gr.Image(label="Generated Image")
|
| 371 |
-
img_status = gr.Textbox(label="Status", interactive=False)
|
| 372 |
-
|
| 373 |
-
# Event handlers for the chat interface
|
| 374 |
-
def chat_respond(message, image, history):
|
| 375 |
-
if not message.strip():
|
| 376 |
-
return history, ""
|
| 377 |
-
|
| 378 |
-
# Add user message to history
|
| 379 |
-
if image is not None:
|
| 380 |
-
history.append([f"{message} [with uploaded image]", None])
|
| 381 |
-
else:
|
| 382 |
-
history.append([message, None])
|
| 383 |
-
|
| 384 |
-
# Generate response
|
| 385 |
try:
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
# Add assistant response to history
|
| 390 |
-
history[-1][1] = (generated_image_path,)
|
| 391 |
-
|
| 392 |
except Exception as e:
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
)
|
| 412 |
-
|
| 413 |
-
clear_btn.click(
|
| 414 |
-
clear_chat,
|
| 415 |
-
outputs=[chatbot, msg_input]
|
| 416 |
-
)
|
| 417 |
-
|
| 418 |
-
# Wire up the separate interfaces
|
| 419 |
-
text_generate_btn.click(
|
| 420 |
-
generate_from_text_only,
|
| 421 |
-
inputs=[text_prompt],
|
| 422 |
-
outputs=[text_output_image, text_status]
|
| 423 |
-
)
|
| 424 |
-
|
| 425 |
-
img_generate_btn.click(
|
| 426 |
-
generate_from_text_and_image,
|
| 427 |
-
inputs=[img_text_prompt, input_image],
|
| 428 |
-
outputs=[img_output_image, img_status]
|
| 429 |
)
|
| 430 |
|
| 431 |
if __name__ == "__main__":
|
| 432 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Of course. Below is the modified `app.py` that merges the "Text-to-Image" and "Text+Image-to-Image" functionalities into a single, unified `gr.ChatInterface`, removing the extra tabs and UI elements as you requested.
|
| 2 |
+
|
| 3 |
+
```python
|
| 4 |
import os
|
| 5 |
import torch
|
| 6 |
import numpy as np
|
|
|
|
| 22 |
return len(self.input_ids)
|
| 23 |
|
| 24 |
def process_image(image_paths, vl_chat_processor):
|
| 25 |
+
"""Processes a list of image paths into pixel values."""
|
| 26 |
images = [Image.open(image_path).convert("RGB") for image_path in image_paths]
|
| 27 |
images_outputs = vl_chat_processor.image_processor(images, return_tensors="pt")
|
| 28 |
return images_outputs['pixel_values']
|
|
|
|
| 39 |
|
| 40 |
# === Text-and-Image-to-Image generation ===
|
| 41 |
def text_and_image_to_image_generate(input_prompt, input_image_path, output_path, vl_chat_processor, vl_gpt, temperature=1.0, parallel_size=2, cfg_weight=5, cfg_weight2=5):
|
| 42 |
+
"""Generates an image from a text prompt and an input image."""
|
| 43 |
torch.cuda.empty_cache()
|
| 44 |
|
| 45 |
input_img_tokens = vl_chat_processor.image_start_tag + vl_chat_processor.image_tag * vl_chat_processor.num_image_tokens + vl_chat_processor.image_end_tag + vl_chat_processor.image_start_tag + vl_chat_processor.pad_tag * vl_chat_processor.num_image_tokens + vl_chat_processor.image_end_tag
|
| 46 |
+
output_img_tokens = vl_chat_processor.image_start_tag
|
| 47 |
|
| 48 |
pre_data = []
|
| 49 |
input_images = [input_image_path]
|
|
|
|
| 72 |
image_embeds_input = vl_gpt.prepare_gen_img_embeds(image_tokens_input)
|
| 73 |
|
| 74 |
input_ids = torch.LongTensor(vl_chat_processor.tokenizer.encode(sft_format))
|
| 75 |
+
|
| 76 |
encoder_pixel_values = process_image(input_images, vl_chat_processor).cuda()
|
| 77 |
tokens = torch.zeros((parallel_size * 3, len(input_ids)), dtype=torch.long)
|
| 78 |
for i in range(parallel_size * 3):
|
|
|
|
| 104 |
|
| 105 |
for i in range(image_token_num_per_image):
|
| 106 |
outputs = vl_gpt.language_model.model(
|
| 107 |
+
inputs_embeds=inputs_embeds,
|
| 108 |
+
use_cache=True,
|
| 109 |
past_key_values=past_key_values
|
| 110 |
)
|
| 111 |
hidden_states = outputs.last_hidden_state
|
|
|
|
| 125 |
next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
|
| 126 |
img_embeds = vl_gpt.prepare_gen_img_embeds(next_token)
|
| 127 |
inputs_embeds = img_embeds.unsqueeze(dim=1)
|
| 128 |
+
|
| 129 |
past_key_values = outputs.past_key_values
|
| 130 |
|
| 131 |
dec = vl_gpt.gen_vision_model.decode_code(
|
| 132 |
+
generated_tokens.to(dtype=torch.int),
|
| 133 |
shape=[parallel_size, 8, img_size // patch_size, img_size // patch_size]
|
| 134 |
)
|
| 135 |
dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
|
|
|
|
| 154 |
|
| 155 |
# === Text-to-Image generation ===
|
| 156 |
def text_to_image_generate(input_prompt, output_path, vl_chat_processor, vl_gpt, temperature=1.0, parallel_size=2, cfg_weight=5.0):
|
| 157 |
+
"""Generates an image from a text prompt only."""
|
| 158 |
torch.cuda.empty_cache()
|
| 159 |
|
| 160 |
conversation = [
|
|
|
|
| 185 |
|
| 186 |
inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
|
| 187 |
generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).cuda()
|
| 188 |
+
|
| 189 |
past_key_values = None
|
| 190 |
+
|
| 191 |
for i in range(image_token_num_per_image):
|
| 192 |
outputs = vl_gpt.language_model.model(
|
| 193 |
+
inputs_embeds=inputs_embeds,
|
| 194 |
+
use_cache=True,
|
| 195 |
past_key_values=past_key_values
|
| 196 |
)
|
| 197 |
+
|
| 198 |
hidden_states = outputs.last_hidden_state
|
| 199 |
logits = vl_gpt.gen_head(hidden_states[:, -1, :])
|
| 200 |
+
|
| 201 |
logit_cond = logits[0::2, :]
|
| 202 |
logit_uncond = logits[1::2, :]
|
| 203 |
+
|
| 204 |
logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
|
| 205 |
probs = torch.softmax(logits / temperature, dim=-1)
|
| 206 |
next_token = torch.multinomial(probs, num_samples=1)
|
| 207 |
generated_tokens[:, i] = next_token.squeeze(dim=-1)
|
| 208 |
+
|
|
|
|
| 209 |
next_token_expanded = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
|
| 210 |
img_embeds = vl_gpt.prepare_gen_img_embeds(next_token_expanded)
|
| 211 |
inputs_embeds = img_embeds.unsqueeze(dim=1)
|
| 212 |
+
|
|
|
|
| 213 |
past_key_values = outputs.past_key_values
|
| 214 |
|
|
|
|
| 215 |
dec = vl_gpt.gen_vision_model.decode_code(
|
| 216 |
+
generated_tokens.to(dtype=torch.int),
|
| 217 |
shape=[parallel_size, 8, img_size // patch_size, img_size // patch_size]
|
| 218 |
)
|
| 219 |
dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
|
| 220 |
dec = np.clip((dec + 1) / 2 * 255, 0, 255)
|
| 221 |
+
|
| 222 |
visual_img = np.zeros((parallel_size, img_size, img_size, 3), dtype=np.uint8)
|
| 223 |
visual_img[:, :, :] = dec
|
| 224 |
|
|
|
|
| 225 |
output_dir = os.path.dirname(output_path)
|
| 226 |
if output_dir:
|
| 227 |
os.makedirs(output_dir, exist_ok=True)
|
| 228 |
+
|
| 229 |
output_images = []
|
| 230 |
for i in range(parallel_size):
|
| 231 |
save_path = output_path.replace('.png', f'_{i}.png')
|
|
|
|
| 235 |
torch.cuda.empty_cache()
|
| 236 |
return output_images
|
| 237 |
|
| 238 |
+
# === Unified Gradio handler for ChatInterface ===
|
| 239 |
@spaces.GPU(duration=120)
|
| 240 |
+
def janus_chat_responder(message, history):
|
| 241 |
+
"""
|
| 242 |
+
Handles both text-only and multimodal (text+image) inputs from the ChatInterface.
|
| 243 |
+
'message' is a dictionary with 'text' and 'files' keys.
|
| 244 |
+
"""
|
| 245 |
+
output_path = "./output/chat_image.png"
|
| 246 |
+
prompt = message["text"]
|
| 247 |
+
uploaded_files = message["files"]
|
| 248 |
+
|
| 249 |
+
if uploaded_files:
|
| 250 |
+
# Handle text+image to image generation
|
| 251 |
+
# Assuming the first uploaded file is the image to process
|
| 252 |
+
temp_image_path = uploaded_files[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 253 |
|
| 254 |
+
try:
|
| 255 |
+
images = text_and_image_to_image_generate(
|
| 256 |
+
prompt, temp_image_path, output_path, vl_chat_processor, vl_gpt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
)
|
| 258 |
+
# Return the path to the first generated image to be displayed in the chat
|
| 259 |
+
return images[0]
|
| 260 |
+
except Exception as e:
|
| 261 |
+
return f"Error during image-to-image generation: {str(e)}"
|
| 262 |
+
|
| 263 |
+
else:
|
| 264 |
+
# Handle text-to-image generation
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
try:
|
| 266 |
+
images = text_to_image_generate(prompt, output_path, vl_chat_processor, vl_gpt)
|
| 267 |
+
# Return the path to the first generated image
|
| 268 |
+
return images[0]
|
|
|
|
|
|
|
|
|
|
| 269 |
except Exception as e:
|
| 270 |
+
return f"Error during text-to-image generation: {str(e)}"
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
# === Simplified Gradio UI with a single ChatInterface ===
|
| 274 |
+
with gr.Blocks(theme="soft", title="Janus Image Generation") as demo:
|
| 275 |
+
gr.Markdown("# Janus Multi-Modal Image Generation")
|
| 276 |
+
gr.Markdown("Generate images from text prompts, or upload an image and a prompt to transform it.")
|
| 277 |
+
|
| 278 |
+
gr.ChatInterface(
|
| 279 |
+
fn=janus_chat_responder,
|
| 280 |
+
multimodal=True,
|
| 281 |
+
title="Janus-4o-7B Chat",
|
| 282 |
+
examples=[
|
| 283 |
+
{"text": "a cat sitting on a windowsill", "files": []},
|
| 284 |
+
{"text": "a futuristic city at sunset", "files": []},
|
| 285 |
+
{"text": "a dragon flying over mountains", "files": []},
|
| 286 |
+
{"text": "Turn this into a watercolor painting", "files": ["./assets/example_image.jpg"]}
|
| 287 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 288 |
)
|
| 289 |
|
| 290 |
if __name__ == "__main__":
|
| 291 |
+
# Create a dummy image for the example if it doesn't exist
|
| 292 |
+
if not os.path.exists("./assets"):
|
| 293 |
+
os.makedirs("./assets")
|
| 294 |
+
if not os.path.exists("./assets/example_image.jpg"):
|
| 295 |
+
dummy_image = Image.new('RGB', (100, 100), color = 'red')
|
| 296 |
+
dummy_image.save("./assets/example_image.jpg")
|
| 297 |
+
|
| 298 |
+
demo.launch()
|
| 299 |
+
```
|