Wasifjafri commited on
Commit
ad1095f
·
1 Parent(s): 7bfa83f
.gitattributes CHANGED
@@ -34,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  src/faiss_index/index.faiss filter=lfs diff=lfs merge=lfs -text
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  src/faiss_index/index.faiss filter=lfs diff=lfs merge=lfs -text
37
+ *.faiss filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python cache/bytecode
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ dist/
13
+ downloads/
14
+ eggs/
15
+ .eggs/
16
+ parts/
17
+ sdist/
18
+ var/
19
+ wheels/
20
+ pip-wheel-metadata/
21
+ share/python-wheels/
22
+ *.egg-info/
23
+ .installed.cfg
24
+ *.egg
25
+ MANIFEST
26
+
27
+ # Unit test / coverage / tools
28
+ htmlcov/
29
+ .tox/
30
+ .nox/
31
+ .coverage
32
+ .coverage.*
33
+ .pytest_cache/
34
+ pytestdebug.log
35
+ coverage.xml
36
+ *.cover
37
+ *.py,cover
38
+ .hypothesis/
39
+ .pyre/
40
+ .mypy_cache/
41
+ .pytype/
42
+ .pyright/
43
+ .ruff_cache/
44
+
45
+ # Jupyter
46
+ .ipynb_checkpoints/
47
+ **/.ipynb_checkpoints/
48
+
49
+ # Environments
50
+ .env
51
+ .env.*
52
+ !.env.example
53
+ .venv/
54
+ venv/
55
+ env/
56
+ ENV/
57
+ .python-version
58
+
59
+ # Logs
60
+ *.log
61
+ logs/
62
+
63
+ # IDE / Editor
64
+ .vscode/
65
+ .vscode-test/
66
+ .idea/
67
+ *.sublime-workspace
68
+ *.sublime-project
69
+
70
+ # OS files
71
+ .DS_Store
72
+ Thumbs.db
73
+ ehthumbs.db
74
+ Desktop.ini
75
+
76
+ # Streamlit
77
+ .streamlit/secrets.toml
78
+
79
+ # Hugging Face / model caches
80
+ .huggingface/
81
+ **/.cache/
82
+ transformers_cache/
83
+ hf_cache/
84
+ sentence_transformers_cache/
85
+ torch_cache/
86
+
87
+ # Data and artifacts
88
+ data/
89
+ .data/
90
+ datasets/
91
+ outputs/
92
+ artifacts/
93
+ checkpoints/
94
+ runs/
95
+ wandb/
96
+ mlruns/
97
+
98
+ # Vector store / indexes
99
+ vectorstore/
100
+ .vectorstore/
101
+ # faiss_index/
102
+ # indexes/
103
+
104
+ # Docker (local overrides)
105
+ docker-compose.override.yml
Dockerfile CHANGED
@@ -1,20 +1,28 @@
1
- FROM python:3.13.5-slim
2
 
3
  WORKDIR /app
4
 
5
- RUN apt-get update && apt-get install -y \
6
- build-essential \
7
  curl \
8
  git \
