Xlordo commited on
Commit
bf736ad
Β·
verified Β·
1 Parent(s): 87f8dbb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -84
app.py CHANGED
@@ -1,95 +1,86 @@
1
  import gradio as gr
2
- import faiss
3
- import numpy as np
4
  from datasets import load_dataset
5
- from sentence_transformers import SentenceTransformer
6
- from sklearn.metrics import ndcg_score
7
-
8
- # ----------------------------
9
- # Load dataset (MS MARCO v1.1)
10
- # ----------------------------
11
- dataset = load_dataset("ms_marco", "v1.1", split="train[:10000]")
12
- passages = [item["passage"] for item in dataset]
13
- print(f"Loaded {len(passages)} passages")
14
 
15
- # ----------------------------
16
  # Load SBERT model
17
- # ----------------------------
18
  model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
19
 
20
- # ----------------------------
21
- # Build FAISS index
22
- # ----------------------------
23
- embeddings = model.encode(passages, convert_to_numpy=True, show_progress_bar=True)
24
- dimension = embeddings.shape[1]
25
- index = faiss.IndexFlatL2(dimension)
26
- index.add(embeddings)
27
- print("FAISS index built with", index.ntotal, "passages")
28
-
29
- # ----------------------------
30
- # Search function
31
- # ----------------------------
32
- def search(query, k=10):
33
- query_vec = model.encode([query], convert_to_numpy=True)
34
- distances, indices = index.search(query_vec, k)
35
- results = [(passages[i], float(dist)) for i, dist in zip(indices[0], distances[0])]
36
- return results
37
-
38
- # ----------------------------
39
- # Evaluation metrics
40
- # ----------------------------
41
- def evaluate(query, relevant_passages, k=10):
42
- """Compute IR metrics for a query given a list of relevant passages (ground truth)."""
43
- results = search(query, k)
44
- retrieved = [res[0] for res in results]
45
-
46
- # Binary relevance vector
47
- y_true = [1 if p in relevant_passages else 0 for p in retrieved]
48
- y_true_full = np.array([[1 if passages[i] in relevant_passages else 0 for i in range(len(passages))]])
49
- y_scores_full = np.zeros((1, len(passages)))
50
- for idx, (res, dist) in enumerate(results):
51
- pos = passages.index(res)
52
- y_scores_full[0, pos] = 1.0 - dist # higher score = more relevant
53
-
54
- # Metrics
55
- precision = sum(y_true) / k
56
- recall = sum(y_true) / len(relevant_passages) if relevant_passages else 0
57
- f1 = (2 * precision * recall) / (precision + recall) if (precision+recall) > 0 else 0
58
- mrr = 1.0 / (y_true.index(1)+1) if 1 in y_true else 0
59
- ndcg = ndcg_score(y_true_full, y_scores_full, k=k)
60
-
61
- return {
62
- "Precision@10": round(precision, 3),
63
- "Recall@10": round(recall, 3),
64
- "F1": round(f1, 3),
65
- "MRR": round(mrr, 3),
66
- "nDCG@10": round(ndcg, 3)
 
 
 
 
 
 
 
 
67
  }
68
 
69
- # ----------------------------
70
- # Gradio interface
71
- # ----------------------------
72
- def gradio_interface(query, relevant_texts):
73
- results = search(query, k=10)
74
- metrics = {}
75
- if relevant_texts.strip():
76
- relevant_passages = [t.strip() for t in relevant_texts.split("\n") if t.strip()]
77
- metrics = evaluate(query, relevant_passages, k=10)
78
- return results, metrics
79
-
80
- demo = gr.Interface(
81
- fn=gradio_interface,
82
- inputs=[
83
- gr.Textbox(label="Enter your query"),
84
- gr.Textbox(label="Enter relevant passages (ground truth, one per line)", placeholder="Optional")
85
- ],
86
- outputs=[
87
- gr.Dataframe(headers=["Passage", "Distance"], label="Top-10 Results"),
88
- gr.Label(label="Evaluation Metrics")
89
- ],
90
- title="SBERT + FAISS Semantic Search",
91
- description="Enter a query to search MS MARCO passages. Optionally provide ground truth passages to compute IR metrics."
92
  )
93
 
