import sys, subprocess, importlib def ensure_torch(): try: import torch return except Exception: subprocess.check_call([sys.executable, "-m", "pip", "install", "--no-cache-dir", "torch==2.2.0+cpu", "--extra-index-url", "https://download.pytorch.org/whl/cpu"]) importlib.invalidate_caches() ensure_torch() # now safe to import torch and proceed import torch # app.py - Gradio UI using vendored tinyllava import io, requests, traceback from PIL import Image import gradio as gr import torch from tinyllava.model import load_pretrained_model from tinyllava.mm_utils import get_model_name_from_path # Use CPU-friendly TinyLLaVA model recommended for Spaces MODEL_PATH = "bczhou/TinyLLaVA-1.5B" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DEFAULT_MAX_TOKENS = 128 # Lazy load _model = None _tokenizer = None _image_processor = None _context_len = None def lazy_load(): global _model, _tokenizer, _image_processor, _context_len if _model is not None: return model_name = get_model_name_from_path(MODEL_PATH) _tokenizer, _model, _image_processor, _context_len = load_pretrained_model(MODEL_PATH, model_name=model_name) _model.to(DEVICE) _model.eval() def load_image_from_url(url: str): resp = requests.get(url, timeout=10) resp.raise_for_status() return Image.open(io.BytesIO(resp.content)).convert("RGB") def prepare_inputs(prompt: str, image: Image.Image): prompt_text = f"USER: \n{(prompt or '').strip()}\nASSISTANT:" inputs = _tokenizer(prompt_text, return_tensors="pt").to(DEVICE) if _image_processor is not None: img_inputs = _image_processor(images=image, return_tensors="pt").to(DEVICE) inputs.update(img_inputs) return inputs def generate_text(prompt, upload, url, max_new_tokens=DEFAULT_MAX_TOKENS): try: lazy_load() if upload is None and url: image = load_image_from_url(url) elif upload is not None: image = upload else: return "No image provided." inputs = prepare_inputs(prompt, image) gen = _model.generate(**inputs, max_new_tokens=int(max_new_tokens), num_beams=1, temperature=0.0) if isinstance(gen, torch.Tensor): decoded = _tokenizer.batch_decode(gen, skip_special_tokens=True)[0] elif isinstance(gen, (list, tuple)): decoded = _tokenizer.batch_decode(gen[0], skip_special_tokens=True)[0] else: decoded = str(gen) if "ASSISTANT:" in decoded: return decoded.split("ASSISTANT:")[-1].strip() return decoded.strip() except Exception as e: return f"Inference error: {e}\n\n{traceback.format_exc()}" with gr.Blocks() as demo: gr.Markdown("TinyLLaVA (vendored loader) — CPU-friendly") with gr.Row(): with gr.Column(scale=2): prompt = gr.Textbox(label="Prompt (optional)") upload = gr.Image(label="Upload Image", type="pil") url = gr.Textbox(label="Image URL") max_tokens = gr.Slider(32, 512, value=DEFAULT_MAX_TOKENS, step=32, label="Max new tokens") btn = gr.Button("Generate") with gr.Column(scale=1): preview = gr.Image(label="Preview", type="pil") out = gr.Textbox(label="Output", lines=8) def update_preview(u, s): if u is not None: return u if s: try: return load_image_from_url(s) except: return None return None upload.change(update_preview, [upload, url], preview) url.change(update_preview, [upload, url], preview) btn.click(generate_text, [prompt, upload, url, max_tokens], out) if __name__ == "__main__": demo.launch()