Upload 3 files
Browse files- embedder.py +241 -0
- llm_agent.py +56 -0
- retriever.py +174 -0
embedder.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from sentence_transformers import SentenceTransformer
|
| 2 |
+
from sentence_transformers.sparse_encoder import SparseEncoder
|
| 3 |
+
import torch
|
| 4 |
+
import os
|
| 5 |
+
import re
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
import numpy as np
|
| 8 |
+
import chromadb
|
| 9 |
+
import json
|
| 10 |
+
import time
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
"""
|
| 14 |
+
Currently only do dense encoding
|
| 15 |
+
Sparse encoding related functions are placeholders
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
DENSE_EMBEDDER_MODEL = "BAAI/bge-base-zh-v1.5"
|
| 19 |
+
SPARSE_EMBEDDER_MODEL = "opensearch-project/opensearch-neural-sparse-encoding-doc-v2-distill"
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class TextCleaner:
|
| 23 |
+
def __init__(self, lowercase=False, remove_urls=True, normalize_space=True):
|
| 24 |
+
self.lowercase = lowercase
|
| 25 |
+
self.remove_urls = remove_urls
|
| 26 |
+
self.normalize_space = normalize_space
|
| 27 |
+
|
| 28 |
+
def clean(self, text: str) -> str:
|
| 29 |
+
text = text.strip()
|
| 30 |
+
if self.lowercase:
|
| 31 |
+
text = text.lower()
|
| 32 |
+
if self.remove_urls:
|
| 33 |
+
text = re.sub(r"http\S+", "", text)
|
| 34 |
+
if self.normalize_space:
|
| 35 |
+
text = re.sub(r"\s+", " ", text)
|
| 36 |
+
return text
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def join_chunk_text(text_chunk):
|
| 40 |
+
# support chunk (list of sentences) processing
|
| 41 |
+
if isinstance(text_chunk, list):
|
| 42 |
+
return "\n".join(text_chunk)
|
| 43 |
+
return text_chunk
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class DenseTextEncoder:
|
| 47 |
+
"""
|
| 48 |
+
output: numpy array
|
| 49 |
+
"""
|
| 50 |
+
def __init__(self, model_name, normalize=True, device=None):
|
| 51 |
+
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 52 |
+
self.model = SentenceTransformer(model_name)
|
| 53 |
+
self.normalize = normalize
|
| 54 |
+
self.cleaner = TextCleaner()
|
| 55 |
+
|
| 56 |
+
def _prepare_texts(self, texts):
|
| 57 |
+
"""Support single string, list[str], or list[list[str]]"""
|
| 58 |
+
if isinstance(texts, str):
|
| 59 |
+
texts = [texts]
|
| 60 |
+
elif isinstance(texts, list):
|
| 61 |
+
if all(isinstance(t, str) for t in texts):
|
| 62 |
+
texts = [join_chunk_text(t) if isinstance(t, list) else t for t in texts]
|
| 63 |
+
elif all(isinstance(t, list) for t in texts):
|
| 64 |
+
texts = [join_chunk_text(t) for t in texts]
|
| 65 |
+
else:
|
| 66 |
+
raise ValueError("Input list must contain only str or list[str].")
|
| 67 |
+
else:
|
| 68 |
+
raise ValueError("Input must be str or list.")
|
| 69 |
+
cleaned = [self.cleaner.clean(t) for t in texts]
|
| 70 |
+
return cleaned
|
| 71 |
+
|
| 72 |
+
def encode_document(self, texts):
|
| 73 |
+
cleaned = self._prepare_texts(texts)
|
| 74 |
+
output = self.model.encode_document(cleaned, convert_to_numpy=True, normalize_embeddings=self.normalize)
|
| 75 |
+
return output
|
| 76 |
+
|
| 77 |
+
def encode_query(self, texts):
|
| 78 |
+
cleaned = self._prepare_texts(texts)
|
| 79 |
+
output = self.model.encode_query(cleaned, convert_to_numpy=True, normalize_embeddings=self.normalize)
|
| 80 |
+
return output
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class SparseTextEncoder:
|
| 84 |
+
"""
|
| 85 |
+
output: torch tensor
|
| 86 |
+
"""
|
| 87 |
+
def __init__(self, model_name, device=None):
|
| 88 |
+
self.device = device or "cpu"
|
| 89 |
+
self.encoder = SparseEncoder(model_name, device=self.device)
|
| 90 |
+
self.cleaner = TextCleaner()
|
| 91 |
+
|
| 92 |
+
def _prepare_texts(self, texts):
|
| 93 |
+
"""Support single string, list[str], or list[list[str]]"""
|
| 94 |
+
if isinstance(texts, str):
|
| 95 |
+
texts = [texts]
|
| 96 |
+
elif isinstance(texts, list):
|
| 97 |
+
if all(isinstance(t, str) for t in texts):
|
| 98 |
+
texts = [join_chunk_text(t) if isinstance(t, list) else t for t in texts]
|
| 99 |
+
elif all(isinstance(t, list) for t in texts):
|
| 100 |
+
texts = [join_chunk_text(t) for t in texts]
|
| 101 |
+
else:
|
| 102 |
+
raise ValueError("Input list must contain only str or list[str].")
|
| 103 |
+
else:
|
| 104 |
+
raise ValueError("Input must be str or list.")
|
| 105 |
+
cleaned = [self.cleaner.clean(t) for t in texts]
|
| 106 |
+
return cleaned
|
| 107 |
+
|
| 108 |
+
def encode_document(self, texts):
|
| 109 |
+
"""Encode for corpus indexing"""
|
| 110 |
+
cleaned = self._prepare_texts(texts)
|
| 111 |
+
return self.encoder.encode_document(cleaned)
|
| 112 |
+
|
| 113 |
+
def encode_query(self, texts):
|
| 114 |
+
"""Encode for query retrieval"""
|
| 115 |
+
cleaned = self._prepare_texts(texts)
|
| 116 |
+
return self.encoder.encode_query(cleaned)
|
| 117 |
+
|
| 118 |
+
def read_input(source):
|
| 119 |
+
if os.path.exists(source):
|
| 120 |
+
with open(source, "r", encoding="utf-8") as f:
|
| 121 |
+
lines = [line.strip() for line in f if line.strip()]
|
| 122 |
+
return lines
|
| 123 |
+
else:
|
| 124 |
+
return [source]
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def encode_chunks_with_metadata(chunks, dense_encoder, sparse_encoder):
|
| 128 |
+
"""
|
| 129 |
+
:param chunks: [{'text': [...], 'chunk_id': ..., ...metadata}, ...]
|
| 130 |
+
:param dense_encoder: dense encoder model
|
| 131 |
+
:param sparse_encoder: sparse encoder model
|
| 132 |
+
:return:
|
| 133 |
+
{
|
| 134 |
+
"chunk_id": "9b28e9938292486e9a61f2d1787bb828",
|
| 135 |
+
"dense_embedding": np.array([...]),
|
| 136 |
+
"sparse_embedding": torch.sparse.Tensor(...),
|
| 137 |
+
"text": "友希那: ...\n莉莎: ...",
|
| 138 |
+
"eventName": "连结思绪的未竟之歌",
|
| 139 |
+
"chapterTitle": "序章: 古旧的磁带",
|
| 140 |
+
"story_type": "event",
|
| 141 |
+
# ...other metadata
|
| 142 |
+
}
|
| 143 |
+
"""
|
| 144 |
+
text = [join_chunk_text(chunk["text"]) for chunk in chunks]
|
| 145 |
+
dense_vecs = dense_encoder.encode_document(text)
|
| 146 |
+
# placeholder, skip sparse encoding for now
|
| 147 |
+
#sparse_vecs = sparse_encoder.encode_document(text)
|
| 148 |
+
result = []
|
| 149 |
+
for i, chunk in enumerate(chunks):
|
| 150 |
+
|
| 151 |
+
# placeholder, skip sparse encoding for now
|
| 152 |
+
#sparse_i = sparse_vecs[i]
|
| 153 |
+
#if isinstance(sparse_i, torch.Tensor) and sparse_i.is_sparse:
|
| 154 |
+
# sparse_i = sparse_i.coalesce()
|
| 155 |
+
|
| 156 |
+
result.append({
|
| 157 |
+
"chunk_id": chunk.get("chunk_id"),
|
| 158 |
+
"dense_embedding": dense_vecs[i],
|
| 159 |
+
# placeholder, skip sparse encoding for now
|
| 160 |
+
"sparse_embedding": None,
|
| 161 |
+
#"sparse_embedding": sparse_vecs[i],
|
| 162 |
+
"text": text[i],
|
| 163 |
+
"eventName": chunk.get("eventName"),
|
| 164 |
+
"chapterTitle": chunk.get("chapterTitle"),
|
| 165 |
+
"story_type": chunk.get("story_type"),
|
| 166 |
+
"start_idx": chunk.get("start_idx"),
|
| 167 |
+
"end_idx": chunk.get("end_idx"),
|
| 168 |
+
})
|
| 169 |
+
return result
|
| 170 |
+
|
| 171 |
+
# save dense embedding to chroma vector database
|
| 172 |
+
def save_chunks_to_chroma(embedded_chunk, collection):
|
| 173 |
+
ids = []
|
| 174 |
+
documents = []
|
| 175 |
+
embeddings = []
|
| 176 |
+
metadata = []
|
| 177 |
+
|
| 178 |
+
for entry in embedded_chunk:
|
| 179 |
+
ids.append(entry["chunk_id"])
|
| 180 |
+
documents.append(entry["text"])
|
| 181 |
+
embeddings.append(
|
| 182 |
+
entry["dense_embedding"].tolist() if isinstance(entry["dense_embedding"], np.ndarray) else entry[
|
| 183 |
+
"dense_embedding"])
|
| 184 |
+
# currently do not store sparse embedding to chroma
|
| 185 |
+
meta = {k: v for k, v in entry.items() if k not in ["chunk_id", "dense_embedding","sparse_embedding", "text"]}
|
| 186 |
+
metadata.append(meta)
|
| 187 |
+
|
| 188 |
+
batch_size = 64
|
| 189 |
+
for i in range(0, len(ids), batch_size):
|
| 190 |
+
collection.add(
|
| 191 |
+
ids=ids[i:i + batch_size],
|
| 192 |
+
documents=documents[i:i + batch_size],
|
| 193 |
+
embeddings=embeddings[i:i + batch_size],
|
| 194 |
+
metadatas=metadata[i:i + batch_size]
|
| 195 |
+
)
|
| 196 |
+
print(f"saved {len(ids)} chunks to {collection.name}")
|
| 197 |
+
|
| 198 |
+
def read_jsonl_in_batches(file_path, batch_size=64):
|
| 199 |
+
batch = []
|
| 200 |
+
with open(file_path, 'r', encoding='utf8') as f:
|
| 201 |
+
for line in f:
|
| 202 |
+
if line.strip():
|
| 203 |
+
batch.append(json.loads(line))
|
| 204 |
+
if len(batch) == batch_size:
|
| 205 |
+
yield batch
|
| 206 |
+
batch = []
|
| 207 |
+
if batch:
|
| 208 |
+
yield batch
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
if __name__ == "__main__":
|
| 212 |
+
chunk_files = [
|
| 213 |
+
"./chunks/band_chunks.jsonl",
|
| 214 |
+
"./chunks/card_chunks.jsonl",
|
| 215 |
+
"./chunks/event_chunks.jsonl",
|
| 216 |
+
"./chunks/main_chunks.jsonl"
|
| 217 |
+
]
|
| 218 |
+
|
| 219 |
+
dense_encoder = DenseTextEncoder(DENSE_EMBEDDER_MODEL)
|
| 220 |
+
sparse_encoder = SparseTextEncoder(SPARSE_EMBEDDER_MODEL)
|
| 221 |
+
|
| 222 |
+
# init databases
|
| 223 |
+
chroma_client = chromadb.PersistentClient(path="./chroma_db")
|
| 224 |
+
chroma_collection = chroma_client.get_or_create_collection("bangdream_dense")
|
| 225 |
+
|
| 226 |
+
start_time = time.time()
|
| 227 |
+
for file_path in chunk_files:
|
| 228 |
+
|
| 229 |
+
with open(file_path, 'r', encoding='utf8') as f:
|
| 230 |
+
total_lines = sum(1 for line in f if line.strip())
|
| 231 |
+
print(f"\nProcessing {file_path} ({total_lines} chunks)")
|
| 232 |
+
|
| 233 |
+
pbar = tqdm(total=total_lines, desc=f"Encoding {os.path.basename(file_path)}", unit="chunk")
|
| 234 |
+
|
| 235 |
+
for batch in read_jsonl_in_batches(file_path, batch_size=64):
|
| 236 |
+
embedded = encode_chunks_with_metadata(batch, dense_encoder, sparse_encoder)
|
| 237 |
+
save_chunks_to_chroma(embedded, chroma_collection)
|
| 238 |
+
pbar.update(len(batch))
|
| 239 |
+
pbar.close()
|
| 240 |
+
end_time = time.time()
|
| 241 |
+
print(f"Total time used: {end_time - start_time}")
|
llm_agent.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from openai import OpenAI
|
| 2 |
+
from dotenv import load_dotenv
|
| 3 |
+
import os
|
| 4 |
+
from retriever import load_encoder, load_collection, encode_query, retrieve_docs, query_rerank, expand_with_neighbors, dedup_by_chapter_event
|
| 5 |
+
from sentence_transformers import CrossEncoder
|
| 6 |
+
|
| 7 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 8 |
+
# load llm api key in .env
|
| 9 |
+
load_dotenv()
|
| 10 |
+
api_key = os.getenv("OPENAI_API_KEY")
|
| 11 |
+
|
| 12 |
+
client = OpenAI(api_key=api_key)
|
| 13 |
+
|
| 14 |
+
def build_rag_prompt(query, context):
|
| 15 |
+
prompt = f"""已知资料如下:
|
| 16 |
+
{context}
|
| 17 |
+
|
| 18 |
+
用户提问:{query}
|
| 19 |
+
请参考所有已知资料, 并结合资料内容,简明、准确地回答问题。如果有多个符合的答案, 可以根据你是否确定而决定是否分别陈述这些答案.如果不能确定答案,请如实说明理由,不要凭空编造。"""
|
| 20 |
+
return prompt
|
| 21 |
+
|
| 22 |
+
def llm_answer(query, expanded_results, model_name="gpt-4o"):
|
| 23 |
+
context = expanded_results[0][0] if expanded_results else ""
|
| 24 |
+
prompt = build_rag_prompt(query, context)
|
| 25 |
+
response = client.chat.completions.create(
|
| 26 |
+
model=model_name,
|
| 27 |
+
messages=[
|
| 28 |
+
{"role": "system", "content": "你是BangDream知识问答助手, 也就是邦学家. 只能基于提供的资料内容作答。"},
|
| 29 |
+
{"role": "user", "content": prompt}
|
| 30 |
+
],
|
| 31 |
+
temperature=0.2,
|
| 32 |
+
max_tokens=512,
|
| 33 |
+
)
|
| 34 |
+
return response.choices[0].message.content.strip()
|
| 35 |
+
|
| 36 |
+
if __name__ == "__main__":
|
| 37 |
+
collection = load_collection()
|
| 38 |
+
encoder = load_encoder()
|
| 39 |
+
reranker = CrossEncoder("BAAI/bge-reranker-large")
|
| 40 |
+
|
| 41 |
+
query_text = input("please enter your question:")
|
| 42 |
+
print("Thinking...\n...")
|
| 43 |
+
query_vec = encode_query(encoder, query_text)
|
| 44 |
+
results = retrieve_docs(collection, query_vec, top_k=50)
|
| 45 |
+
reranked = query_rerank(reranker, query_text, results, top_n=20)
|
| 46 |
+
deduped = dedup_by_chapter_event(reranked, max_per_group=1)
|
| 47 |
+
expanded_results = expand_with_neighbors(deduped[:5], collection)
|
| 48 |
+
|
| 49 |
+
answer = llm_answer(query_text, expanded_results)
|
| 50 |
+
|
| 51 |
+
print("\n=== Answer ===")
|
| 52 |
+
print(answer)
|
| 53 |
+
print("\n=== retrieved documents ===")
|
| 54 |
+
for idx, (context, score, meta) in enumerate(expanded_results, 1):
|
| 55 |
+
print(f"\n--- document {idx} (Score={score:.4f}) ---\n{context[:200]}...")
|
| 56 |
+
print(meta)
|
retriever.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import chromadb
|
| 2 |
+
from sentence_transformers import SentenceTransformer, CrossEncoder
|
| 3 |
+
|
| 4 |
+
CHROMA_DB_DIR = "./chroma_db"
|
| 5 |
+
COLLECTION_NAME = "bangdream_dense"
|
| 6 |
+
MODEL_NAME = "BAAI/bge-base-zh-v1.5"
|
| 7 |
+
|
| 8 |
+
reranker = CrossEncoder("BAAI/bge-reranker-large")
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def load_collection(db_path=CHROMA_DB_DIR, collection_name=COLLECTION_NAME):
|
| 12 |
+
"""Connect to Chroma persistent DB and load a collection."""
|
| 13 |
+
client = chromadb.PersistentClient(path=db_path)
|
| 14 |
+
collection = client.get_or_create_collection(collection_name)
|
| 15 |
+
return collection
|
| 16 |
+
|
| 17 |
+
def load_encoder(model_name=MODEL_NAME):
|
| 18 |
+
"""Load dense encoder model."""
|
| 19 |
+
return SentenceTransformer(model_name)
|
| 20 |
+
|
| 21 |
+
def encode_query(encoder, query_text):
|
| 22 |
+
"""Encode query text into normalized embedding."""
|
| 23 |
+
return encoder.encode_query([query_text], normalize_embeddings=True)
|
| 24 |
+
|
| 25 |
+
def dedup_by_chapter_event(reranked_docs, max_per_group=1):
|
| 26 |
+
"""de-duplicate when chapterTitle and eventName are identical"""
|
| 27 |
+
seen = {}
|
| 28 |
+
deduped = []
|
| 29 |
+
for doc, score, meta in reranked_docs:
|
| 30 |
+
key = (meta.get("chapterTitle", ""), meta.get("eventName", ""))
|
| 31 |
+
if key not in seen:
|
| 32 |
+
seen[key] = 1
|
| 33 |
+
deduped.append((doc, score, meta))
|
| 34 |
+
elif seen[key] < max_per_group:
|
| 35 |
+
seen[key] += 1
|
| 36 |
+
deduped.append((doc, score, meta))
|
| 37 |
+
return deduped
|
| 38 |
+
|
| 39 |
+
def retrieve_docs(collection, query_vec, top_k=5):
|
| 40 |
+
"""Retrieve documents from Chroma collection."""
|
| 41 |
+
results = collection.query(
|
| 42 |
+
query_embeddings=query_vec,
|
| 43 |
+
n_results=top_k,
|
| 44 |
+
include=["metadatas", "documents", "distances"],
|
| 45 |
+
)
|
| 46 |
+
return results
|
| 47 |
+
|
| 48 |
+
def query_rerank(reranker, query, results, top_n=3):
|
| 49 |
+
"""Use CrossEncoder to re-rank retrieved results."""
|
| 50 |
+
docs = results["documents"][0]
|
| 51 |
+
pairs = [(query, doc) for doc in docs]
|
| 52 |
+
|
| 53 |
+
# CrossEncoder
|
| 54 |
+
scores = reranker.predict(pairs)
|
| 55 |
+
|
| 56 |
+
# rerank
|
| 57 |
+
ranked = sorted(zip(docs, scores, results["metadatas"][0]), key=lambda x: x[1], reverse=True)
|
| 58 |
+
|
| 59 |
+
# get top_n
|
| 60 |
+
reranked_docs = ranked[:top_n]
|
| 61 |
+
|
| 62 |
+
"""
|
| 63 |
+
# print result
|
| 64 |
+
print("=== After Rerank ===")
|
| 65 |
+
for i, (doc, score, meta) in enumerate(reranked_docs, 1):
|
| 66 |
+
print(f"Rank {i} | Score: {score:.4f}")
|
| 67 |
+
print(meta)
|
| 68 |
+
print(doc)
|
| 69 |
+
print("-" * 40)
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
return reranked_docs
|
| 73 |
+
|
| 74 |
+
def pretty_print_results(results):
|
| 75 |
+
"""Nicely print retrieved results."""
|
| 76 |
+
docs = results["documents"][0]
|
| 77 |
+
dists = results["distances"][0]
|
| 78 |
+
metas = results["metadatas"][0]
|
| 79 |
+
for idx, (doc, dist, meta) in enumerate(zip(docs, dists, metas)):
|
| 80 |
+
print(f"Rank {idx + 1} | Distance: {dist:.4f}")
|
| 81 |
+
print(meta)
|
| 82 |
+
print(doc)
|
| 83 |
+
print("-" * 40)
|
| 84 |
+
|
| 85 |
+
# expend documents
|
| 86 |
+
def get_all_chunks_in_chapter(collection, chapter_title, event_name=None, story_type=None):
|
| 87 |
+
filters = []
|
| 88 |
+
if chapter_title:
|
| 89 |
+
filters.append({"chapterTitle": chapter_title})
|
| 90 |
+
if story_type:
|
| 91 |
+
filters.append({"story_type": story_type})
|
| 92 |
+
if event_name:
|
| 93 |
+
filters.append({"eventName": event_name})
|
| 94 |
+
if len(filters) == 1:
|
| 95 |
+
filter_dict = filters[0]
|
| 96 |
+
elif len(filters) > 1:
|
| 97 |
+
filter_dict = {"$and": filters}
|
| 98 |
+
else:
|
| 99 |
+
filter_dict = {}
|
| 100 |
+
results = collection.get(where=filter_dict, include=["documents", "metadatas"])
|
| 101 |
+
chunk_list = []
|
| 102 |
+
for doc, meta in zip(results["documents"], results["metadatas"]):
|
| 103 |
+
chunk_list.append({
|
| 104 |
+
"text": doc,
|
| 105 |
+
**meta,
|
| 106 |
+
})
|
| 107 |
+
return chunk_list
|
| 108 |
+
|
| 109 |
+
def find_adjacent_chunks(current_chunk, all_chunks):
|
| 110 |
+
start_idx = current_chunk['start_idx']
|
| 111 |
+
end_idx = current_chunk['end_idx']
|
| 112 |
+
prev_chunk, next_chunk = None, None
|
| 113 |
+
for chunk in all_chunks:
|
| 114 |
+
if chunk['end_idx'] == start_idx - 1:
|
| 115 |
+
prev_chunk = chunk
|
| 116 |
+
if chunk['start_idx'] == end_idx + 1:
|
| 117 |
+
next_chunk = chunk
|
| 118 |
+
return prev_chunk, next_chunk
|
| 119 |
+
|
| 120 |
+
def safe_to_list(x):
|
| 121 |
+
if isinstance(x, str):
|
| 122 |
+
return x.split('\n') if '\n' in x else [x]
|
| 123 |
+
return list(x)
|
| 124 |
+
|
| 125 |
+
def expand_with_neighbors(reranked_docs, collection):
|
| 126 |
+
expanded_results = []
|
| 127 |
+
for doc, score, meta in reranked_docs:
|
| 128 |
+
#print(meta)
|
| 129 |
+
chapter_title = meta.get("chapterTitle", "")
|
| 130 |
+
event_name = meta.get("eventName", "")
|
| 131 |
+
story_type = meta.get("story_type", None)
|
| 132 |
+
all_chunks = get_all_chunks_in_chapter(collection, chapter_title, event_name, story_type)
|
| 133 |
+
prev_chunk, next_chunk = find_adjacent_chunks(meta, all_chunks)
|
| 134 |
+
expanded_text = []
|
| 135 |
+
if prev_chunk:
|
| 136 |
+
#expanded_text += prev_chunk["text"]
|
| 137 |
+
expanded_text += safe_to_list(prev_chunk["text"])
|
| 138 |
+
#expanded_text.extend(prev_chunk["text"])
|
| 139 |
+
#expanded_text += doc
|
| 140 |
+
expanded_text += safe_to_list(doc)
|
| 141 |
+
|
| 142 |
+
#expanded_text.extend(doc if isinstance(doc, list) else [doc])
|
| 143 |
+
if next_chunk:
|
| 144 |
+
#expanded_text.extend(next_chunk["text"])
|
| 145 |
+
#expanded_text += next_chunk["text"]
|
| 146 |
+
expanded_text += safe_to_list(next_chunk["text"])
|
| 147 |
+
|
| 148 |
+
expanded_results.append((
|
| 149 |
+
"\n".join(expanded_text),
|
| 150 |
+
score,
|
| 151 |
+
{
|
| 152 |
+
**meta,
|
| 153 |
+
#"prev_chunk_id": prev_chunk["ids"][0] if prev_chunk else None,
|
| 154 |
+
#"next_chunk_id": next_chunk["ids"][0] if next_chunk else None,
|
| 155 |
+
}
|
| 156 |
+
))
|
| 157 |
+
return expanded_results
|
| 158 |
+
|
| 159 |
+
"""if __name__ == "__main__":
|
| 160 |
+
collection = load_collection()
|
| 161 |
+
encoder = load_encoder()
|
| 162 |
+
|
| 163 |
+
query_text = "乐奈喜欢什么?"
|
| 164 |
+
query_vec = encode_query(encoder, query_text)
|
| 165 |
+
results = retrieve_docs(collection, query_vec, top_k=50)
|
| 166 |
+
reranked = query_rerank(reranker, query_text, results, top_n=20)
|
| 167 |
+
deduped = dedup_by_chapter_event(reranked, max_per_group=1)
|
| 168 |
+
expanded_results = expand_with_neighbors(deduped[:5], collection)
|
| 169 |
+
|
| 170 |
+
for doc in expanded_results:
|
| 171 |
+
print("===")
|
| 172 |
+
print(doc)
|
| 173 |
+
print(doc[0])
|
| 174 |
+
print("===")"""
|