Safetensors
daviddongdong commited on
Commit
61fc49d
·
verified ·
1 Parent(s): 0a463e4

Update text_wrapper.py

Browse files
Files changed (1) hide show
  1. text_wrapper.py +81 -41
text_wrapper.py CHANGED
@@ -2,6 +2,15 @@ import torch
2
  import numpy as np
3
  from tqdm import tqdm
4
 
 
 
 
 
 
 
 
 
 
5
  class Sent_Retriever:
6
  def __init__(self, bs=256, use_gpu=True):
7
  self.bs = bs
@@ -19,8 +28,20 @@ class Sent_Retriever:
19
  return embeddings
20
 
21
  def score(self, queries, quotes):
22
- query_emb = np.asarray(self.embed_queries(queries))
23
- quote_emb = np.asarray(self.embed_quotes(quotes))
 
 
 
 
 
 
 
 
 
 
 
 
24
  return (query_emb @ quote_emb.T).tolist()
25
 
26
  def get_tok_len(self, text_input):
@@ -93,11 +114,10 @@ class GTE(Sent_Retriever):
93
  return self.embed_passages(quotes)
94
 
95
 
96
-
97
  class Contriever():
98
- def __init__(self, bs = 256, use_gpu= True):
99
  from transformers import AutoTokenizer, AutoModel
100
- self.model_path = 'checkpoint/contriever-msmarco'
101
  self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
102
  self.model = AutoModel.from_pretrained(self.model_path)
103
  self.bs = bs
@@ -133,21 +153,32 @@ class Contriever():
133
  quote_embeddings.extend([q.cpu().detach().numpy() for q in batched_quote_embs])
134
  return quote_embeddings
135
 
136
- def score(self, query, quotes):
137
- query_emb = np.asarray(self.embed_queries(query))
138
- quote_emb = np.asarray(self.embed_quotes(quotes))
139
- scores = (query_emb @ quote_emb.T).tolist()
140
- return scores
 
 
 
 
 
 
 
 
 
 
 
141
 
142
 
143
  class DPR():
144
- def __init__(self, bs = 256, use_gpu= True):
145
  from transformers import DPRContextEncoder, DPRContextEncoderTokenizer, DPRQuestionEncoder, DPRQuestionEncoderTokenizer
146
- self.model_path = "checkpoint/"
147
- self.query_tok = DPRQuestionEncoderTokenizer.from_pretrained(self.model_path +"dpr-question_encoder-multiset-base")
148
- self.query_enc = DPRQuestionEncoder.from_pretrained(self.model_path +"dpr-question_encoder-multiset-base")
149
- self.ctx_tok = DPRContextEncoderTokenizer.from_pretrained(self.model_path +"dpr-ctx_encoder-multiset-base")
150
- self.ctx_enc = DPRContextEncoder.from_pretrained(self.model_path +"dpr-ctx_encoder-multiset-base")
151
  self.bs = bs
152
  print("[text_wrapper.py - init] Setting up DPR...")
153
  print("[text_wrapper.py - init] DPR is loaded from '{}'...".format( self.model_path ))
@@ -187,19 +218,30 @@ class DPR():
187
  quote_embeddings.extend(quote_emb)
188
  return quote_embeddings
189
 
190
- def score(self, query, quotes):
191
- query_emb = np.asarray(self.embed_queries(query))
192
- quote_emb = np.asarray(self.embed_quotes(quotes))
193
- scores = (query_emb @ quote_emb.T).tolist()
194
- return scores
 
 
 
 
 
 
 
 
 
 
 
195
 
196
 
197
  class ColBERTReranker:
198
- def __init__(self, bs = 256, use_gpu= True):
199
  from colbert.modeling.colbert import ColBERT
200
  from colbert.infra import ColBERTConfig
201
  from transformers import AutoTokenizer
202
- self.model_path = "checkpoint/colbertv2.0"
203
  self.bs = bs
204
  config = ColBERTConfig(bsize=bs, root='./', query_token_id='[Q]', doc_token_id='[D]')
205
  self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
@@ -231,8 +273,6 @@ class ColBERTReranker:
231
  length = mask.sum().item() # Number of true tokens in this sequence
232
  np_emb = emb[:length].cpu().numpy() # Shape: [L, H]
233
  query_embeddings.append(np_emb) # `L` varies per example
234
-
235
- # torch.cuda.empty_cache()
236
  return query_embeddings
237
 
238
  @staticmethod
@@ -274,32 +314,32 @@ class ColBERTReranker:
274
  return quote_embeddings, quote_masks
275
  return quote_embeddings
276
 
277
-
278
  @staticmethod
