File size: 9,281 Bytes
eabfc15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
import os
from typing import List, Optional
from langchain_community.vectorstores import FAISS
from langchain_text_splitters import RecursiveCharacterTextSplitter
try:
    from semantic_text_splitter import TextSplitter as SemanticTextSplitter  # type: ignore
    _HAS_SEMANTIC = True
except ImportError:  # graceful fallback if package missing
    _HAS_SEMANTIC = False
from langchain_core.documents import Document
from .embeddings import get_embedding_model
from .config import (
    FAISS_INDEX_PATH,
    HF_DATASET_REPO_ID,
    HF_DATASET_REPO_TYPE,
    FAISS_INDEX_REMOTE_DIR,
    FAISS_INDEX_FILES,
)
from pathlib import Path
from typing import Tuple
import shutil

def _ensure_local_faiss_from_hub(index_dir: str) -> bool:
    """Download FAISS index files from Hugging Face Hub dataset repo if missing.

    Returns True if files are present (downloaded or already existed), False otherwise.
    """
    target = Path(index_dir)
    target.mkdir(parents=True, exist_ok=True)
    faiss_name, pkl_name = FAISS_INDEX_FILES
    faiss_path = target / faiss_name
    pkl_path = target / pkl_name
    if faiss_path.exists() and pkl_path.exists():
        return True
    try:
        from huggingface_hub import hf_hub_download, list_repo_files

        def _download_pair(faiss_fname: str, meta_fname: str, remote_subfolder: Optional[str] = None) -> bool:
            try:
                # Download FAISS file
                local_faiss = hf_hub_download(
                    repo_id=HF_DATASET_REPO_ID,
                    repo_type=HF_DATASET_REPO_TYPE,
                    filename=faiss_fname,
                    subfolder=remote_subfolder or FAISS_INDEX_REMOTE_DIR or None,
                    local_dir=str(target),
                    local_dir_use_symlinks=False,
                )
                # Download metadata file
                local_meta = hf_hub_download(
                    repo_id=HF_DATASET_REPO_ID,
                    repo_type=HF_DATASET_REPO_TYPE,
                    filename=meta_fname,
                    subfolder=remote_subfolder or FAISS_INDEX_REMOTE_DIR or None,
                    local_dir=str(target),
                    local_dir_use_symlinks=False,
                )
                # Normalize file names in target so FAISS.load_local can find them
                try:
                    dst_faiss = target / faiss_name
                    dst_meta = target / pkl_name
                    if Path(local_faiss) != dst_faiss:
                        shutil.copy2(local_faiss, dst_faiss)
                    if Path(local_meta) != dst_meta:
                        shutil.copy2(local_meta, dst_meta)
                except Exception as copy_err:
                    print(f"[FAISS download] Copy to expected names failed: {copy_err}")
                return (target / faiss_name).exists() and (target / pkl_name).exists()
            except Exception:
                return False

        # First try configured names
        if _download_pair(faiss_name, pkl_name, FAISS_INDEX_REMOTE_DIR):
            return True

        # Fallback: auto-discover by listing repository files
        try:
            files = list_repo_files(repo_id=HF_DATASET_REPO_ID, repo_type=HF_DATASET_REPO_TYPE)
        except Exception as e:
            print(f"[FAISS download] list_repo_files failed for {HF_DATASET_REPO_ID}: {e}")
            files = []

        def _in_remote_dir(path: str) -> bool:
            if not FAISS_INDEX_REMOTE_DIR:
                return True
            return path.startswith(f"{FAISS_INDEX_REMOTE_DIR}/") or path == FAISS_INDEX_REMOTE_DIR

        faiss_candidates = [f for f in files if f.lower().endswith('.faiss') and _in_remote_dir(f)]
        meta_candidates = [
            f for f in files if (f.lower().endswith('.pkl') or f.lower().endswith('.pickle')) and _in_remote_dir(f)
        ]
        if faiss_candidates and meta_candidates:
            # Take the first candidates
            cand_faiss_path = faiss_candidates[0]
            cand_meta_path = meta_candidates[0]
            # Split into subfolder + filename
            def _split_path(p: str) -> Tuple[Optional[str], str]:
                if '/' in p:
                    d, b = p.rsplit('/', 1)
                    return d, b
                return None, p
            sub_faiss, base_faiss = _split_path(cand_faiss_path)
            sub_meta, base_meta = _split_path(cand_meta_path)
            # Prefer the shared subfolder if both live under the same dir
            shared_sub = sub_faiss if sub_faiss == sub_meta else sub_faiss or sub_meta
            if _download_pair(base_faiss, base_meta, shared_sub):
                return True

        print(
            f"[FAISS download] Could not find/download FAISS pair in {HF_DATASET_REPO_ID}. "
            f"Looked for {faiss_name} and {pkl_name}, candidates: {faiss_candidates} / {meta_candidates}"
        )
        return False
    except Exception as e:
        print(f"[FAISS download] Could not fetch from Hub ({HF_DATASET_REPO_ID}): {e}")
        return False

