Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -8,6 +8,7 @@ from janus.models import MultiModalityCausalLM, VLChatProcessor
|
|
| 8 |
from dataclasses import dataclass
|
| 9 |
import spaces
|
| 10 |
|
|
|
|
| 11 |
@dataclass
|
| 12 |
class VLChatProcessorOutput():
|
| 13 |
sft_format: str
|
|
@@ -24,9 +25,8 @@ def process_image(image_paths, vl_chat_processor):
|
|
| 24 |
images_outputs = vl_chat_processor.image_processor(images, return_tensors="pt")
|
| 25 |
return images_outputs['pixel_values']
|
| 26 |
|
| 27 |
-
# === Load Janus model ===
|
| 28 |
-
#
|
| 29 |
-
# In a local environment, you might need to adjust paths or download assets.
|
| 30 |
model_path = "FreedomIntelligence/Janus-4o-7B"
|
| 31 |
vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
|
| 32 |
tokenizer = vl_chat_processor.tokenizer
|
|
@@ -66,7 +66,7 @@ def text_and_image_to_image_generate(input_prompt, input_image_path, output_path
|
|
| 66 |
|
| 67 |
with torch.inference_mode():
|
| 68 |
input_image_pixel_values = process_image(input_images, vl_chat_processor).to(torch.bfloat16).cuda()
|
| 69 |
-
|
| 70 |
image_tokens_input = info_input[2].detach().reshape(input_image_pixel_values.shape[0], -1)
|
| 71 |
image_embeds_input = vl_gpt.prepare_gen_img_embeds(image_tokens_input)
|
| 72 |
|
|
@@ -99,13 +99,15 @@ def text_and_image_to_image_generate(input_prompt, input_image_path, output_path
|
|
| 99 |
inputs_embeds[ind[0], offset: offset + image_embeds_input.shape[1], :] = image_embeds_input[(ii // 2) % img_len]
|
| 100 |
|
| 101 |
generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).cuda()
|
|
|
|
|
|
|
| 102 |
past_key_values = None
|
| 103 |
|
| 104 |
for i in range(image_token_num_per_image):
|
| 105 |
outputs = vl_gpt.language_model.model(
|
| 106 |
inputs_embeds=inputs_embeds,
|
| 107 |
use_cache=True,
|
| 108 |
-
past_key_values=past_key_values
|
| 109 |
)
|
| 110 |
hidden_states = outputs.last_hidden_state
|
| 111 |
|
|
@@ -124,7 +126,8 @@ def text_and_image_to_image_generate(input_prompt, input_image_path, output_path
|
|
| 124 |
next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
|
| 125 |
img_embeds = vl_gpt.prepare_gen_img_embeds(next_token)
|
| 126 |
inputs_embeds = img_embeds.unsqueeze(dim=1)
|
| 127 |
-
|
|
|
|
| 128 |
past_key_values = outputs.past_key_values
|
| 129 |
|
| 130 |
dec = vl_gpt.gen_vision_model.decode_code(
|
|
@@ -184,13 +187,14 @@ def text_to_image_generate(input_prompt, output_path, vl_chat_processor, vl_gpt,
|
|
| 184 |
inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
|
| 185 |
generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).cuda()
|
| 186 |
|
|
|
|
| 187 |
past_key_values = None
|
| 188 |
|
| 189 |
for i in range(image_token_num_per_image):
|
| 190 |
outputs = vl_gpt.language_model.model(
|
| 191 |
inputs_embeds=inputs_embeds,
|
| 192 |
use_cache=True,
|
| 193 |
-
past_key_values=past_key_values
|
| 194 |
)
|
| 195 |
|
| 196 |
hidden_states = outputs.last_hidden_state
|
|
@@ -208,6 +212,7 @@ def text_to_image_generate(input_prompt, output_path, vl_chat_processor, vl_gpt,
|
|
| 208 |
img_embeds = vl_gpt.prepare_gen_img_embeds(next_token_expanded)
|
| 209 |
inputs_embeds = img_embeds.unsqueeze(dim=1)
|
| 210 |
|
|
|
|
| 211 |
past_key_values = outputs.past_key_values
|
| 212 |
|
| 213 |
dec = vl_gpt.gen_vision_model.decode_code(
|
|
@@ -244,53 +249,62 @@ def janus_chat_responder(message, history):
|
|
| 244 |
prompt = message["text"]
|
| 245 |
uploaded_files = message["files"]
|
| 246 |
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
try:
|
| 253 |
images = text_and_image_to_image_generate(
|
| 254 |
prompt, temp_image_path, output_path, vl_chat_processor, vl_gpt
|
| 255 |
)
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
except Exception as e:
|
| 259 |
-
return f"Error during image-to-image generation: {str(e)}"
|
| 260 |
-
|
| 261 |
-
else:
|
| 262 |
-
# Handle text-to-image generation
|
| 263 |
-
try:
|
| 264 |
images = text_to_image_generate(prompt, output_path, vl_chat_processor, vl_gpt)
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 269 |
|
| 270 |
|
| 271 |
-
# ===
|
| 272 |
with gr.Blocks(theme="soft", title="Janus Image Generation") as demo:
|
| 273 |
gr.Markdown("# Janus Multi-Modal Image Generation")
|
| 274 |
gr.Markdown("Generate images from text prompts, or upload an image and a prompt to transform it.")
|
| 275 |
|
|
|
|
| 276 |
gr.ChatInterface(
|
| 277 |
fn=janus_chat_responder,
|
| 278 |
-
multimodal=True,
|
| 279 |
-
title="Janus-4o-7B
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 280 |
examples=[
|
| 281 |
-
{"text": "
|
| 282 |
-
{"text": "
|
| 283 |
-
{"text": "
|
| 284 |
-
{"text": "Turn this into a watercolor painting", "files": ["./assets/example_image.jpg"]}
|
| 285 |
]
|
| 286 |
)
|
| 287 |
|
| 288 |
if __name__ == "__main__":
|
| 289 |
-
# Create a dummy image for the example if it doesn't exist
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
if not os.path.exists(
|
| 293 |
-
|
| 294 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 295 |
|
| 296 |
demo.launch()
|
|
|
|
| 8 |
from dataclasses import dataclass
|
| 9 |
import spaces
|
| 10 |
|
| 11 |
+
# This dataclass definition is required for the processor
|
| 12 |
@dataclass
|
| 13 |
class VLChatProcessorOutput():
|
| 14 |
sft_format: str
|
|
|
|
| 25 |
images_outputs = vl_chat_processor.image_processor(images, return_tensors="pt")
|
| 26 |
return images_outputs['pixel_values']
|
| 27 |
|
| 28 |
+
# === Load Janus model and processor ===
|
| 29 |
+
# This setup assumes the necessary model files are accessible.
|
|
|
|
| 30 |
model_path = "FreedomIntelligence/Janus-4o-7B"
|
| 31 |
vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
|
| 32 |
tokenizer = vl_chat_processor.tokenizer
|
|
|
|
| 66 |
|
| 67 |
with torch.inference_mode():
|
| 68 |
input_image_pixel_values = process_image(input_images, vl_chat_processor).to(torch.bfloat16).cuda()
|
| 69 |
+
_, _, info_input = vl_gpt.gen_vision_model.encode(input_image_pixel_values)
|
| 70 |
image_tokens_input = info_input[2].detach().reshape(input_image_pixel_values.shape[0], -1)
|
| 71 |
image_embeds_input = vl_gpt.prepare_gen_img_embeds(image_tokens_input)
|
| 72 |
|
|
|
|
| 99 |
inputs_embeds[ind[0], offset: offset + image_embeds_input.shape[1], :] = image_embeds_input[(ii // 2) % img_len]
|
| 100 |
|
| 101 |
generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).cuda()
|
| 102 |
+
|
| 103 |
+
# --- FIX: Initialize past_key_values for cached generation ---
|
| 104 |
past_key_values = None
|
| 105 |
|
| 106 |
for i in range(image_token_num_per_image):
|
| 107 |
outputs = vl_gpt.language_model.model(
|
| 108 |
inputs_embeds=inputs_embeds,
|
| 109 |
use_cache=True,
|
| 110 |
+
past_key_values=past_key_values # Pass cached values
|
| 111 |
)
|
| 112 |
hidden_states = outputs.last_hidden_state
|
| 113 |
|
|
|
|
| 126 |
next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
|
| 127 |
img_embeds = vl_gpt.prepare_gen_img_embeds(next_token)
|
| 128 |
inputs_embeds = img_embeds.unsqueeze(dim=1)
|
| 129 |
+
|
| 130 |
+
# --- FIX: Update past_key_values with the output from the current step ---
|
| 131 |
past_key_values = outputs.past_key_values
|
| 132 |
|
| 133 |
dec = vl_gpt.gen_vision_model.decode_code(
|
|
|
|
| 187 |
inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
|
| 188 |
generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).cuda()
|
| 189 |
|
| 190 |
+
# --- FIX: Initialize past_key_values for cached generation ---
|
| 191 |
past_key_values = None
|
| 192 |
|
| 193 |
for i in range(image_token_num_per_image):
|
| 194 |
outputs = vl_gpt.language_model.model(
|
| 195 |
inputs_embeds=inputs_embeds,
|
| 196 |
use_cache=True,
|
| 197 |
+
past_key_values=past_key_values # Pass cached values
|
| 198 |
)
|
| 199 |
|
| 200 |
hidden_states = outputs.last_hidden_state
|
|
|
|
| 212 |
img_embeds = vl_gpt.prepare_gen_img_embeds(next_token_expanded)
|
| 213 |
inputs_embeds = img_embeds.unsqueeze(dim=1)
|
| 214 |
|
| 215 |
+
# --- FIX: Update past_key_values with the output from the current step ---
|
| 216 |
past_key_values = outputs.past_key_values
|
| 217 |
|
| 218 |
dec = vl_gpt.gen_vision_model.decode_code(
|
|
|
|
| 249 |
prompt = message["text"]
|
| 250 |
uploaded_files = message["files"]
|
| 251 |
|
| 252 |
+
try:
|
| 253 |
+
if uploaded_files:
|
| 254 |
+
# Handle text+image to image generation
|
| 255 |
+
temp_image_path = uploaded_files[0]
|
|
|
|
|
|
|
| 256 |
images = text_and_image_to_image_generate(
|
| 257 |
prompt, temp_image_path, output_path, vl_chat_processor, vl_gpt
|
| 258 |
)
|
| 259 |
+
else:
|
| 260 |
+
# Handle text-to-image generation
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
images = text_to_image_generate(prompt, output_path, vl_chat_processor, vl_gpt)
|
| 262 |
+
|
| 263 |
+
# Return a gallery component to display all generated images
|
| 264 |
+
return gr.Gallery(value=images, label="Generated Images")
|
| 265 |
+
|
| 266 |
+
except Exception as e:
|
| 267 |
+
# Return a user-friendly error message
|
| 268 |
+
gr.Error(f"An error occurred during generation: {str(e)}")
|
| 269 |
+
# Return None or an empty list for the gallery to clear it
|
| 270 |
+
return None
|
| 271 |
|
| 272 |
|
| 273 |
+
# === Gradio UI with a single ChatInterface ===
|
| 274 |
with gr.Blocks(theme="soft", title="Janus Image Generation") as demo:
|
| 275 |
gr.Markdown("# Janus Multi-Modal Image Generation")
|
| 276 |
gr.Markdown("Generate images from text prompts, or upload an image and a prompt to transform it.")
|
| 277 |
|
| 278 |
+
# Using gr.ChatInterface which handles the chat history and input box automatically
|
| 279 |
gr.ChatInterface(
|
| 280 |
fn=janus_chat_responder,
|
| 281 |
+
multimodal=True, # Enables file uploads
|
| 282 |
+
title="Janus-4o-7B",
|
| 283 |
+
chatbot=gr.Chatbot(height=400, label="Chat", show_label=False),
|
| 284 |
+
textbox=gr.MultimodalTextbox(
|
| 285 |
+
file_types=["image"],
|
| 286 |
+
placeholder="Type a prompt or upload an image...",
|
| 287 |
+
label="Input"
|
| 288 |
+
),
|
| 289 |
examples=[
|
| 290 |
+
{"text": "A cat made of glass, sitting on a table.", "files": []},
|
| 291 |
+
{"text": "A futuristic city at sunset, with flying cars.", "files": []},
|
| 292 |
+
{"text": "A dragon breathing fire over a medieval castle.", "files": []},
|
| 293 |
+
{"text": "Turn this into a watercolor painting.", "files": ["./assets/example_image.jpg"]}
|
| 294 |
]
|
| 295 |
)
|
| 296 |
|
| 297 |
if __name__ == "__main__":
|
| 298 |
+
# Create a dummy image for the example if it doesn't exist to prevent errors
|
| 299 |
+
assets_dir = "./assets"
|
| 300 |
+
example_image_path = os.path.join(assets_dir, "example_image.jpg")
|
| 301 |
+
if not os.path.exists(example_image_path):
|
| 302 |
+
os.makedirs(assets_dir, exist_ok=True)
|
| 303 |
+
try:
|
| 304 |
+
dummy_image = Image.new('RGB', (384, 384), color = 'red')
|
| 305 |
+
dummy_image.save(example_image_path)
|
| 306 |
+
print(f"Created dummy example image at: {example_image_path}")
|
| 307 |
+
except Exception as e:
|
| 308 |
+
print(f"Could not create dummy image: {e}")
|
| 309 |
|
| 310 |
demo.launch()
|