279
  def colbert_score(query_embed, quote_embeddings, quote_masks):
280
  Q, H = query_embed.shape # [Q, H]
281
  N, L, _ = quote_embeddings.shape # [N, L, H]
282
- # 1. Compute [Q, N, L] (similarity btw every query token to every quote token)
283
- # Expand query to [Q, 1, 1, H], quote_embeddings to [1, N, L, H]
284
  query_expanded = query_embed[:, np.newaxis, np.newaxis, :] # [Q, 1, 1, H]
285
  quote_expanded = quote_embeddings[np.newaxis, :, :, :] # [1, N, L, H]
286
  sim = np.matmul(query_expanded, np.transpose(quote_expanded, (0 ,1 ,3 ,2))) # (Q, N, 1, L)
287
- # But let's use broadcasting for dot product:
288
- # sim[q, n, l] = np.dot(query_embed[q], quote_embeddings[n,l])
289
  sim = np.einsum('qh,nlh->qnl', query_embed, quote_embeddings) # [Q, N, L]
290
- # 2. Mask invalid tokens
291
- sim = np.where(quote_masks[np.newaxis, :, : ]==1, sim, -1e9) # [Q, N, L]
292
- # 3. MaxSim: For each query token, take max over quote tokens (L dimension)
293
- maxsim = sim.max(-1) # [Q, N]
294
- # 4. Aggregate (sum over query tokens)
295
- scores = maxsim.sum(axis=0) # [N]
296
  return scores
297
 
298
- def score(self, query, quotes):
299
- query_embeddings = self.embed_queries(query)
300
- quote_embeddings, quote_masks = self.embed_quotes(quotes, pad_token_len=True)
 
 
 
 
 
 
 
 
301
  scores_list = []
302
- for query_embed in query_embeddings:
303
- scores = self.colbert_score(query_embed, quote_embeddings, quote_masks)
304
  scores_list.append(scores.tolist())
305
  return scores_list
 
2
  import numpy as np
3
  from tqdm import tqdm
4
 
5
+ def is_str_list(obj): # Checks if it's a list and all elements are strings
6
+ return isinstance(obj, list) and all(isinstance(item, str) for item in obj)
7
+
8
+ def is_np_list(obj): # Checks if it's a list and all elements are np.ndarray
9
+ return isinstance(obj, list) and all(isinstance(item, np.ndarray) for item in obj)
10
+
11
+ def is_np_array(obj): # Checks if it's a np.ndarray
12
+ return isinstance(obj, np.ndarray)
13
+
14
  class Sent_Retriever:
15
  def __init__(self, bs=256, use_gpu=True):
16
  self.bs = bs
 
28
  return embeddings
29
 
30
  def score(self, queries, quotes):
31
+ if is_str_list(queries):
32
+ query_emb = np.asarray(self.embed_queries(queries))
33
+ elif is_np_list(queries):
34
+ query_emb = np.asarray(queries)
35
+ elif is_np_array(queries):
36
+ query_emb = queries
37
+
38
+ if is_str_list(quotes):
39
+ quote_emb = np.asarray(self.embed_quotes(quotes))
40
+ elif is_np_list(quotes):
41
+ quote_emb = np.asarray(quotes)
42
+ elif is_np_array(quotes):
43
+ quote_emb = quotes
44
+
45
  return (query_emb @ quote_emb.T).tolist()
46
 
47
  def get_tok_len(self, text_input):
 
114
  return self.embed_passages(quotes)
115
 
116
 
 
117
  class Contriever():
118
+ def __init__(self, bs = 256, use_gpu= True, model_path='checkpoint/contriever-msmarco'):
119
  from transformers import AutoTokenizer, AutoModel
120
+ self.model_path = model_path
121
  self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
122
  self.model = AutoModel.from_pretrained(self.model_path)
123
  self.bs = bs
 
153
  quote_embeddings.extend([q.cpu().detach().numpy() for q in batched_quote_embs])
154
  return quote_embeddings
155
 
156
+ def score(self, queries, quotes):
157
+ if is_str_list(queries):
158
+ query_emb = np.asarray(self.embed_queries(queries))
159
+ elif is_np_list(queries):
160
+ query_emb = np.asarray(queries)
161
+ elif is_np_array(queries):
162
+ query_emb = queries
163
+
164
+ if is_str_list(quotes):
165
+ quote_emb = np.asarray(self.embed_quotes(quotes))
166
+ elif is_np_list(quotes):
167
+ quote_emb = np.asarray(quotes)
168
+ elif is_np_array(quotes):
169
+ quote_emb = quotes
170
+
171
+ return (query_emb @ quote_emb.T).tolist()
172
 