def _semantic_chunk_documents(
    documents: List[Document],
    chunk_size: int,
    chunk_overlap: int
) -> List[Document]:
    # Newer versions expose factory; fallback to direct init
    if hasattr(SemanticTextSplitter, "from_tiktoken_encoder"):
        splitter = SemanticTextSplitter.from_tiktoken_encoder(
            chunk_size=chunk_size,
            chunk_overlap=chunk_overlap,
        )
    else:  # try simple init signature
        splitter = SemanticTextSplitter(
            chunk_size=chunk_size,
            chunk_overlap=chunk_overlap,
        )
    semantic_chunks: List[Document] = []
    for d in documents:
        try:
            parts = splitter.chunks(d.page_content)
        except AttributeError:
            # Fallback: naive sentence-ish split
            parts = d.page_content.split('. ')
        for part in parts:
            cleaned = part.strip()
            if cleaned:
                semantic_chunks.append(
                    Document(page_content=cleaned, metadata=d.metadata)
                )
    return semantic_chunks

def _chunk_documents(
    documents: List[Document],
    method: str = "recursive",
    chunk_size: int = 1000,
    chunk_overlap: int = 120
):
    if method == "semantic" and _HAS_SEMANTIC:
        try:
            return _semantic_chunk_documents(documents, chunk_size, chunk_overlap)
        except Exception as e:
            print(f"[semantic chunking fallback] {e}; reverting to recursive splitter.")
    # fallback / default
    splitter = RecursiveCharacterTextSplitter(
        chunk_size=chunk_size, chunk_overlap=chunk_overlap, add_start_index=True
    )
    return splitter.split_documents(documents)

def build_or_load_vectorstore(
    documents: List[Document],
    force_rebuild: bool = False,
    chunk_method: str = "recursive",  # or "semantic"
    chunk_size: int = 1000,
    chunk_overlap: int = 120
):
    # Ensure local index exists (download from Hub if needed)
    if not os.path.exists(FAISS_INDEX_PATH):
        fetched = _ensure_local_faiss_from_hub(FAISS_INDEX_PATH)
        if fetched:
            print(f"Downloaded FAISS index from Hub into {FAISS_INDEX_PATH}")

    if os.path.exists(FAISS_INDEX_PATH) and not force_rebuild:
        print(f"Loading existing FAISS index from {FAISS_INDEX_PATH}...")
        try:
            vectorstore = FAISS.load_local(
                FAISS_INDEX_PATH,
                get_embedding_model(),
                allow_dangerous_deserialization=True
            )
            print("Vector store loaded successfully.")
            return vectorstore
        except Exception as e:
            print(f"Failed to load FAISS index due to: {e}")
            if not documents:
                raise RuntimeError(
                    "Existing FAISS index is incompatible with current libraries and no documents were "
                    "provided to rebuild it. Delete 'faiss_index' and rebuild, or pass documents to rebuild."
                ) from e
            print("Rebuilding FAISS index from provided documents...")

    print("Building FAISS index (force_rebuild=%s, method=%s)..." % (force_rebuild, chunk_method))
    splits = _chunk_documents(
        documents,
        method=chunk_method,
        chunk_size=chunk_size,
        chunk_overlap=chunk_overlap
    )
    print(f"Split {len(documents)} docs into {len(splits)} chunks (method={chunk_method}).")
    vectorstore = FAISS.from_documents(splits, get_embedding_model())
    vectorstore.save_local(FAISS_INDEX_PATH)
    print(f"Vector store created and saved to {FAISS_INDEX_PATH}")
    return vectorstore

def build_filtered_retriever(vectorstore, primary_category: Optional[str] = None, k: int = 3):
    base = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": k})
    if not primary_category:
        return base
    # Simple wrapper applying post-filtering by metadata; could be replaced by a VectorStore-specific filter if supported
    def _get_relevant_documents(query):
        docs = base.get_relevant_documents(query)
        return [d for d in docs if d.metadata.get("primary_category") == primary_category]
    base.get_relevant_documents = _get_relevant_documents  # monkey patch
    return base