Spaces:
Sleeping
Sleeping
File size: 9,281 Bytes
eabfc15 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 |
import os
from typing import List, Optional
from langchain_community.vectorstores import FAISS
from langchain_text_splitters import RecursiveCharacterTextSplitter
try:
from semantic_text_splitter import TextSplitter as SemanticTextSplitter # type: ignore
_HAS_SEMANTIC = True
except ImportError: # graceful fallback if package missing
_HAS_SEMANTIC = False
from langchain_core.documents import Document
from .embeddings import get_embedding_model
from .config import (
FAISS_INDEX_PATH,
HF_DATASET_REPO_ID,
HF_DATASET_REPO_TYPE,
FAISS_INDEX_REMOTE_DIR,
FAISS_INDEX_FILES,
)
from pathlib import Path
from typing import Tuple
import shutil
def _ensure_local_faiss_from_hub(index_dir: str) -> bool:
"""Download FAISS index files from Hugging Face Hub dataset repo if missing.
Returns True if files are present (downloaded or already existed), False otherwise.
"""
target = Path(index_dir)
target.mkdir(parents=True, exist_ok=True)
faiss_name, pkl_name = FAISS_INDEX_FILES
faiss_path = target / faiss_name
pkl_path = target / pkl_name
if faiss_path.exists() and pkl_path.exists():
return True
try:
from huggingface_hub import hf_hub_download, list_repo_files
def _download_pair(faiss_fname: str, meta_fname: str, remote_subfolder: Optional[str] = None) -> bool:
try:
# Download FAISS file
local_faiss = hf_hub_download(
repo_id=HF_DATASET_REPO_ID,
repo_type=HF_DATASET_REPO_TYPE,
filename=faiss_fname,
subfolder=remote_subfolder or FAISS_INDEX_REMOTE_DIR or None,
local_dir=str(target),
local_dir_use_symlinks=False,
)
# Download metadata file
local_meta = hf_hub_download(
repo_id=HF_DATASET_REPO_ID,
repo_type=HF_DATASET_REPO_TYPE,
filename=meta_fname,
subfolder=remote_subfolder or FAISS_INDEX_REMOTE_DIR or None,
local_dir=str(target),
local_dir_use_symlinks=False,
)
# Normalize file names in target so FAISS.load_local can find them
try:
dst_faiss = target / faiss_name
dst_meta = target / pkl_name
if Path(local_faiss) != dst_faiss:
shutil.copy2(local_faiss, dst_faiss)
if Path(local_meta) != dst_meta:
shutil.copy2(local_meta, dst_meta)
except Exception as copy_err:
print(f"[FAISS download] Copy to expected names failed: {copy_err}")
return (target / faiss_name).exists() and (target / pkl_name).exists()
except Exception:
return False
# First try configured names
if _download_pair(faiss_name, pkl_name, FAISS_INDEX_REMOTE_DIR):
return True
# Fallback: auto-discover by listing repository files
try:
files = list_repo_files(repo_id=HF_DATASET_REPO_ID, repo_type=HF_DATASET_REPO_TYPE)
except Exception as e:
print(f"[FAISS download] list_repo_files failed for {HF_DATASET_REPO_ID}: {e}")
files = []
def _in_remote_dir(path: str) -> bool:
if not FAISS_INDEX_REMOTE_DIR:
return True
return path.startswith(f"{FAISS_INDEX_REMOTE_DIR}/") or path == FAISS_INDEX_REMOTE_DIR
faiss_candidates = [f for f in files if f.lower().endswith('.faiss') and _in_remote_dir(f)]
meta_candidates = [
f for f in files if (f.lower().endswith('.pkl') or f.lower().endswith('.pickle')) and _in_remote_dir(f)
]
if faiss_candidates and meta_candidates:
# Take the first candidates
cand_faiss_path = faiss_candidates[0]
cand_meta_path = meta_candidates[0]
# Split into subfolder + filename
def _split_path(p: str) -> Tuple[Optional[str], str]:
if '/' in p:
d, b = p.rsplit('/', 1)
return d, b
return None, p
sub_faiss, base_faiss = _split_path(cand_faiss_path)
sub_meta, base_meta = _split_path(cand_meta_path)
# Prefer the shared subfolder if both live under the same dir
shared_sub = sub_faiss if sub_faiss == sub_meta else sub_faiss or sub_meta
if _download_pair(base_faiss, base_meta, shared_sub):
return True
print(
f"[FAISS download] Could not find/download FAISS pair in {HF_DATASET_REPO_ID}. "
f"Looked for {faiss_name} and {pkl_name}, candidates: {faiss_candidates} / {meta_candidates}"
)
return False
except Exception as e:
print(f"[FAISS download] Could not fetch from Hub ({HF_DATASET_REPO_ID}): {e}")
return False
def _semantic_chunk_documents(
documents: List[Document],
chunk_size: int,
chunk_overlap: int
) -> List[Document]:
# Newer versions expose factory; fallback to direct init
if hasattr(SemanticTextSplitter, "from_tiktoken_encoder"):
splitter = SemanticTextSplitter.from_tiktoken_encoder(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
else: # try simple init signature
splitter = SemanticTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
semantic_chunks: List[Document] = []
for d in documents:
try:
parts = splitter.chunks(d.page_content)
except AttributeError:
# Fallback: naive sentence-ish split
parts = d.page_content.split('. ')
for part in parts:
cleaned = part.strip()
if cleaned:
semantic_chunks.append(
Document(page_content=cleaned, metadata=d.metadata)
)
return semantic_chunks
def _chunk_documents(
documents: List[Document],
method: str = "recursive",
chunk_size: int = 1000,
chunk_overlap: int = 120
):
if method == "semantic" and _HAS_SEMANTIC:
try:
return _semantic_chunk_documents(documents, chunk_size, chunk_overlap)
except Exception as e:
print(f"[semantic chunking fallback] {e}; reverting to recursive splitter.")
# fallback / default
splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size, chunk_overlap=chunk_overlap, add_start_index=True
)
return splitter.split_documents(documents)
def build_or_load_vectorstore(
documents: List[Document],
force_rebuild: bool = False,
chunk_method: str = "recursive", # or "semantic"
chunk_size: int = 1000,
chunk_overlap: int = 120
):
# Ensure local index exists (download from Hub if needed)
if not os.path.exists(FAISS_INDEX_PATH):
fetched = _ensure_local_faiss_from_hub(FAISS_INDEX_PATH)
if fetched:
print(f"Downloaded FAISS index from Hub into {FAISS_INDEX_PATH}")
if os.path.exists(FAISS_INDEX_PATH) and not force_rebuild:
print(f"Loading existing FAISS index from {FAISS_INDEX_PATH}...")
try:
vectorstore = FAISS.load_local(
FAISS_INDEX_PATH,
get_embedding_model(),
allow_dangerous_deserialization=True
)
print("Vector store loaded successfully.")
return vectorstore
except Exception as e:
print(f"Failed to load FAISS index due to: {e}")
if not documents:
raise RuntimeError(
"Existing FAISS index is incompatible with current libraries and no documents were "
"provided to rebuild it. Delete 'faiss_index' and rebuild, or pass documents to rebuild."
) from e
print("Rebuilding FAISS index from provided documents...")
print("Building FAISS index (force_rebuild=%s, method=%s)..." % (force_rebuild, chunk_method))
splits = _chunk_documents(
documents,
method=chunk_method,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap
)
print(f"Split {len(documents)} docs into {len(splits)} chunks (method={chunk_method}).")
vectorstore = FAISS.from_documents(splits, get_embedding_model())
vectorstore.save_local(FAISS_INDEX_PATH)
print(f"Vector store created and saved to {FAISS_INDEX_PATH}")
return vectorstore
def build_filtered_retriever(vectorstore, primary_category: Optional[str] = None, k: int = 3):
base = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": k})
if not primary_category:
return base
# Simple wrapper applying post-filtering by metadata; could be replaced by a VectorStore-specific filter if supported
def _get_relevant_documents(query):
docs = base.get_relevant_documents(query)
return [d for d in docs if d.metadata.get("primary_category") == primary_category]
base.get_relevant_documents = _get_relevant_documents # monkey patch
return base
|