Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| """ | |
| Gradio UI for Multimodal Gemma Model - Hugging Face Space Version | |
| Fixed: Added all missing modules (projectors.py, lightning_module.py, logging.py, data/, training/) | |
| Updated requirements.txt with rich and datasets libraries | |
| """ | |
| import sys | |
| import torch | |
| import gradio as gr | |
| from pathlib import Path | |
| from PIL import Image | |
| import io | |
| import time | |
| import logging | |
| from huggingface_hub import hf_hub_download | |
| # Model imports | |
| from src.models import MultimodalGemmaLightning | |
| from src.utils.config import load_config, merge_configs | |
| # Global model variable | |
| model = None | |
| config = None | |
| def download_and_load_model(): | |
| """Download and load the trained multimodal model from HF""" | |
| global model, config | |
| if model is not None: | |
| return "β Model already loaded!" | |
| try: | |
| print("π Downloading multimodal Gemma model from HF...") | |
| # Download model checkpoint | |
| checkpoint_path = hf_hub_download( | |
| repo_id="sagar007/multimodal-gemma-270m-llava", | |
| filename="final_model.ckpt", | |
| cache_dir="./model_cache" | |
| ) | |
| print("π Loading model from checkpoint...") | |
| # Load checkpoint data to inspect what's inside | |
| checkpoint = torch.load(checkpoint_path, map_location="cpu") | |
| print(f"Checkpoint keys: {list(checkpoint.keys())}") | |
| # Extract the saved hyperparameters if they exist | |
| if "hyper_parameters" in checkpoint: | |
| saved_config = checkpoint["hyper_parameters"].get("config", {}) | |
| print("Found saved config in checkpoint") | |
| # Override any gated models in the saved config | |
| if "model" in saved_config and "gemma_model_name" in saved_config["model"]: | |
| if "google/gemma" in saved_config["model"]["gemma_model_name"]: | |
| print("Replacing gated Gemma model with accessible alternative") | |
| saved_config["model"]["gemma_model_name"] = "microsoft/DialoGPT-medium" | |
| saved_config["model"]["use_4bit"] = False # Disable quantization for compatibility | |
| config = saved_config | |
| else: | |
| print("No saved config found, creating minimal config") | |
| # Create minimal config for loading | |
| config = { | |
| "model": { | |
| "gemma_model_name": "microsoft/DialoGPT-medium", # Use non-gated model | |
| "vision_model_name": "openai/clip-vit-large-patch14", | |
| "use_4bit": False, # Disable quantization for loading | |
| "projector_hidden_dim": 2048, | |
| "lora": {"r": 16, "alpha": 32, "dropout": 0.1} | |
| }, | |
| "special_tokens": {"image_token": "<image>"}, | |
| "training": {"projector_lr": 1e-3, "lora_lr": 1e-4} | |
| } | |
| try: | |
| # First try: Use the checkpoint's config if available | |
| model = MultimodalGemmaLightning.load_from_checkpoint( | |
| checkpoint_path, | |
| config=config, | |
| strict=False, | |
| map_location="cuda" if torch.cuda.is_available() else "cpu" | |
| ) | |
| print("β Loaded with checkpoint config") | |
| except Exception as e1: | |
| print(f"Failed with checkpoint config: {e1}") | |
| try: | |
| # Second try: Minimal config with no quantization | |
| minimal_config = { | |
| "model": { | |
| "gemma_model_name": "microsoft/DialoGPT-small", # Even smaller model | |
| "vision_model_name": "openai/clip-vit-base-patch32", # Smaller CLIP | |
| "use_4bit": False, # No quantization | |
| "projector_hidden_dim": 512, | |
| "lora": {"r": 8, "alpha": 16, "dropout": 0.1, "target_modules": ["q_proj", "v_proj"]} | |
| }, | |
| "special_tokens": {"image_token": "<image>"}, | |
| "training": {"projector_lr": 1e-3, "lora_lr": 1e-4} | |
| } | |
| model = MultimodalGemmaLightning.load_from_checkpoint( | |
| checkpoint_path, | |
| config=minimal_config, | |
| strict=False, | |
| map_location="cuda" if torch.cuda.is_available() else "cpu" | |
| ) | |
| print("β Loaded with minimal config") | |
| except Exception as e2: | |
| print(f"Failed with minimal config: {e2}") | |
| try: | |
| # Third try: Direct state dict loading | |
| print("Attempting direct state dict loading...") | |
| # Create a dummy model just to get the structure | |
| dummy_config = { | |
| "model": { | |
| "gemma_model_name": "microsoft/DialoGPT-small", | |
| "vision_model_name": "openai/clip-vit-base-patch32", | |
| "use_4bit": False, | |
| "projector_hidden_dim": 512, | |
| }, | |
| "special_tokens": {"image_token": "<image>"}, | |
| "training": {"projector_lr": 1e-3, "lora_lr": 1e-4} | |
| } | |
| model = MultimodalGemmaLightning(dummy_config) | |
| # Load only compatible weights | |
| checkpoint_state = checkpoint['state_dict'] | |
| model_state = model.state_dict() | |
| # Filter and load compatible weights | |
| compatible_weights = {} | |
| for key, value in checkpoint_state.items(): | |
| if key in model_state and model_state[key].shape == value.shape: | |
| compatible_weights[key] = value | |
| model.load_state_dict(compatible_weights, strict=False) | |
| print(f"β Loaded {len(compatible_weights)} compatible weights") | |
| except Exception as e3: | |
| print(f"All loading methods failed: {e3}") | |
| return f"β Model loading failed - checkpoint incompatible. Last error: {str(e3)}" | |
| model.eval() | |
| # Move to appropriate device | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = model.to(device) | |
| print(f"β Model loaded successfully on {device}!") | |
| return f"β Model loaded successfully on {device}!" | |
| except Exception as e: | |
| error_msg = f"β Error loading model: {str(e)}" | |
| print(error_msg) | |
| return error_msg | |
| def predict_with_image(image, question, max_tokens=100, temperature=0.7): | |
| """Generate response for image + text input""" | |
| global model, config | |
| if model is None: | |
| return "β Please load the model first using the 'Load Model' button!" | |
| if image is None: | |
| return "β Please upload an image!" | |
| if not question.strip(): | |
| question = "What do you see in this image?" | |
| try: | |
| # Get device | |
| device = next(model.parameters()).device | |
| # Process image | |
| if isinstance(image, str): | |
| image = Image.open(image).convert('RGB') | |
| elif not isinstance(image, Image.Image): | |
| image = Image.fromarray(image).convert('RGB') | |
| # Prepare image for model | |
| vision_inputs = model.model.vision_processor( | |
| images=[image], | |
| return_tensors="pt" | |
| ) | |
| pixel_values = vision_inputs["pixel_values"].to(device) | |
| # Prepare text prompt | |
| prompt = f"<image>\\nHuman: {question}\\nAssistant:" | |
| # Tokenize text | |
| text_inputs = model.model.tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=256 | |
| ) | |
| input_ids = text_inputs["input_ids"].to(device) | |
| attention_mask = text_inputs["attention_mask"].to(device) | |
| # Generate response | |
| with torch.no_grad(): | |
| # Use the full multimodal model with image inputs | |
| outputs = model.model.generate( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| images=pixel_values, | |
| max_new_tokens=min(max_tokens, 150), | |
| temperature=min(max(temperature, 0.1), 2.0), | |
| do_sample=temperature > 0.1, | |
| repetition_penalty=1.1 | |
| ) | |
| # Decode response | |
| input_length = input_ids.shape[1] | |
| generated_tokens = outputs[0][input_length:] | |
| response = model.model.tokenizer.decode(generated_tokens, skip_special_tokens=True) | |
| # Clean up response | |
| response = response.strip() | |
| if not response: | |
| response = "I can see the image, but I'm having trouble generating a detailed response." | |
| return response | |
| except Exception as e: | |
| error_msg = f"β Error during inference: {str(e)}" | |
| print(error_msg) | |
| return error_msg | |
| def chat_with_image(image, question, history, max_tokens, temperature): | |
| """Chat interface function""" | |
| if model is None: | |
| response = "β Please load the model first!" | |
| else: | |
| response = predict_with_image(image, question, max_tokens, temperature) | |
| # Add to history - using messages format | |
| history.append({"role": "user", "content": question}) | |
| history.append({"role": "assistant", "content": response}) | |
| return history, "" | |
| def create_gradio_interface(): | |
| """Create the Gradio interface""" | |
| # Custom CSS for better styling | |
| css = """ | |
| .container { | |
| max-width: 1200px; | |
| margin: auto; | |
| padding: 20px; | |
| } | |
| .header { | |
| text-align: center; | |
| margin-bottom: 30px; | |
| } | |
| .model-info { | |
| background-color: #f0f8ff; | |
| padding: 15px; | |
| border-radius: 10px; | |
| margin-bottom: 20px; | |
| } | |
| """ | |
| with gr.Blocks(css=css, title="Multimodal Gemma Chat") as demo: | |
| gr.HTML(""" | |
| <div class="header"> | |
| <h1>π Multimodal Gemma-270M Chat</h1> | |
| <p>Upload an image and chat with your trained vision-language model!</p> | |
| <p><a href="https://huggingface.co/sagar007/multimodal-gemma-270m-llava">π€ Model</a></p> | |
| </div> | |
| """) | |
| # Model status section | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.HTML(""" | |
| <div class="model-info"> | |
| <h3>π Model Info</h3> | |
| <ul> | |
| <li><strong>Base Model:</strong> Google Gemma-270M</li> | |
| <li><strong>Vision:</strong> CLIP ViT-Large</li> | |
| <li><strong>Training:</strong> LLaVA-150K + COCO Images</li> | |
| <li><strong>Parameters:</strong> 18.6M trainable / 539M total</li> | |
| </ul> | |
| </div> | |
| """) | |
| # Model loading | |
| load_btn = gr.Button("π Load Model", variant="primary", size="lg") | |
| model_status = gr.Textbox( | |
| label="Model Status", | |
| value="Click 'Load Model' to start", | |
| interactive=False | |
| ) | |
| gr.HTML("<hr>") | |
| # Main interface | |
| with gr.Row(): | |
| # Left column - Image and controls | |
| with gr.Column(scale=1): | |
| image_input = gr.Image( | |
| label="πΈ Upload Image", | |
| type="pil", | |
| height=300 | |
| ) | |
| # Example images | |
| gr.HTML("<p><strong>π‘ Tip:</strong> Upload any image and ask questions about it</p>") | |
| # Generation settings | |
| with gr.Accordion("βοΈ Generation Settings", open=False): | |
| max_tokens = gr.Slider( | |
| minimum=10, | |
| maximum=200, | |
| value=100, | |
| step=10, | |
| label="Max Tokens" | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.1, | |
| maximum=2.0, | |
| value=0.7, | |
| step=0.1, | |
| label="Temperature" | |
| ) | |
| # Right column - Chat interface | |
| with gr.Column(scale=2): | |
| chatbot = gr.Chatbot( | |
| label="π¬ Chat with Image", | |
| height=400, | |
| show_label=True, | |
| type="messages" | |
| ) | |
| question_input = gr.Textbox( | |
| label="β Ask about the image", | |
| placeholder="What do you see in this image?", | |
| lines=2 | |
| ) | |
| with gr.Row(): | |
| submit_btn = gr.Button("π¬ Send", variant="primary") | |
| clear_btn = gr.Button("ποΈ Clear Chat") | |
| # Example prompts | |
| with gr.Row(): | |
| gr.HTML("<h3>π‘ Example Questions:</h3>") | |
| with gr.Row(): | |
| example_questions = [ | |
| "What do you see in this image?", | |
| "Describe the main objects in the picture.", | |
| "What colors are prominent in this image?", | |
| "Are there any people in the image?", | |
| "What's the setting or location?", | |
| "What objects are in the foreground?" | |
| ] | |
| for i, question in enumerate(example_questions): | |
| if i % 3 == 0: | |
| with gr.Row(): | |
| pass | |
| gr.Button( | |
| question, | |
| size="sm" | |
| ).click( | |
| lambda x=question: x, | |
| outputs=question_input | |
| ) | |
| # Footer | |
| gr.HTML(""" | |
| <hr> | |
| <div style="text-align: center; margin-top: 20px;"> | |
| <p><strong>π― Your Multimodal Gemma Model</strong></p> | |
| <p>Text-only β Vision-Language Model using LLaVA Architecture</p> | |
| <p>Model: <a href="https://huggingface.co/sagar007/multimodal-gemma-270m-llava">sagar007/multimodal-gemma-270m-llava</a></p> | |
| </div> | |
| """) | |
| # Event handlers | |
| load_btn.click( | |
| fn=download_and_load_model, | |
| outputs=model_status | |
| ) | |
| submit_btn.click( | |
| fn=chat_with_image, | |
| inputs=[image_input, question_input, chatbot, max_tokens, temperature], | |
| outputs=[chatbot, question_input] | |
| ) | |
| question_input.submit( | |
| fn=chat_with_image, | |
| inputs=[image_input, question_input, chatbot, max_tokens, temperature], | |
| outputs=[chatbot, question_input] | |
| ) | |
| clear_btn.click( | |
| fn=lambda: ([], ""), | |
| outputs=[chatbot, question_input] | |
| ) | |
| return demo | |
| def main(): | |
| """Main function to launch the Gradio app""" | |
| print("π Starting Multimodal Gemma Gradio Space...") | |
| # Create interface | |
| demo = create_gradio_interface() | |
| # Launch | |
| print("π Launching Gradio interface...") | |
| demo.launch() | |
| if __name__ == "__main__": | |
| main() |