Spaces:
Runtime error
Runtime error
| import json | |
| from pyserini.search.lucene import LuceneSearcher | |
| from tqdm import tqdm | |
| def convert_unicode_to_normal(data): | |
| if isinstance(data, str): | |
| return data.encode('utf-8').decode('utf-8') | |
| elif isinstance(data, list): | |
| assert(isinstance(data[0], str)) | |
| return [sample.encode('utf-8').decode('utf-8') for sample in data] | |
| else: | |
| raise ValueError | |
| K=30 | |
| index_dir="/root/indexes/index-wikipedia-dpr-20210120" # lucene | |
| runfile_path=f"runs/q=NQtest_c=wikidpr_m=bm25_k={K}.run" # bm25 | |
| qafile_path="/root/nota-fairseq/examples/information_retrieval/open_domain_data/NQ/qa_pairs/test.jsonl" | |
| logging_path="logging_q=NQ_c=wiki_including_ans.jsonl" | |
| # define searcher with pre-built indexes | |
| searcher = LuceneSearcher(index_dir=index_dir) | |
| # v2. read qa first (due to runfile query name sort) | |
| print("read qa file") | |
| pair_by_qid = {} | |
| with open(qafile_path, 'r') as fr_qa: | |
| for pair in tqdm(fr_qa): | |
| pair_data = json.loads(pair) | |
| qid, query, answers = pair_data["qid"], pair_data["query"], pair_data["answers"] # str, str, list | |
| pair_by_qid[qid] = {'query': query, 'answers':answers} | |
| print("check retrieved passage include answer") | |
| qid_with_ans_in_retrieval = [] | |
| with open(runfile_path, 'r') as fr_run, open(logging_path, 'w') as fw_log: | |
| for result in tqdm(fr_run): | |
| fields = result.split(' ') | |
| assert(len(fields) == 6) # qid q_type pid k score engine | |
| qid_, pid = fields[0], fields[2] | |
| assert(qid_ in pair_by_qid.keys()) | |
| query, answers = pair_by_qid[qid_]['query'], pair_by_qid[qid_]['answers'] | |
| # get passage | |
| psg_txt = searcher.doc(pid) | |
| psg_txt = psg_txt.raw() | |
| psg_txt = json.loads(psg_txt) | |
| psg_txt = psg_txt['contents'].strip() | |
| psg_txt = convert_unicode_to_normal(psg_txt) | |
| # check if passage contains answer | |
| #if any([ans in psg_txt for ans in answers]): | |
| for ans in answers: | |
| if ans in psg_txt: | |
| log_w = { | |
| "qid": qid_, | |
| "pid": pid, | |
| "query": query, | |
| "answer": ans, | |
| "passage": psg_txt | |
| } | |
| fw_log.write(json.dumps(log_w, ensure_ascii=False) + '\n') | |
| if qid_ not in qid_with_ans_in_retrieval: | |
| qid_with_ans_in_retrieval.append(qid_) | |
| break # don't have to count check multiple answer in passage | |
| print(f"#qid in test set: {len(pair_by_qid.keys())}, #qid having answer with retrieval(BM25, K={K}): {len(qid_with_ans_in_retrieval)}, Recall = {len(qid_with_ans_in_retrieval)/len(pair_by_qid.keys())*100}") | |
| # v1 | |
| """ | |
| with open(runfile_path, 'r') as fr_run, open(qafile_path, 'r') as fr_qa: | |
| for pair in tqdm(fr_qa): | |
| pair_data = json.loads(pair) | |
| qid, query, answers = pair_data["qid"], pair_data["query"], pair_data["answers"] # str, str, list | |
| for k in range(K): | |
| result=fr_run.readline() | |
| print(result) | |
| fields = result.split(' ') | |
| assert(len(fields) == 6) # qid q_type pid k score engine | |
| qid_, pid = fields[0], fields[2] | |
| assert(qid == qid_), f"qid={qid}, qid_={qid_} should be same" | |
| # get passage | |
| psg_txt = searcher.doc(pid) | |
| psg_txt = psg_txt.raw() | |
| psg_txt = json.loads(psg_txt) | |
| psg_txt = psg_txt['contents'].strip() | |
| psg_txt = convert_unicode_to_normal(psg_txt) | |
| # check if passage contains answer | |
| if any([ans in psg_txt for ans in answers]): | |
| import pdb | |
| pdb.set_trace() | |
| """ |