Xlordo commited on
Commit
c882be2
Β·
verified Β·
1 Parent(s): b4f6db8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -18
app.py CHANGED
@@ -3,17 +3,28 @@ from datasets import load_dataset
3
  from sentence_transformers import SentenceTransformer, util
4
  import numpy as np
5
 
6
- # Load SBERT model
7
  model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
8
 
9
- # βœ… Load dataset with passages
 
10
  dataset = load_dataset("sentence-transformers/msmarco", "v1.1", split="train[:10000]")
11
  passages = dataset["passage"]
12
 
13
- # Encode passages once for efficiency
14
  passage_embeddings = model.encode(passages, convert_to_tensor=True)
15
 
16
- # ---------- Evaluation Metrics ----------
 
 
 
 
 
 
 
 
 
 
17
  def precision_at_k(relevant, retrieved, k):
18
  return len(set(relevant) & set(retrieved[:k])) / k
19
 
@@ -44,23 +55,27 @@ def semantic_search(query, top_k=10):
44
  query_embedding = model.encode(query, convert_to_tensor=True)
45
  scores = util.cos_sim(query_embedding, passage_embeddings)[0]
46
  top_results = scores.topk(k=top_k)
47
- retrieved = [int(idx) for idx in top_results[1]]
48
- results = [(passages[idx], float(scores[idx])) for idx in retrieved]
49
- return results, retrieved
50
 
51
- # ---------- Interface Logic ----------
52
  def search_and_evaluate(query):
53
- results, retrieved = semantic_search(query, top_k=10)
54
 
55
- # Example: assume top-3 are relevant (for demo purposes)
56
- relevant = set(retrieved[:3])
 
 
 
 
57
 
58
  metrics = {
59
- "Precision@10": precision_at_k(relevant, retrieved, 10),
60
- "Recall@10": recall_at_k(relevant, retrieved, 10),
61
- "F1@10": f1_at_k(relevant, retrieved, 10),
62
- "MRR": mrr(relevant, retrieved),
63
- "nDCG@10": ndcg_at_k(relevant, retrieved, 10),
64
  }
65
 
66
  output_text = "### Search Results:\n"
@@ -73,13 +88,12 @@ def search_and_evaluate(query):
73
 
74
  return output_text
75
 
76
- # ---------- Gradio App ----------
77
  iface = gr.Interface(
78
  fn=search_and_evaluate,
79
  inputs=gr.Textbox(label="Enter your query"),
80
  outputs=gr.Textbox(label="Results + Metrics"),
81
  title="SBERT Semantic Search + Evaluation Metrics",
82
- description="Semantic search on MS MARCO (10,000 sample passages) using all-mpnet-base-v2. Includes Precision@10, Recall@10, F1, MRR, nDCG@10."
83
  )
84
 
85
  if __name__ == "__main__":
 
3
  from sentence_transformers import SentenceTransformer, util
4
  import numpy as np
5
 
6
+ # ---------- Load model ----------
7
  model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
8
 
9
+ # ---------- Load MS MARCO dataset ----------
10
+ # 10k sample passages
11
  dataset = load_dataset("sentence-transformers/msmarco", "v1.1", split="train[:10000]")
12
  passages = dataset["passage"]
13
 
14
+ # Precompute embeddings
15
  passage_embeddings = model.encode(passages, convert_to_tensor=True)
16
 
17
+ # Map index -> passage
18
+ id_to_passage = {i: passages[i] for i in range(len(passages))}
19
+
20
+ # ---------- Load queries and qrels ----------
21
+ queries_dataset = load_dataset("sentence-transformers/msmarco", "v1.1", split="validation[:500]") # small sample
22
+ qrels_dataset = load_dataset("ms_marco", "v1.1", split="validation[:500]") # contains relevant passage ids
23
+
24
+ query_id_to_text = {i: q["query"] for i, q in enumerate(queries_dataset)}
25
+ query_id_to_relevant = {i: set(q["positive_passages"]) for i, q in enumerate(qrels_dataset)}
26
+
27
+ # ---------- Evaluation metrics ----------
28
  def precision_at_k(relevant, retrieved, k):
29
  return len(set(relevant) & set(retrieved[:k])) / k
30
 
 
55
  query_embedding = model.encode(query, convert_to_tensor=True)
56
  scores = util.cos_sim(query_embedding, passage_embeddings)[0]
57
  top_results = scores.topk(k=top_k)
58
+ retrieved_indices = [int(idx) for idx in top_results[1]]
59
+ results = [(id_to_passage[idx], float(scores[idx])) for idx in retrieved_indices]
60
+ return results, retrieved_indices
61
 
62
+ # ---------- Gradio interface ----------
63
  def search_and_evaluate(query):
64
+ results, retrieved_indices = semantic_search(query, top_k=10)
65
 
66
+ # Match against actual relevant passages if available
67
+ relevant_indices = set()
68
+ for i, q in query_id_to_text.items():
69
+ if q.strip().lower() == query.strip().lower():
70
+ relevant_indices = query_id_to_relevant[i]
71
+ break
72
 
73
  metrics = {
74
+ "Precision@10": precision_at_k(relevant_indices, retrieved_indices, 10),
75
+ "Recall@10": recall_at_k(relevant_indices, retrieved_indices, 10),
76
+ "F1@10": f1_at_k(relevant_indices, retrieved_indices, 10),
77
+ "MRR": mrr(relevant_indices, retrieved_indices),
78
+ "nDCG@10": ndcg_at_k(relevant_indices, retrieved_indices, 10)
79
  }
80
 
81
  output_text = "### Search Results:\n"
 
88
 
89
  return output_text
90
 
 
91
  iface = gr.Interface(
92
  fn=search_and_evaluate,
93
  inputs=gr.Textbox(label="Enter your query"),
94
  outputs=gr.Textbox(label="Results + Metrics"),
95
  title="SBERT Semantic Search + Evaluation Metrics",
96
+ description="Semantic search on MS MARCO (10,000 sample passages) using all-mpnet-base-v2 with true evaluation metrics."
97
  )
98
 
99
  if __name__ == "__main__":