import os, io, gc, json, re, ast from functools import lru_cache import numpy as np import pandas as pd import faiss import torch import torch.nn.functional as F from typing import List, Dict, Any from PIL import Image, ImageFilter, ImageOps, ImageEnhance import gradio as gr from huggingface_hub import hf_hub_download from sentence_transformers import SentenceTransformer from transformers import AutoTokenizer, AutoModelForCausalLM import os, torch torch.set_num_threads(2) # vCPUهای Space معمولاً 2 تاست os.environ["TOKENIZERS_PARALLELISM"] = "false" # ========================= # Config (override in Space → Settings → Variables & secrets) # ========================= DATASET_REPO = os.getenv("DATASET_REPO", "ahm1378/NLP-Project") # <--- CHANGE to your repo CSV_FILE = os.getenv("CSV_FILE", "final_merged_images.csv") E5_INDEX_FILE = os.getenv("E5_INDEX_FILE", "faiss_e5_rag_v15.ip") E5_EMB_FILE = os.getenv("E5_EMB_FILE", "doc_embeds_e5_rag_v15.npy") FUSION_INDEX_FILE = os.getenv("FUSION_INDEX_FILE", "faiss_fusion.ip") FUSION_EMB_FILE = os.getenv("FUSION_EMB_FILE", "fusion_doc_emb.npy") FT_HEAD_FILE = os.getenv("FT_HEAD_FILE", "finetune_clip_fa.pt") # your finetuned text projection (CLIP space) HF_TOKEN = os.getenv("HF_TOKEN", None) # needed if DATASET_REPO is private # Models (CPU-friendly defaults; override via env if desired) E5_ID = os.getenv("E5_ID", "intfloat/multilingual-e5-small") CLIP_TXT_ID = os.getenv("CLIP_TXT_ID", "sentence-transformers/clip-ViT-B-32-multilingual-v1") LLM_ID = os.getenv("LLM_ID", "Qwen/Qwen2-0.5B-Instruct") # خروجی کوتاه‌تر MAX_NEW_TOKENS_DEFAULT = int(os.getenv("MAX_NEW_TOKENS_DEFAULT", "96")) # قبلاً 256 # نمونه‌برداری خاموش (قطعی و سریع‌تر) TEMPERATURE_DEFAULT = float(os.getenv("TEMPERATURE_DEFAULT", "0.0")) TOP_P_DEFAULT = float(os.getenv("TOP_P_DEFAULT", "1.0")) TOP_K_DEFAULT = int(os.getenv("TOP_K_DEFAULT", "50")) # ========================= # Helpers # ========================= def normalize_digits_months(s: str) -> str: if not isinstance(s, str): s = str(s) trans = str.maketrans("۰۱۲۳۴۵۶۷۸۹٠١٢٣٤٥٦٧٨٩", "01234567890123456789") s = s.translate(trans).replace("\u200c", " ").strip() return s def _truncate_chars(s: str, limit: int) -> str: return s if (limit is None or len(s) <= limit) else s[:limit] + "…" def _maybe_hub(file, repo=DATASET_REPO, repo_type="dataset") -> str: # If present locally, use it. Otherwise download from Hub. if os.path.isfile(file): return file return hf_hub_download(repo_id=repo, filename=file, repo_type=repo_type, token=HF_TOKEN) # ========================= # Fetch artifacts # ========================= CSV_PATH = _maybe_hub(CSV_FILE) E5_INDEX_PATH = _maybe_hub(E5_INDEX_FILE) # (E5_EMB_PATH not strictly needed at runtime) FUSION_INDEX_PATH = _maybe_hub(FUSION_INDEX_FILE) if FUSION_INDEX_FILE else None FT_HEAD_PATH = _maybe_hub(FT_HEAD_FILE) if FT_HEAD_FILE else None # ========================= # Load dataframe # ========================= if not os.path.isfile(CSV_PATH): raise FileNotFoundError(f"CSV missing: {CSV_PATH}") df = pd.read_csv(CSV_PATH) # Expect columns: 'id', 'bio', 'image_paths_abs' (list or stringified list) def first_image(x): if isinstance(x, list) and x: return x[0] if isinstance(x, str) and x.strip(): # try JSON list try: lst = json.loads(x) if isinstance(lst, list) and lst: return lst[0] except Exception: # try Python literal list (handles single quotes) try: lst = ast.literal_eval(x) if isinstance(lst, list) and lst: return lst[0] except Exception: return x # treat as single path return "" if "image_paths_abs" in df.columns: df["first_image"] = df["image_paths_abs"].apply(first_image) else: df["first_image"] = "" if "bio" not in df.columns: raise KeyError("Expected 'bio' column in CSV.") df["bio"] = df["bio"].astype(str) # ========================= # Indices # ========================= if not os.path.isfile(E5_INDEX_PATH): raise FileNotFoundError(f"E5 index not found: {E5_INDEX_PATH}") index_e5 = faiss.read_index(E5_INDEX_PATH) index_fusion = None if FUSION_INDEX_PATH and os.path.isfile(FUSION_INDEX_PATH): index_fusion = faiss.read_index(FUSION_INDEX_PATH) # ========================= # Models (CPU-only) # ========================= device = "cpu" dtype = torch.float32 # Text retrieval encoder (E5) st_e5 = SentenceTransformer(E5_ID, device=device) # CLIP text encoder (fallback when no FT head) st_clip_txt = SentenceTransformer(CLIP_TXT_ID, device=device).eval() # Optional: finetuned CLIP text projection head (512->512, bias=False) mclip = SentenceTransformer(CLIP_TXT_ID, device=device).eval() proj_txt = None if FT_HEAD_PATH and os.path.isfile(FT_HEAD_PATH): try: proj_txt = torch.nn.Linear(512, 512, bias=False) ckpt = torch.load(FT_HEAD_PATH, map_location="cpu") if "proj_txt" in ckpt: proj_txt.load_state_dict(ckpt["proj_txt"]) elif "state_dict" in ckpt: proj_txt.load_state_dict(ckpt["state_dict"]) else: raise KeyError("No 'proj_txt' or 'state_dict' key in FT checkpoint.") proj_txt.eval() print("[OK] loaded finetuned projection head:", FT_HEAD_PATH) except Exception as e: print("[WARN] failed to load finetuned head:", e) proj_txt = None # Lazy CLIP image encoder (only load if user actually does fusion) clip_model = None clip_preprocess = None def _ensure_clip_loaded(): global clip_model, clip_preprocess if clip_model is None: import open_clip # lazy import model, _, preprocess_val = open_clip.create_model_and_transforms( "ViT-B-32", pretrained="laion2b_s34b_b79k", device="cpu" ) clip_model = model.eval() clip_preprocess = preprocess_val print("[OK] CLIP ViT-B/32 loaded on CPU") # LLM (small; CPU-friendly) tokenizer = AutoTokenizer.from_pretrained(LLM_ID, use_fast=True) model = AutoModelForCausalLM.from_pretrained( LLM_ID, torch_dtype=dtype, ).to("cpu").eval() # ========================= # Retrieval helpers # ========================= @lru_cache(maxsize=4096) def _encode_query_e5_cached(q: str) -> np.ndarray: qn = "query: " + normalize_digits_months(q) v = st_e5.encode([qn], batch_size=1, convert_to_numpy=True, normalize_embeddings=True)[0] return v.astype("float32") # استفاده به‌جای قدیمی: def _encode_query_e5(q: str) -> np.ndarray: return _encode_query_e5_cached(q) def _faiss_search(index, q_vec: np.ndarray, k: int): if q_vec.ndim == 1: q_vec = q_vec[None, :] s, I = index.search(q_vec.astype("float32"), k) return list(zip(I[0].tolist(), s[0].tolist())) def search_text_rag(query_text: str, k: int = 5): q = _encode_query_e5(query_text) return _faiss_search(index_e5, q, k) # ---- Fusion (CLIP space) ---- def _jpeg(img, quality=40): buf = io.BytesIO(); img.save(buf, format="JPEG", quality=quality, optimize=False) buf.seek(0); return Image.open(buf).convert("RGB") def _rand_resized_crop(img, scale=(0.7, 0.9)): w,h = img.size; s = np.random.uniform(*scale) nw,nh = max(1,int(w*s)), max(1,int(h*s)) left = np.random.randint(0, max(1, w-nw)) top = np.random.randint(0, max(1, h-nh)) return img.crop((left, top, left+nw, top+nh)).resize((w, h), Image.BICUBIC) def _color_jitter(img, b=(0.9,1.1), c=(0.9,1.1)): img = ImageOps.autocontrast(img) img = ImageEnhance.Brightness(img).enhance(np.random.uniform(*b)) img = ImageEnhance.Contrast(img).enhance(np.random.uniform(*c)) return img def augment_once(img: Image.Image, level="medium"): if level == "mild": img = _rand_resized_crop(img, (0.85, 0.95)); img = _jpeg(img, 60) elif level == "medium": img = _rand_resized_crop(img, (0.7, 0.9)) img = img.filter(ImageFilter.GaussianBlur(1.0)) img = _color_jitter(img, (0.9,1.1), (0.9,1.1)); img = _jpeg(img, 40) else: img = _rand_resized_crop(img, (0.6, 0.8)) img = img.filter(ImageFilter.GaussianBlur(1.2)); img = _jpeg(img, 30) return img @torch.no_grad() def _encode_pil_clip(img: Image.Image) -> np.ndarray: _ensure_clip_loaded() t = clip_preprocess(img).unsqueeze(0) feat = clip_model.encode_image(t) feat = F.normalize(feat.float(), dim=-1) return feat.cpu().numpy().astype("float32") # (1,512) @torch.no_grad() def _encode_query_text_clipspace(q: str) -> np.ndarray: qn = normalize_digits_months(q) if proj_txt is not None: # mclip raw → proj → normalize t = torch.tensor( mclip.encode([qn], convert_to_numpy=True, normalize_embeddings=False), dtype=torch.float32 ) t = proj_txt(t) t = F.normalize(t, dim=-1).cpu().numpy().astype("float32") return t else: # fallback: CLIP multilingual text encoder (already normalized) t = st_clip_txt.encode([qn], batch_size=1, convert_to_numpy=True, normalize_embeddings=True) return t.astype("float32") @torch.no_grad() def make_query_embed(query_text: str, image: Image.Image = None, alpha_q: float = 0.7, use_aug: bool = True, n_aug: int = 3) -> np.ndarray: qt = _encode_query_text_clipspace(query_text) # (1,512) qi = None if image is not None: if clip_model is None: # ensure loaded only if needed _ensure_clip_loaded() if use_aug: feats = [ _encode_pil_clip(augment_once(image, "medium")) for _ in range(max(1,int(n_aug))) ] qi = np.mean(np.vstack(feats), axis=0, keepdims=True).astype("float32") else: qi = _encode_pil_clip(image) if qi is not None: qv = torch.from_numpy(alpha_q*qt + (1.0-alpha_q)*qi) qv = F.normalize(qv, dim=-1).cpu().numpy().astype("float32") return qv return qt def search_fusion(query_text: str, image: Image.Image, k: int = 5, alpha_q: float = 0.7): if index_fusion is None: raise RuntimeError("Fusion index not available (upload FUSION_INDEX_FILE to dataset repo).") qv = make_query_embed(query_text, image=image, alpha_q=alpha_q, use_aug=False, n_aug=3) return _faiss_search(index_fusion, qv, k) # ========================= # RAG + LLM # ========================= def retrieve_context_auto(question: str, k: int = 5, image: Image.Image = None) -> Dict[str, Any]: q = normalize_digits_months(question) if (image is not None): route = "fusion" try: hits = search_fusion(q, image=image, k=k) except Exception as e: route = "text_e5" # graceful fallback hits = search_text_rag(q, k=k) else: route = "text_e5" hits = search_text_rag(q, k=k) ctxs = [] for idx, score in hits: if 0 <= idx < len(df): row = df.iloc[idx] ctxs.append({"index": int(idx), "id": row.get("id", idx), "score": float(score), "bio": str(row["bio"])}) return {"route": route, "contexts": ctxs} def build_prompt(question: str, contexts: List[Dict[str, Any]], lang="fa", max_chars=1800) -> str: sys_fa = "تو یک دستیار پاسخ‌گو هستی که فقط بر اساس متن‌های داده‌شده پاسخ می‌دهی. اگر پاسخی در متن‌ها نبود، صادقانه بگو «در متن‌های بازیابی‌شده پاسخی پیدا نشد.»" sys_en = "You are a helpful assistant. Answer only using retrieved passages. If not found, say 'No answer found in retrieved passages.'" system_text = sys_fa if lang == "fa" else sys_en parts = [] for i, c in enumerate(contexts, 1): bi = c["bio"].strip() if bi: parts.append(f"[{i}] {bi}") joined = _truncate_chars("\n\n".join(parts), max_chars) user = (f"سؤال: {question}\n\nمتون بازیابی‌شده:\n{joined}\n\n" f"فقط با اتکا به متون بالا پاسخ بده و منابع را با [1], [2], ... ارجاع بده." ) if lang == "fa" else ( f"Question: {question}\n\nRetrieved passages:\n{joined}\n\n" f"Answer only using the passages, cite sources as [1], [2], ..." ) msgs = [{"role": "system", "content": system_text}, {"role": "user", "content": user}] return tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True) @torch.inference_mode() def llm_generate(prompt: str, max_new_tokens=96, temperature=0.0, top_p=1.0, top_k=50, do_sample=False) -> str: inputs = tokenizer(prompt, return_tensors="pt") inputs = {k: v.to(model.device) for k, v in inputs.items()} out = model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=False, # قطعی temperature=temperature, top_p=top_p, top_k=top_k, num_beams=1, # بدون beam-search use_cache=True, # سریع‌تر pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, ) text = tokenizer.decode(out[0], skip_special_tokens=True) if text.startswith(prompt): text = text[len(prompt):] return text.strip() # ---- MCQ helpers ---- def build_mcq_prompt(question: str, options: List[str], contexts: List[Dict[str, Any]], lang="fa", max_chars=1800) -> str: sys_fa = ( "تو یک دستیار پاسخ‌گو هستی که فقط بر اساس متن‌های داده‌شده پاسخ می‌دهی. " "باید دقیقاً فقط یک شیء JSON برگردانی و هیچ متن دیگری ننویسی." ) sys_en = ( "You are a helpful assistant. Answer ONLY using the retrieved passages. " "You MUST return a single JSON object and nothing else." ) system_text = sys_fa if lang == "fa" else sys_en parts = [] for i, c in enumerate(contexts, 1): bi = c["bio"].strip() if bi: parts.append(f"[{i}] {bi}") joined = _truncate_chars("\n\n".join(parts), max_chars) labels = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" opts_str = "\n".join([f"{labels[i]}) {o}" for i, o in enumerate(options)]) if lang == "fa": user = ( f"سؤال: {question}\n\nگزینه‌ها:\n{opts_str}\n\nمتون بازیابی‌شده:\n{joined}\n\n" "دقیقاً و فقط یک JSON برگردان. فرمت اجباری: " '{"answer_index": X, "reason": "…"} ' "که در آن X اندیس گزینه (۰-بِیس) است. هیچ متن دیگری ننویس." ) else: user = ( f"Question: {question}\n\nOptions:\n{opts_str}\n\nRetrieved:\n{joined}\n\n" 'Return EXACTLY one JSON: {"answer_index": X, "reason": "..."} ' "where X is the 0-based option index. Do not write anything else." ) msgs = [{"role": "system", "content": system_text}, {"role": "user", "content": user}] return tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True) import json as _json import re as _re import numpy as _np def _strict_json_from_text(text: str): # فقط اولین بلاک {...} را بگیر و JSON-parse کن m = _re.search(r'\{.*\}', text, _re.S) if not m: return None frag = m.group(0) try: obj = _json.loads(frag) return obj except Exception: return None import re as _re import numpy as _np def _norm_text_for_match(s: str) -> str: # نرمال‌سازی ساده: اعداد فارسی/عربی، ZWNJ، فاصله‌های اضافه s = normalize_digits_months(s or "") s = s.replace("\u200c", " ").strip() # پایین‌حرفی و تک‌فاصله s = _re.sub(r"\s+", " ", s.lower()) return s def _find_snippet(hay: str, needle: str, win: int = 60) -> str: """یک تکه متن کوتاه اطراف اولین مچ را بده.""" try: i = hay.index(needle) start = max(0, i - win) end = min(len(hay), i + len(needle) + win) return hay[start:end].replace("\n", " ") except ValueError: return "" def score_options_by_context( options: List[str], contexts: List[Dict[str, Any]], return_snippet: bool = False ): """ فال‌بک هوشمند: 1) boundary-aware substring در تک‌تک کانتکست‌ها (امتیاز بالا + تعداد وقوع) 2) اگر هیچ مچی نبود → شباهت embedding با mE5 بین هر گزینه و کل کانتکست‌ها خروجی: - اگر return_snippet=False → فقط best_idx (int) - اگر return_snippet=True → (best_idx, snippet) برمی‌گرداند """ # آماده‌سازی کانتکست‌ها raw_ctxs = [c.get("bio", "") for c in contexts] norm_ctxs = [_norm_text_for_match(x) for x in raw_ctxs] joined_norm = " \n ".join(norm_ctxs) # 1) جست‌وجوی دقیق‌تر: word boundary + شمارش # برای فارسی/عربی هم خوب جواب می‌دهد چون از فاصله استفاده می‌کنیم. best_idx, best_score, best_snip = 0, -1.0, "" for i, opt in enumerate(options): o_raw = str(opt).strip() o = _norm_text_for_match(o_raw) if not o: continue # الگوی boundary ساده: (شروع/فاصله) + عبارت + (پایان/فاصله) # اگر گزینه چندکلمه‌ای است، همین هم خوب جواب می‌دهد. # اگر لازم شد می‌توان regex دقیق‌تر نوشت. pat = r"(? 0: # امتیاز بالا برای مچ صریح + تعداد وقوع score = 10000.0 + total_hits if score > best_score: best_score, best_idx, best_snip = score, i, first_snip if best_score > 0: return (best_idx, best_snip) if return_snippet else best_idx # 2) اگر هیچ مچی نبود → شباهت embedding (mE5) try: # وکتور کل کانتکست‌ها (یک‌بار) ctx_vec = _encode_query_e5(joined_norm) # (dim,) sims = [] for opt in options: qv = _encode_query_e5(str(opt)) sims.append(float(_np.dot(qv, ctx_vec))) best_idx = int(_np.argmax(sims)) # برای snippet در این مسیر: نزدیک‌ترین کانتکست را با dot جداگانه پیدا کنیم # (سریع و به‌اندازه کافی خوب) best_snip = "" try: opt_vec = _encode_query_e5(str(options[best_idx])) # کوساین تقریباً همان inner-prod چون نرمال شده‌اند # امتیاز هر کانتکست با گزینه‌ی برنده: c_scores = [] for raw, norm in zip(raw_ctxs, norm_ctxs): c_vec = _encode_query_e5(norm) c_scores.append(float(_np.dot(opt_vec, c_vec))) j = int(_np.argmax(c_scores)) best_snip = _find_snippet(raw_ctxs[j], str(options[best_idx])) or raw_ctxs[j][:120].replace("\n"," ") except Exception: pass return (best_idx, best_snip) if return_snippet else best_idx except Exception: return (0, "") if return_snippet else 0 # پیش‌فرض محافظه‌کارانه def parse_mcq_output_strict(text: str, options: List[str], contexts: List[Dict[str, Any]]) -> Dict[str, Any]: obj = _strict_json_from_text(text) if obj and "answer_index" in obj: idx = obj["answer_index"] if isinstance(idx, int) and 0 <= idx < len(options): reason = str(obj.get("reason", "")).strip() or "—" return {"answer_index": idx, "reason": reason} idx, snip = score_options_by_context(options, contexts, return_snippet=True) return {"answer_index": idx, "reason": snip or "matched by context"} def parse_mcq_output(text: str, n: int) -> Dict[str, Any]: m = re.search(r'{"\s*answer_index"\s*:\s*([0-9]+)\s*,\s*"reason"\s*:\s*"(.*?)"}', text, re.S) if m: idx = int(m.group(1)); reason = m.group(2).strip() if 0 <= idx < n: return {"answer_index": idx, "reason": reason} letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" m2 = re.search(r'\b([A-D])\b', text, re.I) if m2: idx = letters.index(m2.group(1).upper()) if idx < n: return {"answer_index": idx, "reason": text.strip()} m3 = re.search(r'\b([1-9])\b', text) if m3: idx = int(m3.group(1)) - 1 if 0 <= idx < n: return {"answer_index": idx, "reason": text.strip()} return {"answer_index": None, "reason": text.strip()} # ========================= # Gradio UI # ========================= def ui_answer(question, image, topk, max_tokens, temperature, top_p, top_k): if not question or not question.strip(): return "Please enter a question.", [], "" # Retrieve ret = retrieve_context_auto(question, k=int(topk), image=image) prompt = build_prompt(question, ret["contexts"], lang="fa", max_chars=1800) ans = llm_generate(prompt, max_new_tokens=int(max_tokens), temperature=float(temperature), top_p=float(top_p), top_k=int(top_k), do_sample=False) # Sources rows = [] for i, c in enumerate(ret["contexts"], 1): snip = c["bio"][:180] + ("…" if len(c["bio"]) > 180 else "") rows.append([i, c["id"], round(c["score"], 4), snip]) return ans, rows, ret["route"] def ui_mcq(question, options_txt, image, topk, max_tokens, temperature, top_p, top_k): opts = [o.strip() for o in (options_txt or "").splitlines() if o.strip()] if not question or len(opts) < 2: return "Provide a question and at least 2 options.", "", [], "" ret = retrieve_context_auto(question, k=int(topk), image=image) prompt = build_mcq_prompt(question, opts, ret["contexts"], lang="fa", max_chars=5000) out = llm_generate(prompt, max_new_tokens=int(max_tokens), temperature=float(temperature), top_p=float(top_p), top_k=int(top_k), do_sample=False) # deterministic on CPU parsed = parse_mcq_output_strict(out, opts, ret["contexts"]) pred = parsed["answer_index"] pred_text = (opts[pred] if (pred is not None and 0 <= pred < len(opts)) else "N/A") rows = [] for i, c in enumerate(ret["contexts"], 1): snip = c["bio"][:180] + ("…" if len(c["bio"]) > 180 else "") rows.append([i, c["id"], round(c["score"], 4), snip]) result = f"Pred: index={pred} text={pred_text}\nReason: {parsed['reason']}" return result, out, rows, ret["route"] with gr.Blocks(title="Multimodal RAG (CPU) • E5 + CLIP Fusion + Qwen 0.5B") as demo: gr.Markdown("### Free-tier CPU demo: text RAG (E5) + optional fusion (CLIP) → Qwen 0.5B") with gr.Tab("Ask"): with gr.Row(): q = gr.Textbox(label="Question", lines=3) img = gr.Image(type="pil", label="Optional image") use_fusion = gr.Checkbox(label="Use image fusion (slower on CPU)", value=False) with gr.Row(): topk = gr.Slider(1, 20, value=3, step=1, label="Top-K retrieve") max_tokens = gr.Slider(16, 512, value=96, step=16, label="Max new tokens") with gr.Row(): temperature = gr.Slider(0.0, 1.0, value=0.0, step=0.1, label="Temperature") top_p = gr.Slider(0.1, 1.0, value=1.0, step=0.05, label="Top-p") top_k = gr.Slider(1, 100, value=50, step=1, label="Top-k") btn = gr.Button("Answer") ans = gr.Textbox(label="Answer", lines=8) route = gr.Textbox(label="Route used (text_e5 or fusion)") table = gr.Dataframe(headers=["#", "id", "score", "snippet"], interactive=False) btn.click(ui_answer, [q, img, use_fusion, topk, max_tokens, temperature, top_p, top_k], [ans, table, route]) with gr.Tab("MCQ"): with gr.Row(): q_mcq = gr.Textbox(label="Question", lines=3) opts_mcq = gr.Textbox(label="Options (one per line)", lines=8) img_mcq = gr.Image(type="pil", label="Optional image (fusion if enabled)") with gr.Row(): topk2 = gr.Slider(1, 20, value=3, step=1, label="Top-K retrieve") max_tokens2 = gr.Slider(16, 512, value=96, step=16, label="Max new tokens") with gr.Row(): temperature2 = gr.Slider(0.0, 1.0, value=0.0, step=0.1, label="Temperature") top_p2 = gr.Slider(0.1, 1.0, value=1.0, step=0.05, label="Top-p") top_k2 = gr.Slider(1, 100, value=50, step=1, label="Top-k") btn2 = gr.Button("Answer MCQ") result = gr.Textbox(label="Prediction", lines=12, max_lines=20) raw = gr.Textbox(label="Raw LLM output", lines=12, max_lines=20) route2 = gr.Textbox(label="Route used") table2 = gr.Dataframe(headers=["#", "id", "score", "snippet"], interactive=False) btn2.click(ui_mcq, [q_mcq, opts_mcq, img_mcq, topk2, max_tokens2, temperature2, top_p2, top_k2], [result, raw, table2, route2]) demo.launch()