Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| import os | |
| import tempfile | |
| import torch | |
| import gradio as gr | |
| import traceback | |
| import sys | |
| import logging | |
| from PIL import Image | |
| from models.llava import LLaVA | |
| from typing import Dict, Any, Optional, Union | |
| # Set up logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
| handlers=[ | |
| logging.StreamHandler(sys.stdout), | |
| logging.FileHandler('app.log') | |
| ] | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Initialize FastAPI app | |
| app = FastAPI(title="LLaVA Web Interface") | |
| # Configure CORS | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Global state | |
| model = None | |
| model_status: Dict[str, Any] = { | |
| "initialized": False, | |
| "device": None, | |
| "error": None, | |
| "last_error": None | |
| } | |
| async def global_exception_handler(request: Request, exc: Exception): | |
| """Global exception handler to catch and log all unhandled exceptions.""" | |
| error_msg = f"Unhandled error: {str(exc)}\n{traceback.format_exc()}" | |
| logger.error(error_msg) | |
| model_status["last_error"] = error_msg | |
| return JSONResponse( | |
| status_code=500, | |
| content={"error": "Internal Server Error", "details": str(exc)} | |
| ) | |
| async def get_status(): | |
| """Endpoint to check model and application status.""" | |
| return { | |
| "model_initialized": model is not None, | |
| "model_status": model_status, | |
| "memory_usage": { | |
| "cuda_available": torch.cuda.is_available(), | |
| "cuda_memory_allocated": torch.cuda.memory_allocated() if torch.cuda.is_available() else 0, | |
| "cuda_memory_reserved": torch.cuda.memory_reserved() if torch.cuda.is_available() else 0 | |
| } | |
| } | |
| def initialize_model(): | |
| """Initialize the LLaVA model with proper error handling.""" | |
| global model, model_status | |
| try: | |
| logger.info("Starting model initialization...") | |
| model_status["initialized"] = False | |
| model_status["error"] = None | |
| # Clear any existing model and memory | |
| if model is not None: | |
| del model | |
| torch.cuda.empty_cache() | |
| # Set device | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| logger.info(f"Using device: {device}") | |
| # Initialize new model with basic parameters | |
| model = LLaVA( | |
| vision_model_path="openai/clip-vit-base-patch32", | |
| language_model_path="TinyLlama/TinyLlama-1.1B-Chat-v1.0", | |
| projection_hidden_dim=2048, | |
| device=device | |
| ) | |
| # Configure model for inference | |
| if hasattr(model, 'language_model'): | |
| # Set model to evaluation mode | |
| model.language_model.eval() | |
| # Configure model for inference | |
| if hasattr(model.language_model, 'config'): | |
| model.language_model.config.use_cache = False | |
| # Move model to device | |
| model = model.to(device) | |
| # Set generation config if available | |
| if hasattr(model.language_model, 'generation_config'): | |
| model.language_model.generation_config.do_sample = True | |
| model.language_model.generation_config.max_new_tokens = 256 | |
| model.language_model.generation_config.temperature = 0.7 | |
| model.language_model.generation_config.top_p = 0.9 | |
| if hasattr(model.language_model.config, 'eos_token_id'): | |
| model.language_model.generation_config.pad_token_id = model.language_model.config.eos_token_id | |
| model_status.update({ | |
| "initialized": True, | |
| "device": str(model.device), | |
| "error": None, | |
| "model_info": { | |
| "vision_model": "openai/clip-vit-base-patch32", | |
| "language_model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", | |
| "device": str(model.device) | |
| } | |
| }) | |
| logger.info(f"Model successfully initialized on {model.device}") | |
| return True | |
| except Exception as e: | |
| error_msg = f"Model initialization failed: {str(e)}" | |
| logger.error(error_msg) | |
| logger.error(traceback.format_exc()) | |
| model = None | |
| model_status.update({ | |
| "initialized": False, | |
| "error": error_msg, | |
| "last_error": traceback.format_exc() | |
| }) | |
| return False | |
| def process_image( | |
| image: Optional[Image.Image], | |
| prompt: str, | |
| max_new_tokens: int = 256, | |
| temperature: float = 0.7, | |
| top_p: float = 0.9 | |
| ) -> str: | |
| """Process an image with the LLaVA model with comprehensive error handling.""" | |
| global model_status | |
| logger.info("Starting image processing...") | |
| # Validate model state | |
| if model is None: | |
| logger.error("Model not initialized") | |
| if not initialize_model(): | |
| model_status["last_error"] = "Model initialization failed during processing" | |
| return "Error: Model initialization failed. Please try again later." | |
| # Validate inputs | |
| if image is None: | |
| logger.error("No image provided") | |
| return "Error: Please upload an image first." | |
| if not isinstance(image, Image.Image): | |
| logger.error(f"Invalid image type: {type(image)}") | |
| return "Error: Invalid image format. Please upload a valid image." | |
| if not prompt or not isinstance(prompt, str) or not prompt.strip(): | |
| logger.error("Invalid prompt") | |
| return "Error: Please enter a valid prompt." | |
| # Validate parameters | |
| try: | |
| max_new_tokens = int(max_new_tokens) | |
| temperature = float(temperature) | |
| top_p = float(top_p) | |
| except (ValueError, TypeError) as e: | |
| logger.error(f"Invalid parameters: {str(e)}") | |
| return "Error: Invalid generation parameters." | |
| temp_path = None | |
| try: | |
| logger.info(f"Processing image with prompt: {prompt[:100]}...") | |
| # Save image with explicit format | |
| with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as temp_file: | |
| image.save(temp_file.name, format='PNG') | |
| temp_path = temp_file.name | |
| logger.info(f"Saved temporary image to {temp_path}") | |
| # Clear memory | |
| torch.cuda.empty_cache() | |
| # Process image with Hugging Face specific settings | |
| with torch.inference_mode(): | |
| try: | |
| logger.info("Generating response...") | |
| # Update generation config if available | |
| if hasattr(model, 'language_model') and hasattr(model.language_model, 'generation_config'): | |
| model.language_model.generation_config.max_new_tokens = max_new_tokens | |
| model.language_model.generation_config.temperature = temperature | |
| model.language_model.generation_config.top_p = top_p | |
| response = model.generate_from_image( | |
| image_path=temp_path, | |
| prompt=prompt, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| do_sample=True, | |
| num_beams=1, | |
| pad_token_id=model.language_model.config.eos_token_id if hasattr(model, 'language_model') else None | |
| ) | |
| if not response: | |
| raise ValueError("Empty response from model") | |
| if not isinstance(response, str): | |
| raise ValueError(f"Invalid response type: {type(response)}") | |
| logger.info("Successfully generated response") | |
| model_status["last_error"] = None | |
| return response | |
| except Exception as model_error: | |
| error_msg = f"Model inference error: {str(model_error)}" | |
| logger.error(error_msg) | |
| logger.error(traceback.format_exc()) | |
| model_status["last_error"] = error_msg | |
| return f"Error during model inference: {str(model_error)}" | |
| except Exception as e: | |
| error_msg = f"Processing error: {str(e)}" | |
| logger.error(error_msg) | |
| logger.error(traceback.format_exc()) | |
| model_status["last_error"] = error_msg | |
| return f"Error processing image: {str(e)}" | |
| finally: | |
| # Cleanup | |
| if temp_path and os.path.exists(temp_path): | |
| try: | |
| os.unlink(temp_path) | |
| logger.info("Cleaned up temporary file") | |
| except Exception as e: | |
| logger.warning(f"Failed to clean up temporary file: {str(e)}") | |
| try: | |
| torch.cuda.empty_cache() | |
| except Exception as e: | |
| logger.warning(f"Failed to clear CUDA cache: {str(e)}") | |
| def create_interface(): | |
| """Create a simplified Gradio interface.""" | |
| try: | |
| with gr.Blocks(title="LLaVA Chat", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # LLaVA Chat | |
| Upload an image and chat with LLaVA about it. This model can understand and describe images, answer questions about them, and engage in visual conversations. | |
| ## Example Prompts | |
| Try these prompts to get started: | |
| - "What can you see in this image?" | |
| - "Describe this scene in detail" | |
| - "What emotions does this image convey?" | |
| - "What's happening in this picture?" | |
| - "Can you identify any objects or people in this image?" | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| # Input components | |
| image_input = gr.Image(type="pil", label="Upload Image") | |
| prompt_input = gr.Textbox( | |
| label="Ask about the image", | |
| placeholder="What can you see in this image?", | |
| lines=3 | |
| ) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| max_tokens = gr.Slider( | |
| minimum=32, | |
| maximum=512, | |
| value=256, | |
| step=32, | |
| label="Max New Tokens" | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.7, | |
| step=0.1, | |
| label="Temperature" | |
| ) | |
| top_p = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.9, | |
| step=0.1, | |
| label="Top P" | |
| ) | |
| submit_btn = gr.Button("Generate Response", variant="primary") | |
| with gr.Column(): | |
| output = gr.Textbox( | |
| label="Model Response", | |
| lines=10, | |
| show_copy_button=True | |
| ) | |
| # Set up event handler | |
| submit_btn.click( | |
| fn=process_image, | |
| inputs=[ | |
| image_input, | |
| prompt_input, | |
| max_tokens, | |
| temperature, | |
| top_p | |
| ], | |
| outputs=output | |
| ) | |
| logger.info("Successfully created Gradio interface") | |
| return demo | |
| except Exception as e: | |
| logger.error(f"Failed to create interface: {str(e)}") | |
| logger.error(traceback.format_exc()) | |
| raise | |
| # Create and mount Gradio app | |
| try: | |
| logger.info("Creating Gradio interface...") | |
| demo = create_interface() | |
| app = gr.mount_gradio_app(app, demo, path="/") | |
| logger.info("Successfully mounted Gradio app") | |
| except Exception as e: | |
| logger.error(f"Failed to mount Gradio app: {str(e)}") | |
| logger.error(traceback.format_exc()) | |
| raise | |
| if __name__ == "__main__": | |
| try: | |
| # Initialize model | |
| logger.info("Starting application...") | |
| if not initialize_model(): | |
| logger.error("Model initialization failed. Exiting...") | |
| sys.exit(1) | |
| # Start server | |
| import uvicorn | |
| logger.info("Starting server...") | |
| uvicorn.run( | |
| app, | |
| host="0.0.0.0", | |
| port=7860, | |
| log_level="info" | |
| ) | |
| except Exception as e: | |
| logger.error(f"Application startup failed: {str(e)}") | |
| logger.error(traceback.format_exc()) | |
| sys.exit(1) |