9
- && rm -rf /var/lib/apt/lists/*
10
 
11
- COPY requirements.txt ./
12
- COPY src/ ./src/
 
13
 
14
- RUN pip3 install -r requirements.txt
 
 
 
 
 
 
15
 
16
  EXPOSE 8501
17
 
18
- HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
19
 
20
- ENTRYPOINT ["streamlit", "run", "src/streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
 
 
1
+ FROM python:3.11-slim
2
 
3
  WORKDIR /app
4
 
5
+ # Install only what you need for build/runtime
6
+ RUN apt-get update && apt-get install -y --no-install-recommends \
7
  curl \
8
  git \
9
+ && rm -rf /var/lib/apt/lists/*
10
 
11
+ # Copy all sources at repo root into image (no src/ subpackage anymore)
12
+ COPY requirements.txt ./requirements.txt
13
+ COPY . .
14
 
15
+ # Put project root on PYTHONPATH
16
+ ENV PYTHONPATH=/app
17
+
18
+ RUN pip install --no-cache-dir -r requirements.txt
19
+
20
+ # Spaces sets $PORT; default locally to 8501
21
+ ENV PORT=8501
22
 
23
  EXPOSE 8501
24
 
25
+ HEALTHCHECK CMD curl --fail http://localhost:${PORT}/_stcore/health || exit 1
26
 
27
+ # Use sh so $PORT expands at runtime
28
+ ENTRYPOINT ["/bin/sh", "-c", "streamlit run streamlit_app.py --server.port ${PORT} --server.address 0.0.0.0"]
app.py CHANGED
@@ -1,3 +1,3 @@
1
  # Entrypoint renamed for Hugging Face Spaces Streamlit detection.
2
  # Currently imports the original app content.
3
- from src.streamlit_app import * # noqa
 
1
  # Entrypoint renamed for Hugging Face Spaces Streamlit detection.
2
  # Currently imports the original app content.
3
+ from streamlit_app import * # noqa
{src/faiss_index → faiss_index}/index.faiss RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:286e3c5e6d3b5a8b8a642fb64a5363ec608dde23b197374bb73dc912fae06013
3
- size 123574317
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:839024862e5b0a77cb65312c01ec88994ea1104a457fb58ca535976d5c79d934
3
+ size 15392301
{src/faiss_index → faiss_index}/index.pkl RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:8afc6f5341df0262d4275907748004ea3cc27897c241e21b674486d58ca0bd69
3
- size 59597386
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3edc9118bedd5886f40903221b3286af4c23bc05fcfd47fac55c2dbdfd852d8a
3
+ size 6473293
requirements.txt CHANGED
@@ -2,6 +2,7 @@ streamlit==1.38.0
2
  langchain==0.2.14
3
  langchain-community==0.2.12
4
  langchain-core==0.2.33
 
5
  langchain-groq
6
  sentence-transformers==3.0.1
7
  faiss-cpu>=1.7.4
@@ -10,4 +11,5 @@ huggingface-hub>=0.23.0
10
  python-dotenv==1.0.1
11
  requests==2.32.3
12
  numpy<2.0.0
 
13
  # Removed tiktoken (unused) to avoid rust build on HF base image.
 
2
  langchain==0.2.14
3
  langchain-community==0.2.12
4
  langchain-core==0.2.33
5
+ langchain-text-splitters==0.2.2
6
  langchain-groq
7
  sentence-transformers==3.0.1
8
  faiss-cpu>=1.7.4
 
11
  python-dotenv==1.0.1
12
  requests==2.32.3
13
  numpy<2.0.0
14
+ dotenv
15
  # Removed tiktoken (unused) to avoid rust build on HF base image.
src/__init__.py CHANGED
@@ -1,6 +0,0 @@
1
- """
2
- Local package for the Research RAG Chatbot app.
3
-
4
- This file makes the `src` directory a Python package so absolute imports
5
- like `from src.vector_store import ...` work both locally and in Docker.
6
- """
 
 
 
 
 
 
 
src/api.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from .vector_store import build_or_load_vectorstore
3
+ from .retriever import get_retriever
4
+ from .rag_pipeline import build_rag_chain
5
+ from .ingestion import df_to_documents, preprocess_dataframe, load_data_subset
6
+ from .config import DATA_PATH
7
+ import os
8
+
9
+ app = FastAPI()
10
+
11
+ # Load documents and vectorstore at startup
12
+ df = load_data_subset(os.path.join(DATA_PATH, "arxiv-metadata-oai-snapshot.json"))
13
+ df = preprocess_dataframe(df)
14
+ docs = df_to_documents(df)
15
+ vectorstore = build_or_load_vectorstore(docs)
16
+ retriever = get_retriever(vectorstore)
17
+ rag_chain = build_rag_chain(retriever)
18
+
19
+ @app.get("/query")
20
+ def query_rag(q: str):
21
+ return {"answer": rag_chain.invoke(q).content}
src/config.py CHANGED
@@ -1,17 +1,15 @@
1
  import os
 
 
2
 
3
- # Base paths
4
- BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
5
- DATA_PATH = os.path.join(BASE_DIR, "data")
6
 
7
- # FAISS index location (folder);
8
- # streamlit_app expects a file check but we'll place the file path here.
9
- FAISS_DIR = os.path.join(os.path.dirname(__file__), "faiss_index")
10
- FAISS_INDEX_PATH = os.path.join(FAISS_DIR, "index.faiss")
11
 
12
- # API Keys / Env
13
- GROQ_API_KEY = os.getenv("GROQ_API_KEY", "")
14
 
15
- # Ensure directories exist at runtime
16
- os.makedirs(FAISS_DIR, exist_ok=True)
17
- os.makedirs(DATA_PATH, exist_ok=True)
 
 
1
  import os
2
+ from dotenv import load_dotenv
3
+ load_dotenv()
4
 
5
+ DATA_PATH = "data"
6
+ FAISS_INDEX_PATH = "faiss_index"
 
7
 
8
+ EMBEDDING_MODEL = "sentence-transformers/multi-qa-MiniLM-L6-cos-v1"
 
 
 
9
 
10
+ DEVICE = "cuda" if os.environ.get("CUDA_AVAILABLE", "0") == "1" else "cpu"
 
11
 
12
+ GROQ_API_KEY = os.environ.get("GROQ_API_KEY", "")
13
+
14
+ # Cross-encoder model for reranking
15
+ CROSS_ENCODER_MODEL = "cross-encoder/ms-marco-MiniLM-L-12-v2"
src/embeddings.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from langchain_community.embeddings import HuggingFaceEmbeddings
2
+ from .config import EMBEDDING_MODEL, DEVICE
3
+
4
+ def get_embedding_model():
5
+ return HuggingFaceEmbeddings(
6
+ model_name=EMBEDDING_MODEL,
7
+ model_kwargs={"device": DEVICE}
8
+ )
src/ingestion.py CHANGED
@@ -1,73 +1,46 @@
1
- from __future__ import annotations
2
 
3
- import json
4
  import os
5
- from typing import List, Dict, Any, Optional
6
-
7
  from langchain_core.documents import Document
 
 
8
 
9
-
10
- def load_data_subset(json_path: str, num_records: int = 50000) -> List[Dict[str, Any]]:
11
- """
12
- Load a subset of records from an arXiv metadata JSON lines file.
13
- Returns a list of dicts (not a pandas DataFrame) to keep dependencies minimal.
14
- """
15
- if not os.path.exists(json_path):
16
- raise FileNotFoundError(f"Data file not found: {json_path}")
17
-
18
- rows: List[Dict[str, Any]] = []
19
- with open(json_path, "r", encoding="utf-8") as f:
20
  for i, line in enumerate(f):
21
  if i >= num_records:
22
  break
23
- try:
24
- rows.append(json.loads(line))
25
- except json.JSONDecodeError:
26
- continue
27
- return rows
28
-
29
-
30
- def preprocess_dataframe(rows: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
31
- """
32
- Light preprocessing: ensure keys exist, fill defaults, and derive a 'year' field.
33
- Works on list-of-dicts to avoid pandas dependency.
34
- """
35
- def extract_year(versions: Any) -> Optional[int]:
36
- if isinstance(versions, list) and versions:
37
- created = versions[0].get("created")
38
- if isinstance(created, str) and len(created) >= 4 and created[:4].isdigit():
39
- return int(created[:4])
40
- return None
41
-
42
- norm: List[Dict[str, Any]] = []
43
- for r in rows:
44
- title = str(r.get("title", "") or "")
45
- abstract = str(r.get("abstract", "") or "")
46
- categories = str(r.get("categories", "") or "")
47
- versions = r.get("versions")
48
- year = r.get("year")
49
- if not isinstance(year, int):
50
- year = extract_year(versions)
51
- norm.append({
52
- **r,
53
- "title": title,
54
- "abstract": abstract,
55
- "categories": categories,
56
- "year": year,
57
- })
58
- return norm
59
-
60
-
61
- def df_to_documents(rows: List[Dict[str, Any]]) -> List[Document]:
62
- """
63
- Convert rows (list-of-dicts) to LangChain Documents with metadata.
64
- """
65
- docs: List[Document] = []
66
- for r in rows:
67
- content = f"Title: {r.get('title','')}\n\n{r.get('abstract','')}"
68
- meta = {
69
- "categories": r.get("categories", ""),
70
- "year": r.get("year", None),
71
  }
72
- docs.append(Document(page_content=content, metadata=meta))
73
- return docs
 
1
+ """Data loading, cleaning and preprocessing for ArXiv dataset."""
2
 
 
3
  import os
4
+ import json
5
+ import pandas as pd
6
  from langchain_core.documents import Document
7
+ from .config import DATA_PATH
8
+ from .text_processing import clean_text
9
 
10
+ def load_data_subset(file_path, num_records=50000):
11
+ records = []
12
+ with open(file_path, 'r') as f:
 
 
 
 
 
 
 
 
13
  for i, line in enumerate(f):
14
  if i >= num_records:
15
  break
16
+ records.append(json.loads(line))
17
+ return pd.DataFrame(records)
18
+
19
+ def preprocess_dataframe(df: pd.DataFrame) -> pd.DataFrame:
20
+ df['update_date'] = pd.to_datetime(df['update_date'])
21
+ df['year'] = df['update_date'].dt.year
22
+ df = df.dropna(subset=['abstract'])
23
+ df = df[df['abstract'].str.strip() != '']
24
+ return df
25
+
26
+ def df_to_documents(
27
+ df: pd.DataFrame,
28
+ lowercase: bool = False,
29
+ remove_stopwords: bool = False
30
+ ):
31
+ documents = []
32
+ for _, row in df.iterrows():
33
+ title_clean = clean_text(str(row['title']), lowercase=lowercase, remove_stopwords=remove_stopwords)
34
+ abstract_clean = clean_text(str(row['abstract']), lowercase=lowercase, remove_stopwords=remove_stopwords)
35
+ page_content = f"Title: {title_clean}\n\nAbstract: {abstract_clean}"
36
+ categories_raw = row.get('categories', 'N/A') or 'N/A'
37
+ primary_category = categories_raw.split()[0] if isinstance(categories_raw, str) else 'N/A'
38
+ metadata = {
39
+ "id": row.get('id', 'N/A'),
40
+ "authors": row.get('authors', 'N/A'),
41
+ "year": int(row.get('year')) if not pd.isna(row.get('year')) else None,
42
+ "categories": categories_raw,
43
+ "primary_category": primary_category
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  }
45
+ documents.append(Document(page_content=page_content, metadata=metadata))
46
+ return documents
src/main.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import kagglehub
3
+ from .ingestion import load_data_subset, preprocess_dataframe, df_to_documents
4
+ from .vector_store import build_or_load_vectorstore
5
+ from .retriever import build_advanced_retriever
6
+ from .rag_pipeline import build_rag_chain
7
+ from .config import DATA_PATH
8
+ import shutil
9
+
10
+ def download_dataset():
11
+ """Download the ArXiv dataset via KaggleHub if not already present."""
12
+ os.makedirs(DATA_PATH, exist_ok=True)
13
+ dataset_file = os.path.join(DATA_PATH, "arxiv-metadata-oai-snapshot.json")
14
+
15
+ if not os.path.exists(dataset_file):
16
+ print("Downloading ArXiv dataset via KaggleHub...")
17
+ path = kagglehub.dataset_download("Cornell-University/arxiv")
18
+ extracted_file = os.path.join(path, "arxiv-metadata-oai-snapshot.json")
19
+ shutil.copy(extracted_file, dataset_file) # ✅ copy works across drives
20
+ print(f"Dataset copied to {dataset_file}")
21
+ else:
22
+ print(f"Dataset already exists at {dataset_file}")
23
+
24
+ return dataset_file
25
+
26
+
27
+ def run_sample_queries(rag_chain):
28
+ """Run a few sample queries through the RAG pipeline."""
29
+ sample_questions = [
30
+ "What are the recent advancements in graph neural networks?",
31
+ "Explain the applications of transformers in natural language processing.",
32
+ "How is reinforcement learning applied in robotics?",
33
+ ]
34
+
35
+ for q in sample_questions:
36
+ print("\n---")
37
+ print(f"Question: {q}")
38
+ answer = rag_chain.invoke(q).content
39
+ print(f"Answer: {answer}")
40
+
41
+
42
+ def main():
43
+ dataset_file = download_dataset()
44
+ df = load_data_subset(dataset_file, num_records=50000)
45
+ df = preprocess_dataframe(df)
46
+ documents = df_to_documents(df, lowercase=False, remove_stopwords=False)
47
+ vectorstore = build_or_load_vectorstore(
48
+ documents,
49
+ force_rebuild=False,
50
+ chunk_method="semantic", # fallback to recursive if semantic splitter unavailable
51
+ chunk_size=800,
52
+ chunk_overlap=120
53
+ )
54
+ retriever = build_advanced_retriever(
55
+ vectorstore,
56
+ base_k=16,
57
+ rerank_k=6,
58
+ primary_category=None,
59
+ year_min=None,
60
+ year_max=None,
61
+ dynamic=True,
62
+ use_rerank=True,
63
+ )
64
+ rag_chain = build_rag_chain(retriever)
65
+ run_sample_queries(rag_chain)
66
+
67
+
68
+ if __name__ == "__main__":
69
+ main()
src/rag_pipeline.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.schema.runnable import RunnableLambda, RunnablePassthrough
2
+ from langchain.prompts import PromptTemplate
3
+ from langchain_groq import ChatGroq
4
+ from .config import GROQ_API_KEY
5
+ from .retriever import RerankRetriever
6
+ def build_rag_chain(retriever: RerankRetriever):
7
+ retriever_runnable = RunnableLambda(lambda question: retriever.get_relevant_documents(question))
8
+ format_docs_runnable = RunnableLambda(lambda docs: "\n\n".join([d.page_content for d in docs]))
9
+
10
+ prompt_template = """Answer the following question based on the provided context.
11
+
12
+ Context:
13
+ {context}
14
+
15
+ Question: {question}
16
+
17
+ Answer: """
18
+ prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
19
+
20
+ llm = ChatGroq(
21
+ model="meta-llama/llama-4-maverick-17b-128e-instruct",
22
+ temperature=0.7,
23
+ max_tokens=512,
24
+ groq_api_key=GROQ_API_KEY
25
+ )
26
+
27
+ return {
28
+ "context": retriever_runnable | format_docs_runnable,
29
+ "question": RunnablePassthrough()
30
+ } | prompt | llm
src/retriever.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, List
2
+ from dataclasses import dataclass
3
+ from .config import CROSS_ENCODER_MODEL
4
+
5
+ try:
6
+ from sentence_transformers import CrossEncoder
7
+ _HAS_CE = True
8
+ except ImportError:
9
+ _HAS_CE = False
10
+
11
+
12
+ def get_retriever(vectorstore, k=3):
13
+ return vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": k})
14
+
15
+
16
+ @dataclass
17
+ class RetrievalParams:
18
+ base_k: int = 8 # initial fetch size for reranking
19
+ rerank_k: int = 4 # final number after rerank
20
+ max_k: int = 20 # max docs to fetch for long/ambiguous queries
21
+ min_k: int = 3 # minimum docs
22
+ dynamic: bool = True # enable dynamic k logic
23
+ year_min: Optional[int] = None
24
+ year_max: Optional[int] = None
25
+ primary_category: Optional[str] = None
26
+ use_rerank: bool = True
27
+
28
+
29
+ class RerankRetriever:
30
+ def __init__(self, vectorstore, params: RetrievalParams):
31
+ self.vs = vectorstore
32
+ self.params = params
33
+ self.cross_encoder = None
34
+ if params.use_rerank and _HAS_CE:
35
+ self.cross_encoder = CrossEncoder(CROSS_ENCODER_MODEL)
36
+ elif params.use_rerank:
37
+ print("CrossEncoder not available; install sentence-transformers to enable reranking.")
38
+
39
+ def _compute_dynamic_k(self, query: str) -> int:
40
+ if not self.params.dynamic:
41
+ return self.params.base_k
42
+ length = len(query.split())
43
+ if length <= 4: # very short, broaden
44
+ return min(self.params.base_k + 6, self.params.max_k)
45
+ if length <= 12:
46
+ return self.params.base_k
47
+ return min(self.params.base_k + 4, self.params.max_k)
48
+
49
+ def _metadata_filter(self, docs):
50
+ p = self.params
51
+ filtered = []
52
+ for d in docs:
53
+ y = d.metadata.get("year")
54
+ if p.year_min is not None and (y is None or y < p.year_min):
55
+ continue
56
+ if p.year_max is not None and (y is None or y > p.year_max):
57
+ continue
58
+ if p.primary_category and d.metadata.get("primary_category") != p.primary_category:
59
+ continue
60
+ filtered.append(d)
61
+ return filtered
62
+
63
+ def get_relevant_documents(self, query: str):
64
+ fetch_k = self._compute_dynamic_k(query)
65
+ base_retriever = self.vs.as_retriever(search_type="similarity", search_kwargs={"k": fetch_k})
66
+ docs = base_retriever.get_relevant_documents(query)
67
+ docs = self._metadata_filter(docs)
68
+ if self.cross_encoder and docs:
69
+ pairs = [(query, d.page_content) for d in docs]
70
+ scores = self.cross_encoder.predict(pairs)
71
+ ranked = sorted(zip(docs, scores), key=lambda x: x[1], reverse=True)
72
+ final_k = min(self.params.rerank_k, len(ranked))
73
+ docs = [d for d, _ in ranked[:final_k]]
74
+ else:
75
+ # fallback: truncate
76
+ docs = docs[: self.params.rerank_k]
77
+ return docs
78
+
79
+ # For LangChain compatibility
80
+ def invoke(self, query: str):
81
+ return self.get_relevant_documents(query)
82
+
83
+
84
+ def build_advanced_retriever(
85
+ vectorstore,
86
+ base_k: int = 12,
87
+ rerank_k: int = 5,
88
+ primary_category: Optional[str] = None,
89
+ year_min: Optional[int] = None,
90
+ year_max: Optional[int] = None,
91
+ dynamic: bool = True,
92
+ use_rerank: bool = True,
93
+ ):
94
+ params = RetrievalParams(
95
+ base_k=base_k,
96
+ rerank_k=rerank_k,
97
+ primary_category=primary_category,
98
+ year_min=year_min,
99
+ year_max=year_max,
100
+ dynamic=dynamic,
101
+ use_rerank=use_rerank,
102
+ )
103
+ return RerankRetriever(vectorstore, params)
src/text_processing.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import Iterable
3
+
4
+ # Basic English stopwords (small set to avoid extra dependency); extend if needed
5
+ BASIC_STOPWORDS = {
6
+ 'the','and','a','an','of','in','to','is','are','for','on','with','that','this','by','from','at','as','it','be','or','we','can','our','their','these','those','using','used'
7
+ }
8
+
9
+ LATEX_EQ_RE = re.compile(r'\$\$.*?\$\$|\$[^$]*\$', re.DOTALL)
10
+ URL_RE = re.compile(r'https?://\S+|www\.\S+')
11
+ MULTI_WS_RE = re.compile(r'\s+')
12
+ INLINE_LATEX_CMD_RE = re.compile(r'\\(?:cite|ref|label|eqref|begin|end|textbf|emph|mathrm|mathbb)\{[^}]*\}')
13
+
14
+
15
+ def remove_latex(text: str) -> str:
16
+ text = LATEX_EQ_RE.sub(' ', text)
17
+ text = INLINE_LATEX_CMD_RE.sub(' ', text)
18
+ return text
19
+
20
+
21
+ def remove_urls(text: str) -> str:
22
+ return URL_RE.sub(' ', text)
23
+
24
+
25
+ def normalize_whitespace(text: str) -> str:
26
+ return MULTI_WS_RE.sub(' ', text).strip()
27
+
28
+
29
+ def strip_stopwords(tokens: Iterable[str]) -> str:
30
+ return ' '.join(t for t in tokens if t not in BASIC_STOPWORDS)
31
+
32
+
33
+ def clean_text(text: str, lowercase: bool = False, remove_stopwords: bool = False) -> str:
34
+ if not text:
35
+ return ''
36
+ t = remove_urls(text)
37
+ t = remove_latex(t)
38
+ if lowercase:
39
+ t = t.lower()
40
+ # Tokenize very simply on whitespace after basic cleanup
41
+ t = normalize_whitespace(t)
42
+ if remove_stopwords:
43
+ tokens = t.split()
44
+ t = strip_stopwords(tokens)
45
+ return t
src/vector_store.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Optional
3
+ from langchain_community.vectorstores import FAISS
4
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
5
+ try:
6
+ from semantic_text_splitter import TextSplitter as SemanticTextSplitter # type: ignore
7
+ _HAS_SEMANTIC = True
8
+ except ImportError: # graceful fallback if package missing
9
+ _HAS_SEMANTIC = False
10
+ from langchain_core.documents import Document
11
+ from .embeddings import get_embedding_model
12
+ from .config import FAISS_INDEX_PATH
13
+
14
+ def _chunk_documents(
15
+ documents: List[Document],
16
+ method: str = "recursive",
17
+ chunk_size: int = 1000,
18
+ chunk_overlap: int = 120
19
+ ):
20
+ if method == "semantic" and _HAS_SEMANTIC:
21
+ try:
22
+ # Newer versions expose factory; fallback to direct init
23
+ if hasattr(SemanticTextSplitter, "from_tiktoken_encoder"):
24
+ splitter = SemanticTextSplitter.from_tiktoken_encoder(
25
+ chunk_size=chunk_size,
26
+ chunk_overlap=chunk_overlap,
27
+ )
28
+ else: # try simple init signature
29
+ splitter = SemanticTextSplitter(
30
+ chunk_size=chunk_size,
31
+ chunk_overlap=chunk_overlap,
32
+ )
33
+ semantic_chunks: List[Document] = []
34
+ for d in documents:
35
+ try:
36
+ parts = splitter.chunks(d.page_content)
37
+ except AttributeError:
38
+ # Fallback: naive sentence-ish split
39
+ parts = d.page_content.split('. ')
40
+ for part in parts:
41
+ cleaned = part.strip()
42
+ if not cleaned:
43
+ continue
44
+ semantic_chunks.append(
45
+ Document(page_content=cleaned, metadata=d.metadata)
46
+ )
47
+ return semantic_chunks
48
+ except Exception as e:
49
+ print(f"[semantic chunking fallback] {e}; reverting to recursive splitter.")
50
+ # fallback / default
51
+ splitter = RecursiveCharacterTextSplitter(
52
+ chunk_size=chunk_size, chunk_overlap=chunk_overlap, add_start_index=True
53
+ )
54
+ return splitter.split_documents(documents)
55
+
56
+ def build_or_load_vectorstore(
57
+ documents: List[Document],
58
+ force_rebuild: bool = False,
59
+ chunk_method: str = "recursive", # or "semantic"
60
+ chunk_size: int = 1000,
61
+ chunk_overlap: int = 120
62
+ ):
63
+ if os.path.exists(FAISS_INDEX_PATH) and not force_rebuild:
64
+ print(f"Loading existing FAISS index from {FAISS_INDEX_PATH}...")
65
+ vectorstore = FAISS.load_local(
66
+ FAISS_INDEX_PATH,
67
+ get_embedding_model(),
68
+ allow_dangerous_deserialization=True
69
+ )
70
+ print("Vector store loaded successfully.")
71
+ return vectorstore
72
+
73
+ print("Building FAISS index (force_rebuild=%s, method=%s)..." % (force_rebuild, chunk_method))
74
+ splits = _chunk_documents(
75
+ documents,
76
+ method=chunk_method,
77
+ chunk_size=chunk_size,
78
+ chunk_overlap=chunk_overlap
79
+ )
80
+ print(f"Split {len(documents)} docs into {len(splits)} chunks (method={chunk_method}).")
81
+ vectorstore = FAISS.from_documents(splits, get_embedding_model())
82
+ vectorstore.save_local(FAISS_INDEX_PATH)
83
+ print(f"Vector store created and saved to {FAISS_INDEX_PATH}")
84
+ return vectorstore
85
+
86
+ def build_filtered_retriever(vectorstore, primary_category: Optional[str] = None, k: int = 3):
87
+ base = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": k})
88
+ if not primary_category:
89
+ return base
90
+ # Simple wrapper applying post-filtering by metadata; could be replaced by a VectorStore-specific filter if supported
91
+ def _get_relevant_documents(query):
92
+ docs = base.get_relevant_documents(query)
93
+ return [d for d in docs if d.metadata.get("primary_category") == primary_category]
94
+ base.get_relevant_documents = _get_relevant_documents # monkey patch
95
+ return base
src/streamlit_app.py → streamlit_app.py RENAMED
@@ -1,21 +1,10 @@
1
  import os
2
- import torch
3
  import streamlit as st
4
  from dotenv import load_dotenv, find_dotenv
5
-
6
- from langchain_text_splitters import RecursiveCharacterTextSplitter
7
- from langchain_community.embeddings import HuggingFaceEmbeddings
8
- from langchain_community.vectorstores import FAISS
9
  from langchain.prompts import PromptTemplate
10
  from langchain.schema.runnable import RunnablePassthrough
11
  from langchain_core.runnables import RunnableLambda
12
  from langchain_groq import ChatGroq
13
- import os
14
- import streamlit as st
15
- from dotenv import load_dotenv, find_dotenv
16
- from langchain.prompts import PromptTemplate
17
- from langchain.schema.runnable import RunnablePassthrough
18
- from langchain_groq import ChatGroq
19
 
20
  from src.vector_store import build_or_load_vectorstore
21
  from src.ingestion import load_data_subset, preprocess_dataframe, df_to_documents
 
1
  import os
 
2
  import streamlit as st
3
  from dotenv import load_dotenv, find_dotenv
 
 
 
 
4
  from langchain.prompts import PromptTemplate
5
  from langchain.schema.runnable import RunnablePassthrough
6
  from langchain_core.runnables import RunnableLambda
7
  from langchain_groq import ChatGroq
 
 
 
 
 
 
8
 
9
  from src.vector_store import build_or_load_vectorstore
10
  from src.ingestion import load_data_subset, preprocess_dataframe, df_to_documents