Cudd1es commited on
Commit
900e88e
·
verified ·
1 Parent(s): 674e7e8

Upload 3 files

Browse files
Files changed (3) hide show
  1. embedder.py +241 -0
  2. llm_agent.py +56 -0
  3. retriever.py +174 -0
embedder.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sentence_transformers import SentenceTransformer
2
+ from sentence_transformers.sparse_encoder import SparseEncoder
3
+ import torch
4
+ import os
5
+ import re
6
+ from tqdm import tqdm
7
+ import numpy as np
8
+ import chromadb
9
+ import json
10
+ import time
11
+
12
+
13
+ """
14
+ Currently only do dense encoding
15
+ Sparse encoding related functions are placeholders
16
+ """
17
+
18
+ DENSE_EMBEDDER_MODEL = "BAAI/bge-base-zh-v1.5"
19
+ SPARSE_EMBEDDER_MODEL = "opensearch-project/opensearch-neural-sparse-encoding-doc-v2-distill"
20
+
21
+
22
+ class TextCleaner:
23
+ def __init__(self, lowercase=False, remove_urls=True, normalize_space=True):
24
+ self.lowercase = lowercase
25
+ self.remove_urls = remove_urls
26
+ self.normalize_space = normalize_space
27
+
28
+ def clean(self, text: str) -> str:
29
+ text = text.strip()
30
+ if self.lowercase:
31
+ text = text.lower()
32
+ if self.remove_urls:
33
+ text = re.sub(r"http\S+", "", text)
34
+ if self.normalize_space:
35
+ text = re.sub(r"\s+", " ", text)
36
+ return text
37
+
38
+
39
+ def join_chunk_text(text_chunk):
40
+ # support chunk (list of sentences) processing
41
+ if isinstance(text_chunk, list):
42
+ return "\n".join(text_chunk)
43
+ return text_chunk
44
+
45
+
46
+ class DenseTextEncoder:
47
+ """
48
+ output: numpy array
49
+ """
50
+ def __init__(self, model_name, normalize=True, device=None):
51
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
52
+ self.model = SentenceTransformer(model_name)
53
+ self.normalize = normalize
54
+ self.cleaner = TextCleaner()
55
+
56
+ def _prepare_texts(self, texts):
57
+ """Support single string, list[str], or list[list[str]]"""
58
+ if isinstance(texts, str):
59
+ texts = [texts]
60
+ elif isinstance(texts, list):
61
+ if all(isinstance(t, str) for t in texts):
62
+ texts = [join_chunk_text(t) if isinstance(t, list) else t for t in texts]
63
+ elif all(isinstance(t, list) for t in texts):
64
+ texts = [join_chunk_text(t) for t in texts]
65
+ else:
66
+ raise ValueError("Input list must contain only str or list[str].")
67
+ else:
68
+ raise ValueError("Input must be str or list.")
69
+ cleaned = [self.cleaner.clean(t) for t in texts]
70
+ return cleaned
71
+
72
+ def encode_document(self, texts):
73
+ cleaned = self._prepare_texts(texts)
74
+ output = self.model.encode_document(cleaned, convert_to_numpy=True, normalize_embeddings=self.normalize)
75
+ return output
76
+
77
+ def encode_query(self, texts):
78
+ cleaned = self._prepare_texts(texts)
79
+ output = self.model.encode_query(cleaned, convert_to_numpy=True, normalize_embeddings=self.normalize)
80
+ return output
81
+
82
+
83
+ class SparseTextEncoder:
84
+ """
85
+ output: torch tensor
86
+ """
87
+ def __init__(self, model_name, device=None):
88
+ self.device = device or "cpu"
89
+ self.encoder = SparseEncoder(model_name, device=self.device)
90
+ self.cleaner = TextCleaner()
91
+
92
+ def _prepare_texts(self, texts):
93
+ """Support single string, list[str], or list[list[str]]"""
94
+ if isinstance(texts, str):
95
+ texts = [texts]
96
+ elif isinstance(texts, list):
97
+ if all(isinstance(t, str) for t in texts):
98
+ texts = [join_chunk_text(t) if isinstance(t, list) else t for t in texts]
99
+ elif all(isinstance(t, list) for t in texts):
100
+ texts = [join_chunk_text(t) for t in texts]
101
+ else:
102
+ raise ValueError("Input list must contain only str or list[str].")
103
+ else:
104
+ raise ValueError("Input must be str or list.")
105
+ cleaned = [self.cleaner.clean(t) for t in texts]
106
+ return cleaned
107
+
108
+ def encode_document(self, texts):
109
+ """Encode for corpus indexing"""
110
+ cleaned = self._prepare_texts(texts)
111
+ return self.encoder.encode_document(cleaned)
112
+
113
+ def encode_query(self, texts):
114
+ """Encode for query retrieval"""
115
+ cleaned = self._prepare_texts(texts)
116
+ return self.encoder.encode_query(cleaned)
117
+
118
+ def read_input(source):
119
+ if os.path.exists(source):
120
+ with open(source, "r", encoding="utf-8") as f:
121
+ lines = [line.strip() for line in f if line.strip()]
122
+ return lines
123
+ else:
124
+ return [source]
125
+
126
+
127
+ def encode_chunks_with_metadata(chunks, dense_encoder, sparse_encoder):
128
+ """
129
+ :param chunks: [{'text': [...], 'chunk_id': ..., ...metadata}, ...]
130
+ :param dense_encoder: dense encoder model
131
+ :param sparse_encoder: sparse encoder model
132
+ :return:
133
+ {
134
+ "chunk_id": "9b28e9938292486e9a61f2d1787bb828",
135
+ "dense_embedding": np.array([...]),
136
+ "sparse_embedding": torch.sparse.Tensor(...),
137
+ "text": "友希那: ...\n莉莎: ...",
138
+ "eventName": "连结思绪的未竟之歌",
139
+ "chapterTitle": "序章: 古旧的磁带",
140
+ "story_type": "event",
141
+ # ...other metadata
142
+ }
143
+ """
144
+ text = [join_chunk_text(chunk["text"]) for chunk in chunks]
145
+ dense_vecs = dense_encoder.encode_document(text)
146
+ # placeholder, skip sparse encoding for now
147
+ #sparse_vecs = sparse_encoder.encode_document(text)
148
+ result = []
149
+ for i, chunk in enumerate(chunks):
150
+
151
+ # placeholder, skip sparse encoding for now
152
+ #sparse_i = sparse_vecs[i]
153
+ #if isinstance(sparse_i, torch.Tensor) and sparse_i.is_sparse:
154
+ # sparse_i = sparse_i.coalesce()
155
+
156
+ result.append({
157
+ "chunk_id": chunk.get("chunk_id"),
158
+ "dense_embedding": dense_vecs[i],
159
+ # placeholder, skip sparse encoding for now
160
+ "sparse_embedding": None,
161
+ #"sparse_embedding": sparse_vecs[i],
162
+ "text": text[i],
163
+ "eventName": chunk.get("eventName"),
164
+ "chapterTitle": chunk.get("chapterTitle"),
165
+ "story_type": chunk.get("story_type"),
166
+ "start_idx": chunk.get("start_idx"),
167
+ "end_idx": chunk.get("end_idx"),
168
+ })
169
+ return result
170
+
171
+ # save dense embedding to chroma vector database
172
+ def save_chunks_to_chroma(embedded_chunk, collection):
173
+ ids = []
174
+ documents = []
175
+ embeddings = []
176
+ metadata = []
177
+
178
+ for entry in embedded_chunk:
179
+ ids.append(entry["chunk_id"])
180
+ documents.append(entry["text"])
181
+ embeddings.append(
182
+ entry["dense_embedding"].tolist() if isinstance(entry["dense_embedding"], np.ndarray) else entry[
183
+ "dense_embedding"])
184
+ # currently do not store sparse embedding to chroma
185
+ meta = {k: v for k, v in entry.items() if k not in ["chunk_id", "dense_embedding","sparse_embedding", "text"]}
186
+ metadata.append(meta)
187
+
188
+ batch_size = 64
189
+ for i in range(0, len(ids), batch_size):
190
+ collection.add(
191
+ ids=ids[i:i + batch_size],
192
+ documents=documents[i:i + batch_size],
193
+ embeddings=embeddings[i:i + batch_size],
194
+ metadatas=metadata[i:i + batch_size]
195
+ )
196
+ print(f"saved {len(ids)} chunks to {collection.name}")
197
+
198
+ def read_jsonl_in_batches(file_path, batch_size=64):
199
+ batch = []
200
+ with open(file_path, 'r', encoding='utf8') as f:
201
+ for line in f:
202
+ if line.strip():
203
+ batch.append(json.loads(line))
204
+ if len(batch) == batch_size:
205
+ yield batch
206
+ batch = []
207
+ if batch:
208
+ yield batch
209
+
210
+
211
+ if __name__ == "__main__":
212
+ chunk_files = [
213
+ "./chunks/band_chunks.jsonl",
214
+ "./chunks/card_chunks.jsonl",
215
+ "./chunks/event_chunks.jsonl",
216
+ "./chunks/main_chunks.jsonl"
217
+ ]
218
+
219
+ dense_encoder = DenseTextEncoder(DENSE_EMBEDDER_MODEL)
220
+ sparse_encoder = SparseTextEncoder(SPARSE_EMBEDDER_MODEL)
221
+
222
+ # init databases
223
+ chroma_client = chromadb.PersistentClient(path="./chroma_db")
224
+ chroma_collection = chroma_client.get_or_create_collection("bangdream_dense")
225
+
226
+ start_time = time.time()
227
+ for file_path in chunk_files:
228
+
229
+ with open(file_path, 'r', encoding='utf8') as f:
230
+ total_lines = sum(1 for line in f if line.strip())
231
+ print(f"\nProcessing {file_path} ({total_lines} chunks)")
232
+
233
+ pbar = tqdm(total=total_lines, desc=f"Encoding {os.path.basename(file_path)}", unit="chunk")
234
+
235
+ for batch in read_jsonl_in_batches(file_path, batch_size=64):
236
+ embedded = encode_chunks_with_metadata(batch, dense_encoder, sparse_encoder)
237
+ save_chunks_to_chroma(embedded, chroma_collection)
238
+ pbar.update(len(batch))
239
+ pbar.close()
240
+ end_time = time.time()
241
+ print(f"Total time used: {end_time - start_time}")
llm_agent.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from openai import OpenAI
2
+ from dotenv import load_dotenv
3
+ import os
4
+ from retriever import load_encoder, load_collection, encode_query, retrieve_docs, query_rerank, expand_with_neighbors, dedup_by_chapter_event
5
+ from sentence_transformers import CrossEncoder
6
+
7
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
8
+ # load llm api key in .env
9
+ load_dotenv()
10
+ api_key = os.getenv("OPENAI_API_KEY")
11
+
12
+ client = OpenAI(api_key=api_key)
13
+
14
+ def build_rag_prompt(query, context):
15
+ prompt = f"""已知资料如下:
16
+ {context}
17
+
18
+ 用户提问:{query}
19
+ 请参考所有已知资料, 并结合资料内容,简明、准确地回答问题。如果有多个符合的答案, 可以根据你是否确定而决定是否分别陈述这些答案.如果不能确定答案,请如实说明理由,不要凭空编造。"""
20
+ return prompt
21
+
22
+ def llm_answer(query, expanded_results, model_name="gpt-4o"):
23
+ context = expanded_results[0][0] if expanded_results else ""
24
+ prompt = build_rag_prompt(query, context)
25
+ response = client.chat.completions.create(
26
+ model=model_name,
27
+ messages=[
28
+ {"role": "system", "content": "你是BangDream知识问答助手, 也就是邦学家. 只能基于提供的资料内容作答。"},
29
+ {"role": "user", "content": prompt}
30
+ ],
31
+ temperature=0.2,
32
+ max_tokens=512,
33
+ )
34
+ return response.choices[0].message.content.strip()
35
+
36
+ if __name__ == "__main__":
37
+ collection = load_collection()
38
+ encoder = load_encoder()
39
+ reranker = CrossEncoder("BAAI/bge-reranker-large")
40
+
41
+ query_text = input("please enter your question:")
42
+ print("Thinking...\n...")
43
+ query_vec = encode_query(encoder, query_text)
44
+ results = retrieve_docs(collection, query_vec, top_k=50)
45
+ reranked = query_rerank(reranker, query_text, results, top_n=20)
46
+ deduped = dedup_by_chapter_event(reranked, max_per_group=1)
47
+ expanded_results = expand_with_neighbors(deduped[:5], collection)
48
+
49
+ answer = llm_answer(query_text, expanded_results)
50
+
51
+ print("\n=== Answer ===")
52
+ print(answer)
53
+ print("\n=== retrieved documents ===")
54
+ for idx, (context, score, meta) in enumerate(expanded_results, 1):
55
+ print(f"\n--- document {idx} (Score={score:.4f}) ---\n{context[:200]}...")
56
+ print(meta)
retriever.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import chromadb
2
+ from sentence_transformers import SentenceTransformer, CrossEncoder
3
+
4
+ CHROMA_DB_DIR = "./chroma_db"
5
+ COLLECTION_NAME = "bangdream_dense"
6
+ MODEL_NAME = "BAAI/bge-base-zh-v1.5"
7
+
8
+ reranker = CrossEncoder("BAAI/bge-reranker-large")
9
+
10
+
11
+ def load_collection(db_path=CHROMA_DB_DIR, collection_name=COLLECTION_NAME):
12
+ """Connect to Chroma persistent DB and load a collection."""
13
+ client = chromadb.PersistentClient(path=db_path)
14
+ collection = client.get_or_create_collection(collection_name)
15
+ return collection
16
+
17
+ def load_encoder(model_name=MODEL_NAME):
18
+ """Load dense encoder model."""
19
+ return SentenceTransformer(model_name)
20
+
21
+ def encode_query(encoder, query_text):
22
+ """Encode query text into normalized embedding."""
23
+ return encoder.encode_query([query_text], normalize_embeddings=True)
24
+
25
+ def dedup_by_chapter_event(reranked_docs, max_per_group=1):
26
+ """de-duplicate when chapterTitle and eventName are identical"""
27
+ seen = {}
28
+ deduped = []
29
+ for doc, score, meta in reranked_docs:
30
+ key = (meta.get("chapterTitle", ""), meta.get("eventName", ""))
31
+ if key not in seen:
32
+ seen[key] = 1
33
+ deduped.append((doc, score, meta))
34
+ elif seen[key] < max_per_group:
35
+ seen[key] += 1
36
+ deduped.append((doc, score, meta))
37
+ return deduped
38
+
39
+ def retrieve_docs(collection, query_vec, top_k=5):
40
+ """Retrieve documents from Chroma collection."""
41
+ results = collection.query(
42
+ query_embeddings=query_vec,
43
+ n_results=top_k,
44
+ include=["metadatas", "documents", "distances"],
45
+ )
46
+ return results
47
+
48
+ def query_rerank(reranker, query, results, top_n=3):
49
+ """Use CrossEncoder to re-rank retrieved results."""
50
+ docs = results["documents"][0]
51
+ pairs = [(query, doc) for doc in docs]
52
+
53
+ # CrossEncoder
54
+ scores = reranker.predict(pairs)
55
+
56
+ # rerank
57
+ ranked = sorted(zip(docs, scores, results["metadatas"][0]), key=lambda x: x[1], reverse=True)
58
+
59
+ # get top_n
60
+ reranked_docs = ranked[:top_n]
61
+
62
+ """
63
+ # print result
64
+ print("=== After Rerank ===")
65
+ for i, (doc, score, meta) in enumerate(reranked_docs, 1):
66
+ print(f"Rank {i} | Score: {score:.4f}")
67
+ print(meta)
68
+ print(doc)
69
+ print("-" * 40)
70
+ """
71
+
72
+ return reranked_docs
73
+
74
+ def pretty_print_results(results):
75
+ """Nicely print retrieved results."""
76
+ docs = results["documents"][0]
77
+ dists = results["distances"][0]
78
+ metas = results["metadatas"][0]
79
+ for idx, (doc, dist, meta) in enumerate(zip(docs, dists, metas)):
80
+ print(f"Rank {idx + 1} | Distance: {dist:.4f}")
81
+ print(meta)
82
+ print(doc)
83
+ print("-" * 40)
84
+
85
+ # expend documents
86
+ def get_all_chunks_in_chapter(collection, chapter_title, event_name=None, story_type=None):
87
+ filters = []
88
+ if chapter_title:
89
+ filters.append({"chapterTitle": chapter_title})
90
+ if story_type:
91
+ filters.append({"story_type": story_type})
92
+ if event_name:
93
+ filters.append({"eventName": event_name})
94
+ if len(filters) == 1:
95
+ filter_dict = filters[0]
96
+ elif len(filters) > 1:
97
+ filter_dict = {"$and": filters}
98
+ else:
99
+ filter_dict = {}
100
+ results = collection.get(where=filter_dict, include=["documents", "metadatas"])
101
+ chunk_list = []
102
+ for doc, meta in zip(results["documents"], results["metadatas"]):
103
+ chunk_list.append({
104
+ "text": doc,
105
+ **meta,
106
+ })
107
+ return chunk_list
108
+
109
+ def find_adjacent_chunks(current_chunk, all_chunks):
110
+ start_idx = current_chunk['start_idx']
111
+ end_idx = current_chunk['end_idx']
112
+ prev_chunk, next_chunk = None, None
113
+ for chunk in all_chunks:
114
+ if chunk['end_idx'] == start_idx - 1:
115
+ prev_chunk = chunk
116
+ if chunk['start_idx'] == end_idx + 1:
117
+ next_chunk = chunk
118
+ return prev_chunk, next_chunk
119
+
120
+ def safe_to_list(x):
121
+ if isinstance(x, str):
122
+ return x.split('\n') if '\n' in x else [x]
123
+ return list(x)
124
+
125
+ def expand_with_neighbors(reranked_docs, collection):
126
+ expanded_results = []
127
+ for doc, score, meta in reranked_docs:
128
+ #print(meta)
129
+ chapter_title = meta.get("chapterTitle", "")
130
+ event_name = meta.get("eventName", "")
131
+ story_type = meta.get("story_type", None)
132
+ all_chunks = get_all_chunks_in_chapter(collection, chapter_title, event_name, story_type)
133
+ prev_chunk, next_chunk = find_adjacent_chunks(meta, all_chunks)
134
+ expanded_text = []
135
+ if prev_chunk:
136
+ #expanded_text += prev_chunk["text"]
137
+ expanded_text += safe_to_list(prev_chunk["text"])
138
+ #expanded_text.extend(prev_chunk["text"])
139
+ #expanded_text += doc
140
+ expanded_text += safe_to_list(doc)
141
+
142
+ #expanded_text.extend(doc if isinstance(doc, list) else [doc])
143
+ if next_chunk:
144
+ #expanded_text.extend(next_chunk["text"])
145
+ #expanded_text += next_chunk["text"]
146
+ expanded_text += safe_to_list(next_chunk["text"])
147
+
148
+ expanded_results.append((
149
+ "\n".join(expanded_text),
150
+ score,
151
+ {
152
+ **meta,
153
+ #"prev_chunk_id": prev_chunk["ids"][0] if prev_chunk else None,
154
+ #"next_chunk_id": next_chunk["ids"][0] if next_chunk else None,
155
+ }
156
+ ))
157
+ return expanded_results
158
+
159
+ """if __name__ == "__main__":
160
+ collection = load_collection()
161
+ encoder = load_encoder()
162
+
163
+ query_text = "乐奈喜欢什么?"
164
+ query_vec = encode_query(encoder, query_text)
165
+ results = retrieve_docs(collection, query_vec, top_k=50)
166
+ reranked = query_rerank(reranker, query_text, results, top_n=20)
167
+ deduped = dedup_by_chapter_event(reranked, max_per_group=1)
168
+ expanded_results = expand_with_neighbors(deduped[:5], collection)
169
+
170
+ for doc in expanded_results:
171
+ print("===")
172
+ print(doc)
173
+ print(doc[0])
174
+ print("===")"""