sagar007's picture
Upload app.py with huggingface_hub
18b63c5 verified
raw
history blame
15.2 kB
#!/usr/bin/env python3
"""
Gradio UI for Multimodal Gemma Model - Hugging Face Space Version
Fixed: Added all missing modules (projectors.py, lightning_module.py, logging.py, data/, training/)
Updated requirements.txt with rich and datasets libraries
"""
import sys
import torch
import gradio as gr
from pathlib import Path
from PIL import Image
import io
import time
import logging
from huggingface_hub import hf_hub_download
# Model imports
from src.models import MultimodalGemmaLightning
from src.utils.config import load_config, merge_configs
# Global model variable
model = None
config = None
def download_and_load_model():
"""Download and load the trained multimodal model from HF"""
global model, config
if model is not None:
return "βœ… Model already loaded!"
try:
print("πŸ”„ Downloading multimodal Gemma model from HF...")
# Download model checkpoint
checkpoint_path = hf_hub_download(
repo_id="sagar007/multimodal-gemma-270m-llava",
filename="final_model.ckpt",
cache_dir="./model_cache"
)
print("πŸ“ Loading model from checkpoint...")
# Load checkpoint data to inspect what's inside
checkpoint = torch.load(checkpoint_path, map_location="cpu")
print(f"Checkpoint keys: {list(checkpoint.keys())}")
# Extract the saved hyperparameters if they exist
if "hyper_parameters" in checkpoint:
saved_config = checkpoint["hyper_parameters"].get("config", {})
print("Found saved config in checkpoint")
# Override any gated models in the saved config
if "model" in saved_config and "gemma_model_name" in saved_config["model"]:
if "google/gemma" in saved_config["model"]["gemma_model_name"]:
print("Replacing gated Gemma model with accessible alternative")
saved_config["model"]["gemma_model_name"] = "microsoft/DialoGPT-medium"
saved_config["model"]["use_4bit"] = False # Disable quantization for compatibility
config = saved_config
else:
print("No saved config found, creating minimal config")
# Create minimal config for loading
config = {
"model": {
"gemma_model_name": "microsoft/DialoGPT-medium", # Use non-gated model
"vision_model_name": "openai/clip-vit-large-patch14",
"use_4bit": False, # Disable quantization for loading
"projector_hidden_dim": 2048,
"lora": {"r": 16, "alpha": 32, "dropout": 0.1}
},
"special_tokens": {"image_token": "<image>"},
"training": {"projector_lr": 1e-3, "lora_lr": 1e-4}
}
try:
# First try: Use the checkpoint's config if available
model = MultimodalGemmaLightning.load_from_checkpoint(
checkpoint_path,
config=config,
strict=False,
map_location="cuda" if torch.cuda.is_available() else "cpu"
)
print("βœ… Loaded with checkpoint config")
except Exception as e1:
print(f"Failed with checkpoint config: {e1}")
try:
# Second try: Minimal config with no quantization
minimal_config = {
"model": {
"gemma_model_name": "microsoft/DialoGPT-small", # Even smaller model
"vision_model_name": "openai/clip-vit-base-patch32", # Smaller CLIP
"use_4bit": False, # No quantization
"projector_hidden_dim": 512,
"lora": {"r": 8, "alpha": 16, "dropout": 0.1, "target_modules": ["q_proj", "v_proj"]}
},
"special_tokens": {"image_token": "<image>"},
"training": {"projector_lr": 1e-3, "lora_lr": 1e-4}
}
model = MultimodalGemmaLightning.load_from_checkpoint(
checkpoint_path,
config=minimal_config,
strict=False,
map_location="cuda" if torch.cuda.is_available() else "cpu"
)
print("βœ… Loaded with minimal config")
except Exception as e2:
print(f"Failed with minimal config: {e2}")
try:
# Third try: Direct state dict loading
print("Attempting direct state dict loading...")
# Create a dummy model just to get the structure
dummy_config = {
"model": {
"gemma_model_name": "microsoft/DialoGPT-small",
"vision_model_name": "openai/clip-vit-base-patch32",
"use_4bit": False,
"projector_hidden_dim": 512,
},
"special_tokens": {"image_token": "<image>"},
"training": {"projector_lr": 1e-3, "lora_lr": 1e-4}
}
model = MultimodalGemmaLightning(dummy_config)
# Load only compatible weights
checkpoint_state = checkpoint['state_dict']
model_state = model.state_dict()
# Filter and load compatible weights
compatible_weights = {}
for key, value in checkpoint_state.items():
if key in model_state and model_state[key].shape == value.shape:
compatible_weights[key] = value
model.load_state_dict(compatible_weights, strict=False)
print(f"βœ… Loaded {len(compatible_weights)} compatible weights")
except Exception as e3:
print(f"All loading methods failed: {e3}")
return f"❌ Model loading failed - checkpoint incompatible. Last error: {str(e3)}"
model.eval()
# Move to appropriate device
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
print(f"βœ… Model loaded successfully on {device}!")
return f"βœ… Model loaded successfully on {device}!"
except Exception as e:
error_msg = f"❌ Error loading model: {str(e)}"
print(error_msg)
return error_msg
def predict_with_image(image, question, max_tokens=100, temperature=0.7):
"""Generate response for image + text input"""
global model, config
if model is None:
return "❌ Please load the model first using the 'Load Model' button!"
if image is None:
return "❌ Please upload an image!"
if not question.strip():
question = "What do you see in this image?"
try:
# Get device
device = next(model.parameters()).device
# Process image
if isinstance(image, str):
image = Image.open(image).convert('RGB')
elif not isinstance(image, Image.Image):
image = Image.fromarray(image).convert('RGB')
# Prepare image for model
vision_inputs = model.model.vision_processor(
images=[image],
return_tensors="pt"
)
pixel_values = vision_inputs["pixel_values"].to(device)
# Prepare text prompt
prompt = f"<image>\\nHuman: {question}\\nAssistant:"
# Tokenize text
text_inputs = model.model.tokenizer(
prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=256
)
input_ids = text_inputs["input_ids"].to(device)
attention_mask = text_inputs["attention_mask"].to(device)
# Generate response
with torch.no_grad():
# Use the full multimodal model with image inputs
outputs = model.model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
images=pixel_values,
max_new_tokens=min(max_tokens, 150),
temperature=min(max(temperature, 0.1), 2.0),
do_sample=temperature > 0.1,
repetition_penalty=1.1
)
# Decode response
input_length = input_ids.shape[1]
generated_tokens = outputs[0][input_length:]
response = model.model.tokenizer.decode(generated_tokens, skip_special_tokens=True)
# Clean up response
response = response.strip()
if not response:
response = "I can see the image, but I'm having trouble generating a detailed response."
return response
except Exception as e:
error_msg = f"❌ Error during inference: {str(e)}"
print(error_msg)
return error_msg
def chat_with_image(image, question, history, max_tokens, temperature):
"""Chat interface function"""
if model is None:
response = "❌ Please load the model first!"
else:
response = predict_with_image(image, question, max_tokens, temperature)
# Add to history - using messages format
history.append({"role": "user", "content": question})
history.append({"role": "assistant", "content": response})
return history, ""
def create_gradio_interface():
"""Create the Gradio interface"""
# Custom CSS for better styling
css = """
.container {
max-width: 1200px;
margin: auto;
padding: 20px;
}
.header {
text-align: center;
margin-bottom: 30px;
}
.model-info {
background-color: #f0f8ff;
padding: 15px;
border-radius: 10px;
margin-bottom: 20px;
}
"""
with gr.Blocks(css=css, title="Multimodal Gemma Chat") as demo:
gr.HTML("""
<div class="header">
<h1>πŸŽ‰ Multimodal Gemma-270M Chat</h1>
<p>Upload an image and chat with your trained vision-language model!</p>
<p><a href="https://huggingface.co/sagar007/multimodal-gemma-270m-llava">πŸ€— Model</a></p>
</div>
""")
# Model status section
with gr.Row():
with gr.Column():
gr.HTML("""
<div class="model-info">
<h3>πŸ“Š Model Info</h3>
<ul>
<li><strong>Base Model:</strong> Google Gemma-270M</li>
<li><strong>Vision:</strong> CLIP ViT-Large</li>
<li><strong>Training:</strong> LLaVA-150K + COCO Images</li>
<li><strong>Parameters:</strong> 18.6M trainable / 539M total</li>
</ul>
</div>
""")
# Model loading
load_btn = gr.Button("πŸš€ Load Model", variant="primary", size="lg")
model_status = gr.Textbox(
label="Model Status",
value="Click 'Load Model' to start",
interactive=False
)
gr.HTML("<hr>")
# Main interface
with gr.Row():
# Left column - Image and controls
with gr.Column(scale=1):
image_input = gr.Image(
label="πŸ“Έ Upload Image",
type="pil",
height=300
)
# Example images
gr.HTML("<p><strong>πŸ’‘ Tip:</strong> Upload any image and ask questions about it</p>")
# Generation settings
with gr.Accordion("βš™οΈ Generation Settings", open=False):
max_tokens = gr.Slider(
minimum=10,
maximum=200,
value=100,
step=10,
label="Max Tokens"
)
temperature = gr.Slider(
minimum=0.1,
maximum=2.0,
value=0.7,
step=0.1,
label="Temperature"
)
# Right column - Chat interface
with gr.Column(scale=2):
chatbot = gr.Chatbot(
label="πŸ’¬ Chat with Image",
height=400,
show_label=True,
type="messages"
)
question_input = gr.Textbox(
label="❓ Ask about the image",
placeholder="What do you see in this image?",
lines=2
)
with gr.Row():
submit_btn = gr.Button("πŸ’¬ Send", variant="primary")
clear_btn = gr.Button("πŸ—‘οΈ Clear Chat")
# Example prompts
with gr.Row():
gr.HTML("<h3>πŸ’‘ Example Questions:</h3>")
with gr.Row():
example_questions = [
"What do you see in this image?",
"Describe the main objects in the picture.",
"What colors are prominent in this image?",
"Are there any people in the image?",
"What's the setting or location?",
"What objects are in the foreground?"
]
for i, question in enumerate(example_questions):
if i % 3 == 0:
with gr.Row():
pass
gr.Button(
question,
size="sm"
).click(
lambda x=question: x,
outputs=question_input
)
# Footer
gr.HTML("""
<hr>
<div style="text-align: center; margin-top: 20px;">
<p><strong>🎯 Your Multimodal Gemma Model</strong></p>
<p>Text-only β†’ Vision-Language Model using LLaVA Architecture</p>
<p>Model: <a href="https://huggingface.co/sagar007/multimodal-gemma-270m-llava">sagar007/multimodal-gemma-270m-llava</a></p>
</div>
""")
# Event handlers
load_btn.click(
fn=download_and_load_model,
outputs=model_status
)
submit_btn.click(
fn=chat_with_image,
inputs=[image_input, question_input, chatbot, max_tokens, temperature],
outputs=[chatbot, question_input]
)
question_input.submit(
fn=chat_with_image,
inputs=[image_input, question_input, chatbot, max_tokens, temperature],
outputs=[chatbot, question_input]
)
clear_btn.click(
fn=lambda: ([], ""),
outputs=[chatbot, question_input]
)
return demo
def main():
"""Main function to launch the Gradio app"""
print("πŸš€ Starting Multimodal Gemma Gradio Space...")
# Create interface
demo = create_gradio_interface()
# Launch
print("🌐 Launching Gradio interface...")
demo.launch()
if __name__ == "__main__":
main()