Spaces:
Paused
Paused
| import os | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| import gradio as gr | |
| from transformers import AutoModelForCausalLM | |
| from janus.models import MultiModalityCausalLM, VLChatProcessor | |
| from dataclasses import dataclass | |
| import spaces | |
| # This dataclass definition is required for the processor | |
| class VLChatProcessorOutput(): | |
| sft_format: str | |
| input_ids: torch.Tensor | |
| pixel_values: torch.Tensor | |
| num_image_tokens: torch.IntTensor | |
| def __len__(self): | |
| return len(self.input_ids) | |
| def process_image(image_paths, vl_chat_processor): | |
| """Processes a list of image paths into pixel values.""" | |
| images = [Image.open(image_path).convert("RGB") for image_path in image_paths] | |
| images_outputs = vl_chat_processor.image_processor(images, return_tensors="pt") | |
| return images_outputs['pixel_values'] | |
| # === Load Janus model and processor === | |
| # This setup assumes the necessary model files are accessible. | |
| model_path = "FreedomIntelligence/Janus-4o-7B" | |
| vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path) | |
| tokenizer = vl_chat_processor.tokenizer | |
| vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained( | |
| model_path, trust_remote_code=True, torch_dtype=torch.bfloat16 | |
| ) | |
| vl_gpt = vl_gpt.cuda().eval() | |
| # === Text-and-Image-to-Image generation === | |
| def text_and_image_to_image_generate(input_prompt, input_image_path, output_path, vl_chat_processor, vl_gpt, temperature=1.0, parallel_size=2, cfg_weight=5, cfg_weight2=5): | |
| """Generates an image from a text prompt and an input image.""" | |
| torch.cuda.empty_cache() | |
| input_img_tokens = vl_chat_processor.image_start_tag + vl_chat_processor.image_tag * vl_chat_processor.num_image_tokens + vl_chat_processor.image_end_tag + vl_chat_processor.image_start_tag + vl_chat_processor.pad_tag * vl_chat_processor.num_image_tokens + vl_chat_processor.image_end_tag | |
| output_img_tokens = vl_chat_processor.image_start_tag | |
| pre_data = [] | |
| input_images = [input_image_path] | |
| img_len = len(input_images) | |
| prompts = input_img_tokens * img_len + input_prompt | |
| conversation = [ | |
| {"role": "<|User|>", "content": prompts}, | |
| {"role": "<|Assistant|>", "content": ""} | |
| ] | |
| sft_format = vl_chat_processor.apply_sft_template_for_multi_turn_prompts( | |
| conversations=conversation, | |
| sft_format=vl_chat_processor.sft_format, | |
| system_prompt="", | |
| ) | |
| sft_format = sft_format + output_img_tokens | |
| image_token_num_per_image = 576 | |
| img_size = 384 | |
| patch_size = 16 | |
| with torch.inference_mode(): | |
| input_image_pixel_values = process_image(input_images, vl_chat_processor).to(torch.bfloat16).cuda() | |
| _, _, info_input = vl_gpt.gen_vision_model.encode(input_image_pixel_values) | |
| image_tokens_input = info_input[2].detach().reshape(input_image_pixel_values.shape[0], -1) | |
| image_embeds_input = vl_gpt.prepare_gen_img_embeds(image_tokens_input) | |
| input_ids = torch.LongTensor(vl_chat_processor.tokenizer.encode(sft_format)) | |
| encoder_pixel_values = process_image(input_images, vl_chat_processor).cuda() | |
| tokens = torch.zeros((parallel_size * 3, len(input_ids)), dtype=torch.long) | |
| for i in range(parallel_size * 3): | |
| tokens[i, :] = input_ids | |
| if i % 3 == 2: | |
| tokens[i, 1:-1] = vl_chat_processor.pad_id | |
| pre_data.append(VLChatProcessorOutput(sft_format=sft_format, pixel_values=encoder_pixel_values, input_ids=tokens[i-2], num_image_tokens=[vl_chat_processor.num_image_tokens] * img_len)) | |
| pre_data.append(VLChatProcessorOutput(sft_format=sft_format, pixel_values=encoder_pixel_values, input_ids=tokens[i-1], num_image_tokens=[vl_chat_processor.num_image_tokens] * img_len)) | |
| pre_data.append(VLChatProcessorOutput(sft_format=sft_format, pixel_values=None, input_ids=tokens[i], num_image_tokens=[])) | |
| prepare_inputs = vl_chat_processor.batchify(pre_data) | |
| inputs_embeds = vl_gpt.prepare_inputs_embeds( | |
| input_ids=tokens.cuda(), | |
| pixel_values=prepare_inputs['pixel_values'].to(torch.bfloat16).cuda(), | |
| images_emb_mask=prepare_inputs['images_emb_mask'].cuda(), | |
| images_seq_mask=prepare_inputs['images_seq_mask'].cuda() | |
| ) | |
| image_gen_indices = (tokens == vl_chat_processor.image_end_id).nonzero() | |
| for ii, ind in enumerate(image_gen_indices): | |
| if ii % 4 == 0: | |
| offset = ind[1] + 2 | |
| inputs_embeds[ind[0], offset: offset + image_embeds_input.shape[1], :] = image_embeds_input[(ii // 2) % img_len] | |
| generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).cuda() | |
| # --- FIX: Initialize past_key_values for cached generation --- | |
| past_key_values = None | |
| for i in range(image_token_num_per_image): | |
| outputs = vl_gpt.language_model.model( | |
| inputs_embeds=inputs_embeds, | |
| use_cache=True, | |
| past_key_values=past_key_values # Pass cached values | |
| ) | |
| hidden_states = outputs.last_hidden_state | |
| logits = vl_gpt.gen_head(hidden_states[:, -1, :]) | |
| logit_cond_full = logits[0::3, :] | |
| logit_cond_part = logits[1::3, :] | |
| logit_uncond = logits[2::3, :] | |
| logit_cond = (logit_cond_full + cfg_weight2 * logit_cond_part) / (1 + cfg_weight2) | |
| logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond) | |
| probs = torch.softmax(logits / temperature, dim=-1) | |
| next_token = torch.multinomial(probs, num_samples=1) | |
| generated_tokens[:, i] = next_token.squeeze(dim=-1) | |
| next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1) | |
| img_embeds = vl_gpt.prepare_gen_img_embeds(next_token) | |
| inputs_embeds = img_embeds.unsqueeze(dim=1) | |
| # --- FIX: Update past_key_values with the output from the current step --- | |
| past_key_values = outputs.past_key_values | |
| dec = vl_gpt.gen_vision_model.decode_code( | |
| generated_tokens.to(dtype=torch.int), | |
| shape=[parallel_size, 8, img_size // patch_size, img_size // patch_size] | |
| ) | |
| dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1) | |
| dec = np.clip((dec + 1) / 2 * 255, 0, 255) | |
| visual_img = np.zeros((parallel_size, img_size, img_size, 3), dtype=np.uint8) | |
| visual_img[:, :, :] = dec | |
| output_dir = os.path.dirname(output_path) | |
| if output_dir: | |
| os.makedirs(output_dir, exist_ok=True) | |
| output_images = [] | |
| for i in range(parallel_size): | |
| save_path = output_path.replace('.png', f'_{i}.png') | |
| Image.fromarray(visual_img[i]).save(save_path) | |
| output_images.append(save_path) | |
| torch.cuda.empty_cache() | |
| return output_images | |
| # === Text-to-Image generation === | |
| def text_to_image_generate(input_prompt, output_path, vl_chat_processor, vl_gpt, temperature=1.0, parallel_size=2, cfg_weight=5.0): | |
| """Generates an image from a text prompt only.""" | |
| torch.cuda.empty_cache() | |
| conversation = [ | |
| {"role": "<|User|>", "content": input_prompt}, | |
| {"role": "<|Assistant|>", "content": ""}, | |
| ] | |
| sft_format = vl_chat_processor.apply_sft_template_for_multi_turn_prompts( | |
| conversations=conversation, | |
| sft_format=vl_chat_processor.sft_format, | |
| system_prompt="", | |
| ) | |
| prompt = sft_format + vl_chat_processor.image_start_tag | |
| image_token_num_per_image = 576 | |
| img_size = 384 | |
| patch_size = 16 | |
| with torch.inference_mode(): | |
| input_ids = vl_chat_processor.tokenizer.encode(prompt) | |
| input_ids = torch.LongTensor(input_ids) | |
| tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).cuda() | |
| for i in range(parallel_size * 2): | |
| tokens[i, :] = input_ids | |
| if i % 2 != 0: | |
| tokens[i, 1:-1] = vl_chat_processor.pad_id | |
| inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens) | |
| generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).cuda() | |
| # --- FIX: Initialize past_key_values for cached generation --- | |
| past_key_values = None | |
| for i in range(image_token_num_per_image): | |
| outputs = vl_gpt.language_model.model( | |
| inputs_embeds=inputs_embeds, | |
| use_cache=True, | |
| past_key_values=past_key_values # Pass cached values | |
| ) | |
| hidden_states = outputs.last_hidden_state | |
| logits = vl_gpt.gen_head(hidden_states[:, -1, :]) | |
| logit_cond = logits[0::2, :] | |
| logit_uncond = logits[1::2, :] | |
| logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond) | |
| probs = torch.softmax(logits / temperature, dim=-1) | |
| next_token = torch.multinomial(probs, num_samples=1) | |
| generated_tokens[:, i] = next_token.squeeze(dim=-1) | |
| next_token_expanded = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1) | |
| img_embeds = vl_gpt.prepare_gen_img_embeds(next_token_expanded) | |
| inputs_embeds = img_embeds.unsqueeze(dim=1) | |
| # --- FIX: Update past_key_values with the output from the current step --- | |
| past_key_values = outputs.past_key_values | |
| dec = vl_gpt.gen_vision_model.decode_code( | |
| generated_tokens.to(dtype=torch.int), | |
| shape=[parallel_size, 8, img_size // patch_size, img_size // patch_size] | |
| ) | |
| dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1) | |
| dec = np.clip((dec + 1) / 2 * 255, 0, 255) | |
| visual_img = np.zeros((parallel_size, img_size, img_size, 3), dtype=np.uint8) | |
| visual_img[:, :, :] = dec | |
| output_dir = os.path.dirname(output_path) | |
| if output_dir: | |
| os.makedirs(output_dir, exist_ok=True) | |
| output_images = [] | |
| for i in range(parallel_size): | |
| save_path = output_path.replace('.png', f'_{i}.png') | |
| Image.fromarray(visual_img[i]).save(save_path) | |
| output_images.append(save_path) | |
| torch.cuda.empty_cache() | |
| return output_images | |
| # === Unified Gradio handler for ChatInterface === | |
| def janus_chat_responder(message, history): | |
| """ | |
| Handles both text-only and multimodal (text+image) inputs from the ChatInterface. | |
| 'message' is a dictionary with 'text' and 'files' keys. | |
| """ | |
| output_path = "./output/chat_image.png" | |
| prompt = message["text"] | |
| uploaded_files = message["files"] | |
| try: | |
| if uploaded_files: | |
| # Handle text+image to image generation | |
| temp_image_path = uploaded_files[0] | |
| images = text_and_image_to_image_generate( | |
| prompt, temp_image_path, output_path, vl_chat_processor, vl_gpt | |
| ) | |
| else: | |
| # Handle text-to-image generation | |
| images = text_to_image_generate(prompt, output_path, vl_chat_processor, vl_gpt) | |
| # Return a gallery component to display all generated images | |
| return gr.Gallery(value=images, label="Generated Images") | |
| except Exception as e: | |
| # Return a user-friendly error message | |
| gr.Error(f"An error occurred during generation: {str(e)}") | |
| # Return None or an empty list for the gallery to clear it | |
| return None | |
| # === Gradio UI with a single ChatInterface === | |
| with gr.Blocks(theme="soft", title="Janus Image Generation") as demo: | |
| gr.Markdown("# Janus Multi-Modal Image Generation") | |
| gr.Markdown("Generate images from text prompts, or upload an image and a prompt to transform it.") | |
| # Using gr.ChatInterface which handles the chat history and input box automatically | |
| gr.ChatInterface( | |
| fn=janus_chat_responder, | |
| multimodal=True, # Enables file uploads | |
| title="Janus-4o-7B", | |
| chatbot=gr.Chatbot(height=400, label="Chat", show_label=False), | |
| textbox=gr.MultimodalTextbox( | |
| file_types=["image"], | |
| placeholder="Type a prompt or upload an image...", | |
| label="Input" | |
| ), | |
| examples=[ | |
| {"text": "A cat made of glass, sitting on a table.", "files": []}, | |
| {"text": "A futuristic city at sunset, with flying cars.", "files": []}, | |
| {"text": "A dragon breathing fire over a medieval castle.", "files": []}, | |
| {"text": "Turn this into a watercolor painting.", "files": ["./assets/example_image.jpg"]} | |
| ] | |
| ) | |
| if __name__ == "__main__": | |
| # Create a dummy image for the example if it doesn't exist to prevent errors | |
| assets_dir = "./assets" | |
| example_image_path = os.path.join(assets_dir, "example_image.jpg") | |
| if not os.path.exists(example_image_path): | |
| os.makedirs(assets_dir, exist_ok=True) | |
| try: | |
| dummy_image = Image.new('RGB', (384, 384), color = 'red') | |
| dummy_image.save(example_image_path) | |
| print(f"Created dummy example image at: {example_image_path}") | |
| except Exception as e: | |
| print(f"Could not create dummy image: {e}") | |
| demo.launch() |