# app.py (fixed, no concurrency_count) import os, sys, time, traceback, subprocess from typing import Tuple, Optional from PIL import Image try: import gradio as gr except ImportError: subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "gradio"]) import gradio as gr def _make_fallback(): def _fallback_answer_with_controller(image, question, source="auto", distilled_model="auto"): return "Placeholder answer (wire your models in controller.py).", "baseline", 0 return _fallback_answer_with_controller try: from controller import answer_with_controller except Exception as e: print(f"[WARN] Using fallback controller because import failed: {e}", flush=True) answer_with_controller = _make_fallback() TITLE = "VQA — Memory + RL Controller" DESCRIPTION = "Upload an image, enter a question, and the controller will choose the best decoding strategy." CONTROLLER_SOURCES = ["auto", "distilled", "ppo", "baseline"] DISTILLED_CHOICES = ["auto", "logreg", "mlp32"] def vqa_demo_fn(image: Optional[Image.Image], question: str, source: str, distilled_model: str) -> Tuple[str, str, float]: if image is None: return "Please upload an image.", "", 0.0 question = (question or "").strip() if not question: return "Please enter a question.", "", 0.0 t0 = time.perf_counter() try: image_rgb = image.convert("RGB") pred, strategy_name, action_id = answer_with_controller( image_rgb, question, source=source, distilled_model=distilled_model ) latency_ms = (time.perf_counter() - t0) * 1000.0 return str(pred), f"{action_id} → {strategy_name}", round(latency_ms, 1) except Exception as err: latency_ms = (time.perf_counter() - t0) * 1000.0 print("[ERROR] Inference failed:\n" + "".join(traceback.format_exc()), flush=True) return f"Error: {err}", "error", round(latency_ms, 1) with gr.Blocks(title=TITLE, analytics_enabled=False) as demo: gr.Markdown(f"### {TITLE}\n{DESCRIPTION}") with gr.Row(): with gr.Column(): img_in = gr.Image( type="pil", label="Image", height=320, sources=["upload", "webcam", "clipboard"], # valid ) q_in = gr.Textbox(label="Question", placeholder="e.g., What colour is the bus?", lines=2, max_lines=4) source_in = gr.Radio(CONTROLLER_SOURCES, value="auto", label="Controller Source") dist_in = gr.Radio(DISTILLED_CHOICES, value="auto", label="Distilled Gate (if used)") run_btn = gr.Button("Predict", variant="primary") with gr.Column(): ans_out = gr.Textbox(label="Answer", interactive=False, lines=3, max_lines=6) strat_out = gr.Textbox(label="Chosen Strategy", interactive=False) lat_out = gr.Number(label="Latency (ms)", precision=1, interactive=False) run_btn.click( vqa_demo_fn, inputs=[img_in, q_in, source_in, dist_in], outputs=[ans_out, strat_out, lat_out], api_name="predict", ) if __name__ == "__main__": port = int(os.getenv("PORT", "7860")) demo.queue() # no concurrency_count demo.launch(server_name="0.0.0.0", server_port=port, share=False, show_error=True)