173
 
174
  class DPR():
175
+ def __init__(self, bs = 256, use_gpu=True, model_path="checkpoint"):
176
  from transformers import DPRContextEncoder, DPRContextEncoderTokenizer, DPRQuestionEncoder, DPRQuestionEncoderTokenizer
177
+ self.model_path = model_path
178
+ self.query_tok = DPRQuestionEncoderTokenizer.from_pretrained(self.model_path +"/dpr-question_encoder-multiset-base")
179
+ self.query_enc = DPRQuestionEncoder.from_pretrained(self.model_path +"/dpr-question_encoder-multiset-base")
180
+ self.ctx_tok = DPRContextEncoderTokenizer.from_pretrained(self.model_path +"/dpr-ctx_encoder-multiset-base")
181
+ self.ctx_enc = DPRContextEncoder.from_pretrained(self.model_path +"/dpr-ctx_encoder-multiset-base")
182
  self.bs = bs
183
  print("[text_wrapper.py - init] Setting up DPR...")
184
  print("[text_wrapper.py - init] DPR is loaded from '{}'...".format( self.model_path ))
 
218
  quote_embeddings.extend(quote_emb)
219
  return quote_embeddings
220
 
221
+ def score(self, queries, quotes):
222
+ if is_str_list(queries):
223
+ query_emb = np.asarray(self.embed_queries(queries))
224
+ elif is_np_list(queries):
225
+ query_emb = np.asarray(queries)
226
+ elif is_np_array(queries):
227
+ query_emb = queries
228
+
229
+ if is_str_list(quotes):
230
+ quote_emb = np.asarray(self.embed_quotes(quotes))
231
+ elif is_np_list(quotes):
232
+ quote_emb = np.asarray(quotes)
233
+ elif is_np_array(quotes):
234
+ quote_emb = quotes
235
+
236
+ return (query_emb @ quote_emb.T).tolist()
237
 
238
 
239
  class ColBERTReranker:
240
+ def __init__(self, bs = 256, use_gpu= True, model_path="checkpoint/colbertv2.0"):
241
  from colbert.modeling.colbert import ColBERT
242
  from colbert.infra import ColBERTConfig
243
  from transformers import AutoTokenizer
244
+ self.model_path = model_path
245
  self.bs = bs
246
  config = ColBERTConfig(bsize=bs, root='./', query_token_id='[Q]', doc_token_id='[D]')
247
  self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
 
273
  length = mask.sum().item() # Number of true tokens in this sequence
274
  np_emb = emb[:length].cpu().numpy() # Shape: [L, H]
275
  query_embeddings.append(np_emb) # `L` varies per example
 
 
276
  return query_embeddings
277
 
278
  @staticmethod
 
314
  return quote_embeddings, quote_masks
315
  return quote_embeddings
316
 
 
317
  @staticmethod
318
  def colbert_score(query_embed, quote_embeddings, quote_masks):
319
  Q, H = query_embed.shape # [Q, H]
320
  N, L, _ = quote_embeddings.shape # [N, L, H]
 
 
321
  query_expanded = query_embed[:, np.newaxis, np.newaxis, :] # [Q, 1, 1, H]
322
  quote_expanded = quote_embeddings[np.newaxis, :, :, :] # [1, N, L, H]
323
  sim = np.matmul(query_expanded, np.transpose(quote_expanded, (0 ,1 ,3 ,2))) # (Q, N, 1, L)
 
 
324
  sim = np.einsum('qh,nlh->qnl', query_embed, quote_embeddings) # [Q, N, L]
325
+ sim = np.where(quote_masks[np.newaxis, :, : ]==1, sim, -1e9) # Mask invalid tokens [Q, N, L]
326
+ maxsim = sim.max(-1) # MaxSim: For each query token, take max over quote tokens [Q, N]
327
+ scores = maxsim.sum(axis=0) # Aggregate (sum over query tokens) [N]
 
 
 
328
  return scores
329
 
330
+ def score(self, queries, quotes):
331
+ if is_str_list(queries):
332
+ query_embed = self.embed_queries(queries)
333
+ elif is_np_list(queries):
334
+ query_embed = queries
335
+
336
+ if is_str_list(quotes):
337
+ quote_embed, quote_masks = self.embed_quotes(quotes, pad_token_len=True)
338
+ elif is_np_list(quotes):
339
+ quote_embed, quote_masks = self.pad_tok_len(quotes)
340
+
341
  scores_list = []
342
+ for q_embed in query_embed:
343
+ scores = self.colbert_score(q_embed, quote_embed, quote_masks)
344
  scores_list.append(scores.tolist())
345
  return scores_list