94
  if __name__ == "__main__":
95
- demo.launch()
 
1
  import gradio as gr
 
 
2
  from datasets import load_dataset
3
+ from sentence_transformers import SentenceTransformer, util
4
+ import numpy as np
 
 
 
 
 
 
 
5
 
 
6
  # Load SBERT model
 
7
  model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
8
 
9
+ # βœ… Load dataset with passages
10
+ dataset = load_dataset("sentence-transformers/msmarco", "v1.1", split="train[:10000]")
11
+ passages = dataset["passage"]
12
+
13
+ # Encode passages once for efficiency
14
+ passage_embeddings = model.encode(passages, convert_to_tensor=True)
15
+
16
+ # ---------- Evaluation Metrics ----------
17
+ def precision_at_k(relevant, retrieved, k):
18
+ return len(set(relevant) & set(retrieved[:k])) / k
19
+
20
+ def recall_at_k(relevant, retrieved, k):
21
+ return len(set(relevant) & set(retrieved[:k])) / len(relevant) if relevant else 0
22
+
23
+ def f1_at_k(relevant, retrieved, k):
24
+ p = precision_at_k(relevant, retrieved, k)
25
+ r = recall_at_k(relevant, retrieved, k)
26
+ return 2*p*r / (p+r) if (p+r) > 0 else 0
27
+
28
+ def mrr(relevant, retrieved):
29
+ for i, r in enumerate(retrieved):
30
+ if r in relevant:
31
+ return 1 / (i+1)
32
+ return 0
33
+
34
+ def ndcg_at_k(relevant, retrieved, k):
35
+ dcg = 0
36
+ for i, r in enumerate(retrieved[:k]):
37
+ if r in relevant:
38
+ dcg += 1 / np.log2(i+2)
39
+ ideal_dcg = sum(1 / np.log2(i+2) for i in range(min(len(relevant), k)))
40
+ return dcg / ideal_dcg if ideal_dcg > 0 else 0
41
+
42
+ # ---------- Search ----------
43
+ def semantic_search(query, top_k=10):
44
+ query_embedding = model.encode(query, convert_to_tensor=True)
45
+ scores = util.cos_sim(query_embedding, passage_embeddings)[0]
46
+ top_results = scores.topk(k=top_k)
47
+ retrieved = [int(idx) for idx in top_results[1]]
48
+ results = [(passages[idx], float(scores[idx])) for idx in retrieved]
49
+ return results, retrieved
50
+
51
+ # ---------- Interface Logic ----------
52
+ def search_and_evaluate(query):
53
+ results, retrieved = semantic_search(query, top_k=10)
54
+
55
+ # Example: assume top-3 are relevant (for demo purposes)
56
+ relevant = set(retrieved[:3])
57
+
58
+ metrics = {
59
+ "Precision@10": precision_at_k(relevant, retrieved, 10),
60
+ "Recall@10": recall_at_k(relevant, retrieved, 10),
61
+ "F1@10": f1_at_k(relevant, retrieved, 10),
62
+ "MRR": mrr(relevant, retrieved),
63
+ "nDCG@10": ndcg_at_k(relevant, retrieved, 10),
64
  }
65
 
66
+ output_text = "### Search Results:\n"
67
+ for i, (text, score) in enumerate(results, 1):
68
+ output_text += f"{i}. {text} (score: {score:.4f})\n\n"
69
+
70
+ output_text += "\n### Evaluation Metrics:\n"
71
+ for k, v in metrics.items():
72
+ output_text += f"{k}: {v:.4f}\n"
73
+
74
+ return output_text
75
+
76
+ # ---------- Gradio App ----------
77
+ iface = gr.Interface(
78
+ fn=search_and_evaluate,
79
+ inputs=gr.Textbox(label="Enter your query"),
80
+ outputs=gr.Textbox(label="Results + Metrics"),
81
+ title="SBERT Semantic Search + Evaluation Metrics",
82
+ description="Semantic search on MS MARCO (10,000 sample passages) using all-mpnet-base-v2. Includes Precision@10, Recall@10, F1, MRR, nDCG@10."
 
 
 
 
 
 
83
  )
84
 
85
  if __name__ == "__main__":
86
+ iface.launch()