import io import requests import traceback from typing import Optional from PIL import Image import gradio as gr import torch # tinyllava imports (from TinyLLaVA_Factory repo) from tinyllava.model.builder import load_pretrained_model from tinyllava.mm_utils import get_model_name_from_path # ---------- CONFIG ---------- # Choose one of the TinyLLaVA models from the model zoo: # - "bczhou/TinyLLaVA-3.1B" (best quality, conv_mode="phi", ~3.1B params) # - "bczhou/TinyLLaVA-2.0B" (conv_mode="phi") # - "bczhou/TinyLLaVA-1.5B" (conv_mode="v1") MODEL_PATH = "bczhou/TinyLLaVA-3.1B" # recommended defaults (you can change) DEFAULT_MAX_TOKENS = 256 DEFAULT_CONV_MODE = "phi" # ---------------------------- device = "cuda" if torch.cuda.is_available() else "cpu" # Load model/tokenizer/image_processor via TinyLLaVA loader model_name = get_model_name_from_path(MODEL_PATH) tokenizer, model, image_processor, context_len = load_pretrained_model( model_path=MODEL_PATH, model_base=None, model_name=model_name ) model.to(device) model.eval() def load_image_from_url(url: str) -> Image.Image: resp = requests.get(url, timeout=10) resp.raise_for_status() return Image.open(io.BytesIO(resp.content)).convert("RGB") def _prepare_inputs(prompt: str, image: Image.Image): # Template similar to TinyLLaVA examples prompt_text = f"USER: \n{(prompt or '').strip()}\nASSISTANT:" # Tokenize text prompt text_inputs = tokenizer(prompt_text, return_tensors="pt").to(device) # Process image (image_processor returns tensors dict like pixel_values) img_inputs = image_processor(images=image, return_tensors="pt").to(device) # Merge inputs = {**text_inputs, **img_inputs} return inputs, prompt_text def generate_text(prompt: str, uploaded_image: Optional[Image.Image], image_url: str, max_new_tokens: int = DEFAULT_MAX_TOKENS, conv_mode: str = DEFAULT_CONV_MODE): try: # select image: uploaded preferred, else URL if uploaded_image is None and image_url: image = load_image_from_url(image_url) elif uploaded_image is not None: image = uploaded_image else: return "No image provided. Upload an image or provide an image URL." inputs, prompt_text = _prepare_inputs(prompt, image) gen_kwargs = { "max_new_tokens": int(max_new_tokens), "num_beams": 1, "temperature": 0.0, } # model.generate in TinyLLaVA expects vision tensors in kwargs; inputs already include them outputs = model.generate(**inputs, **gen_kwargs) # outputs might be token ids or a tuple depending on implementation if isinstance(outputs, torch.Tensor): decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] elif isinstance(outputs, (list, tuple)): # sometimes generate returns list of tensors decoded = tokenizer.batch_decode(outputs[0], skip_special_tokens=True)[0] else: decoded = str(outputs) # extract assistant reply if "ASSISTANT:" in decoded: reply = decoded.split("ASSISTANT:")[-1].strip() else: reply = decoded.strip() return reply except Exception as e: tb = traceback.format_exc() return f"Inference error: {e}\n\nTraceback:\n{tb}" # ------- Gradio UI ------- with gr.Blocks() as demo: gr.Markdown("# TinyLLaVA Gradio Demo (upload or URL)") with gr.Row(): with gr.Column(scale=2): prompt_input = gr.Textbox(label="Prompt (optional)", placeholder="Ask about the image...") upload = gr.Image(label="Upload Image (preferred)", type="pil") url = gr.Textbox(label="Image URL (used if upload empty)", placeholder="https://...") conv_mode = gr.Dropdown(label="Conv mode", choices=["phi", "v1"], value=DEFAULT_CONV_MODE) max_tokens = gr.Slider(minimum=32, maximum=1024, step=32, value=DEFAULT_MAX_TOKENS, label="Max new tokens") run_btn = gr.Button("Generate") with gr.Column(scale=1): preview = gr.Image(label="Image preview", type="pil") out = gr.Textbox(label="Generated Text", lines=8) # preview update: show upload if present, else load URL def update_preview(uploaded, url_text): if uploaded is not None: return uploaded if url_text: try: return load_image_from_url(url_text) except Exception: return None return None upload.change(fn=update_preview, inputs=[upload, url], outputs=preview) url.change(fn=update_preview, inputs=[upload, url], outputs=preview) run_btn.click(fn=generate_text, inputs=[prompt_input, upload, url, max_tokens, conv_mode], outputs=out) if __name__ == "__main__": demo.launch()