GitHub Actions commited on
Commit
eabfc15
Β·
1 Parent(s): 55300a1

Sync from GitHub 8e4442fbfa496966b830fcde5a3f4fd862922de9

Browse files
.gitattributes CHANGED
@@ -32,4 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.xz filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
32
  *.xz filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.github/workflows/main.yml ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Sync to HF Space
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+
8
+ jobs:
9
+ sync:
10
+ runs-on: ubuntu-latest
11
+
12
+ steps:
13
+ - name: Checkout repository
14
+ uses: actions/checkout@v4
15
+
16
+ - name: Set up Python
17
+ uses: actions/setup-python@v5
18
+ with:
19
+ python-version: "3.9"
20
+
21
+ - name: Install huggingface_hub
22
+ run: |
23
+ python -m pip install --upgrade pip
24
+ pip install huggingface_hub
25
+
26
+ - name: Mirror sync to HF Space
27
+ env:
28
+ HF_TOKEN: ${{ secrets.HF_TOKEN }}
29
+ run: |
30
+ git config --global user.email "[email protected]"
31
+ git config --global user.name "GitHub Actions"
32
+
33
+ # Clone HF Space repo
34
+ git clone https://huggingface.co/spaces/Wasifjafri/ml-research-assistant hf_space
35
+
36
+ # Sync from GitHub repo to HF Space, excluding itself + .git folder
37
+ rsync -av --delete --exclude '.git/' --exclude 'hf_space/' . hf_space/
38
+
39
+ cd hf_space
40
+ git add .
41
+ git commit -m "Sync from GitHub $GITHUB_SHA" || echo "No changes to commit"
42
+ git push https://user:[email protected]/spaces/Wasifjafri/ml-research-assistant.git main
43
+
44
+
.gitignore ADDED
Binary file (1.44 kB). View file
 
Dockerfile CHANGED
@@ -17,4 +17,5 @@ EXPOSE 8501
17
 
18
  HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
19
 
20
- ENTRYPOINT ["streamlit", "run", "src/streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
 
 
17
 
18
  HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
19
 
20
+ ENTRYPOINT ["streamlit", "run", "src/streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
21
+
README.md CHANGED
@@ -1,4 +1,3 @@
1
- ---
2
  title: Ml Research Assistant
3
  emoji: πŸš€
4
  colorFrom: red
@@ -6,14 +5,9 @@ colorTo: red
6
  sdk: docker
7
  app_port: 8501
8
  tags:
9
- - streamlit
10
  pinned: false
11
  short_description: Chatbot to help in research
12
- ---
13
 
14
- # Welcome to Streamlit!
15
-
16
- Edit `/src/streamlit_app.py` to customize this app to your heart's desire. :heart:
17
-
18
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
19
- forums](https://discuss.streamlit.io).
 
 
1
  title: Ml Research Assistant
2
  emoji: πŸš€
3
  colorFrom: red
 
5
  sdk: docker
6
  app_port: 8501
7
  tags:
8
+ - streamlit
9
  pinned: false
10
  short_description: Chatbot to help in research
 
11
 
12
+ # research_paper_assistant_rag_chatbot
13
+ A RAG pipeline and chatbot to assist in ML research
 
 
 
 
app.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Entrypoint renamed for Hugging Face Spaces Streamlit detection.
2
+ # Currently imports the original app content.
3
+ from streamlit_app import * # noqa
requirements.txt CHANGED
@@ -1,3 +1,53 @@
1
- altair
2
- pandas
3
- streamlit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.3.1
2
+ accelerate==1.10.1
3
+ aiohappyeyeballs==2.6.1
4
+ aiohttp==3.12.15
5
+ aiosignal==1.4.0
6
+ altair==5.5.0
7
+ annotated-types==0.7.0
8
+ anyio==4.10.0
9
+ argon2-cffi==25.1.0
10
+ argon2-cffi-bindings==25.1.0
11
+ arrow==1.3.0
12
+ asttokens==3.0.0
13
+ astunparse==1.6.3
14
+ async-lru==2.0.5
15
+ async-timeout==4.0.3
16
+ attrs==25.3.0
17
+ # Minimal requirements for the RAG chatbot
18
+ # Pin only where needed; compatible with Python 3.11 on Windows
19
+
20
+ # Core app
21
+ streamlit==1.49.1
22
+ python-dotenv==1.1.1
23
+
24
+ # LangChain stack (aligned versions)
25
+ langchain==0.3.27
26
+ langchain-core==0.3.75
27
+ langchain-community==0.3.29
28
+ langchain-text-splitters==0.3.11
29
+ langchain-groq==0.3.8
30
+ langchain-huggingface==0.3.1
31
+ langchain-google-genai==2.1.12
32
+ langchain-anthropic==0.3.6
33
+
34
+ # Vector store and NLP
35
+ faiss-cpu==1.12.0
36
+ sentence-transformers==5.1.0
37
+ transformers==4.56.1
38
+
39
+ # Data + utils
40
+ pandas==2.3.2
41
+ numpy==1.26.4
42
+ requests==2.32.5
43
+ datasets==3.2.0
44
+ # langchain-huggingface>=0.3.1 requires huggingface-hub>=0.33.4
45
+ huggingface_hub>=0.33.4
46
+
47
+ # Optional semantic splitter (app gracefully falls back if missing)
48
+ semantic-text-splitter==0.27.0
49
+
50
+ # Dataset fetcher (legacy - now using Hugging Face datasets)
51
+ # kagglehub==0.3.13
52
+
53
+
runtime.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ python-3.9.11
src/__init__.py ADDED
File without changes
src/config.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+ load_dotenv()
4
+
5
+ DATA_PATH = "data"
6
+ FAISS_INDEX_PATH = "faiss_index"
7
+
8
+ EMBEDDING_MODEL = "sentence-transformers/multi-qa-MiniLM-L6-cos-v1"
9
+
10
+ DEVICE = "cuda" if os.environ.get("CUDA_AVAILABLE", "0") == "1" else "cpu"
11
+
12
+ GROQ_API_KEY = os.environ.get("GROQ_API_KEY", "")
13
+ GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY", "")
14
+ GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY", GOOGLE_API_KEY)
15
+ ANTHROPIC_API_KEY = os.environ.get("ANTHROPIC_API_KEY", "")
16
+
17
+ # Default chat model identifiers
18
+ GROQ_MODEL = os.environ.get("GROQ_MODEL", "meta-llama/llama-4-maverick-17b-128e-instruct")
19
+ GEMINI_MODEL = os.environ.get("GEMINI_MODEL", "gemini-2.5-flash")
20
+ ANTHROPIC_MODEL = os.environ.get("ANTHROPIC_MODEL", "claude-sonnet-4-20250514")
21
+
22
+ # Cross-encoder model for reranking
23
+ CROSS_ENCODER_MODEL = "cross-encoder/ms-marco-MiniLM-L-12-v2"
24
+
25
+ # Remote FAISS index (Hugging Face dataset repo)
26
+ # Override via env if needed
27
+ HF_DATASET_REPO_ID = os.environ.get("HF_DATASET_REPO_ID", "Wasifjafri/research-paper-vdb")
28
+ HF_DATASET_REPO_TYPE = os.environ.get("HF_DATASET_REPO_TYPE", "dataset")
29
+ FAISS_INDEX_REMOTE_DIR = os.environ.get("FAISS_INDEX_REMOTE_DIR", "remote_faiss_index")
30
+ FAISS_INDEX_FILES = (
31
+ os.environ.get("FAISS_INDEX_FAISS_FILENAME", "index.faiss"),
32
+ os.environ.get("FAISS_INDEX_META_FILENAME", "index.pkl"),
33
+ )
src/embeddings.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from langchain_huggingface import HuggingFaceEmbeddings
2
+ from .config import EMBEDDING_MODEL, DEVICE
3
+
4
+ def get_embedding_model():
5
+ return HuggingFaceEmbeddings(
6
+ model_name=EMBEDDING_MODEL,
7
+ model_kwargs={"device": DEVICE}
8
+ )
src/ingestion.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Data loading, cleaning and preprocessing for ArXiv dataset."""
2
+
3
+ import os
4
+ import json
5
+ import gzip
6
+ import pandas as pd
7
+ from langchain_core.documents import Document
8
+ from .config import DATA_PATH
9
+ from .text_processing import clean_text
10
+
11
+ def load_hf_dataset(num_records=50000, dataset_name="CShorten/ML-ArXiv-Papers"):
12
+ """Load ArXiv papers from Hugging Face dataset.
13
+
14
+ Args:
15
+ num_records: Number of records to load
16
+ dataset_name: Hugging Face dataset identifier
17
+
18
+ Returns:
19
+ pandas DataFrame with the papers
20
+ """
21
+ try:
22
+ from datasets import load_dataset
23
+
24
+ print(f"Loading {num_records} records from {dataset_name}...")
25
+
26
+ # Load dataset from Hugging Face
27
+ dataset = load_dataset(dataset_name, split="train", streaming=False)
28
+
29
+ # Convert to pandas DataFrame
30
+ if num_records and num_records < len(dataset):
31
+ df = dataset.select(range(num_records)).to_pandas()
32
+ else:
33
+ df = dataset.to_pandas()
34
+
35
+ print(f"Loaded {len(df)} records from Hugging Face dataset")
36
+ return df
37
+
38
+ except ImportError:
39
+ raise ImportError("Please install the datasets library: pip install datasets")
40
+ except Exception as e:
41
+ raise ValueError(f"Failed to load Hugging Face dataset: {e}")
42
+
43
+ def _open_file(file_path):
44
+ """Open file with appropriate mode and encoding."""
45
+ if file_path.endswith('.gz'):
46
+ return gzip.open(file_path, 'rt', encoding='utf-8-sig')
47
+ return open(file_path, 'r', encoding='utf-8-sig')
48
+
49
+ def _parse_json_line(line):
50
+ """Parse a single JSON line, return None if invalid."""
51
+ s = line.strip()
52
+ if not s:
53
+ return None
54
+ try:
55
+ return json.loads(s)
56
+ except json.JSONDecodeError:
57
+ return None
58
+
59
+ def _try_full_json_array(file_path, num_records):
60
+ """Try to load the file as a full JSON array."""
61
+ try:
62
+ with _open_file(file_path) as f:
63
+ data = json.load(f)
64
+ if not isinstance(data, list):
65
+ raise ValueError("Top-level JSON is not a list.")
66
+ return pd.DataFrame(data[:num_records])
67
+ except Exception as e:
68
+ raise ValueError(
69
+ "Failed to parse dataset. Expected JSON Lines or a JSON array."
70
+ ) from e
71
+
72
+ def _parse_lines(file_path, num_records):
73
+ """Parse lines from file as JSONL, fallback to JSON array if needed."""
74
+ records = []
75
+ with _open_file(file_path) as f:
76
+ for line in f:
77
+ if len(records) >= num_records:
78
+ break
79
+ record = _parse_json_line(line)
80
+ if record is not None:
81
+ records.append(record)
82
+ elif not records:
83
+ # First non-empty line failed, try full-file JSON array
84
+ return _try_full_json_array(file_path, num_records)
85
+ return records
86
+
87
+ def load_data_subset(file_path, num_records=50000):
88
+ """Load up to num_records from a JSON Lines file.
89
+ - Skips empty/BOM-prefixed lines.
90
+ - Uses UTF-8 with BOM tolerance.
91
+ - Raises a clear error if file is empty or unreadable.
92
+ """
93
+ if not os.path.exists(file_path) or os.path.getsize(file_path) == 0:
94
+ raise FileNotFoundError(f"Dataset not found or empty: {file_path}")
95
+
96
+ try:
97
+ records = _parse_lines(file_path, num_records)
98
+ except UnicodeDecodeError:
99
+ # Retry with default encoding if needed
100
+ records = []
101
+ with open(file_path, 'r') as f:
102
+ for line in f:
103
+ if len(records) >= num_records:
104
+ break
105
+ record = _parse_json_line(line)
106
+ if record is not None:
107
+ records.append(record)
108
+
109
+ if isinstance(records, pd.DataFrame):
110
+ return records
111
+
112
+ if not records:
113
+ raise ValueError(
114
+ "No valid records were parsed from the dataset. Ensure the file is JSONL or a JSON array."
115
+ )
116
+ return pd.DataFrame(records)
117
+
118
+ def preprocess_dataframe(df: pd.DataFrame) -> pd.DataFrame:
119
+ """Preprocess the dataframe from Hugging Face or local file."""
120
+ # Handle different date column names
121
+ date_col = None
122
+ if 'update_date' in df.columns:
123
+ date_col = 'update_date'
124
+ elif 'updated' in df.columns:
125
+ date_col = 'updated'
126
+ elif 'published' in df.columns:
127
+ date_col = 'published'
128
+
129
+ if date_col:
130
+ df[date_col] = pd.to_datetime(df[date_col], errors='coerce')
131
+ df['year'] = df[date_col].dt.year
132
+ elif 'year' not in df.columns:
133
+ # If no date column exists, set year to None
134
+ df['year'] = None
135
+
136
+ # Ensure required columns exist
137
+ if 'abstract' in df.columns:
138
+ df = df.dropna(subset=['abstract'])
139
+ df = df[df['abstract'].str.strip() != '']
140
+
141
+ return df
142
+
143
+ def df_to_documents(
144
+ df: pd.DataFrame,
145
+ lowercase: bool = False,
146
+ remove_stopwords: bool = False
147
+ ):
148
+ """Convert dataframe to LangChain documents."""
149
+ documents = []
150
+ for _, row in df.iterrows():
151
+ # Get title and abstract
152
+ title = str(row.get('title', ''))
153
+ abstract = str(row.get('abstract', ''))
154
+
155
+ title_clean = clean_text(title, lowercase=lowercase, remove_stopwords=remove_stopwords)
156
+ abstract_clean = clean_text(abstract, lowercase=lowercase, remove_stopwords=remove_stopwords)
157
+ page_content = f"Title: {title_clean}\n\nAbstract: {abstract_clean}"
158
+
159
+ # Handle categories - can be string or list
160
+ categories_raw = row.get('categories', 'N/A') or 'N/A'
161
+ if isinstance(categories_raw, list):
162
+ categories_str = ' '.join(categories_raw) if categories_raw else 'N/A'
163
+ primary_category = categories_raw[0] if categories_raw else 'N/A'
164
+ else:
165
+ categories_str = str(categories_raw)
166
+ primary_category = categories_str.split()[0] if categories_str != 'N/A' else 'N/A'
167
+
168
+ # Build metadata
169
+ metadata = {
170
+ "id": row.get('id', 'N/A'),
171
+ "title": title, # Keep original title in metadata
172
+ "authors": row.get('authors', 'N/A'),
173
+ "year": int(row.get('year')) if not pd.isna(row.get('year')) else None,
174
+ "categories": categories_str,
175
+ "primary_category": primary_category
176
+ }
177
+
178
+ documents.append(Document(page_content=page_content, metadata=metadata))
179
+ return documents
src/retriever.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, List
2
+ from dataclasses import dataclass
3
+ from .config import CROSS_ENCODER_MODEL
4
+
5
+ try:
6
+ from sentence_transformers import CrossEncoder
7
+ _HAS_CE = True
8
+ except ImportError:
9
+ _HAS_CE = False
10
+
11
+
12
+ def get_retriever(vectorstore, k=3):
13
+ return vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": k})
14
+
15
+
16
+ @dataclass
17
+ class RetrievalParams:
18
+ base_k: int = 8 # initial fetch size for reranking
19
+ rerank_k: int = 4 # final number after rerank
20
+ max_k: int = 20 # max docs to fetch for long/ambiguous queries
21
+ min_k: int = 3 # minimum docs
22
+ dynamic: bool = True # enable dynamic k logic
23
+ year_min: Optional[int] = None
24
+ year_max: Optional[int] = None
25
+ primary_category: Optional[str] = None
26
+ use_rerank: bool = True
27
+
28
+
29
+ class RerankRetriever:
30
+ def __init__(self, vectorstore, params: RetrievalParams):
31
+ self.vs = vectorstore
32
+ self.params = params
33
+ self.cross_encoder = None
34
+ if params.use_rerank and _HAS_CE:
35
+ self.cross_encoder = CrossEncoder(CROSS_ENCODER_MODEL)
36
+ elif params.use_rerank:
37
+ print("CrossEncoder not available; install sentence-transformers to enable reranking.")
38
+
39
+ def _compute_dynamic_k(self, query: str) -> int:
40
+ if not self.params.dynamic:
41
+ return self.params.base_k
42
+ length = len(query.split())
43
+ if length <= 4: # very short, broaden
44
+ return min(self.params.base_k + 6, self.params.max_k)
45
+ if length <= 12:
46
+ return self.params.base_k
47
+ return min(self.params.base_k + 4, self.params.max_k)
48
+
49
+ def _metadata_filter(self, docs):
50
+ p = self.params
51
+ filtered = []
52
+ for d in docs:
53
+ y = d.metadata.get("year")
54
+ if p.year_min is not None and (y is None or y < p.year_min):
55
+ continue
56
+ if p.year_max is not None and (y is None or y > p.year_max):
57
+ continue
58
+ if p.primary_category and d.metadata.get("primary_category") != p.primary_category:
59
+ continue
60
+ filtered.append(d)
61
+ return filtered
62
+
63
+ def get_relevant_documents(self, query: str):
64
+ fetch_k = self._compute_dynamic_k(query)
65
+ base_retriever = self.vs.as_retriever(search_type="similarity", search_kwargs={"k": fetch_k})
66
+ try:
67
+ # New API
68
+ docs = base_retriever.invoke(query)
69
+ except Exception:
70
+ # Backward compatibility
71
+ docs = base_retriever.get_relevant_documents(query)
72
+ docs = self._metadata_filter(docs)
73
+ if self.cross_encoder and docs:
74
+ pairs = [(query, d.page_content) for d in docs]
75
+ scores = self.cross_encoder.predict(pairs)
76
+ ranked = sorted(zip(docs, scores), key=lambda x: x[1], reverse=True)
77
+ final_k = min(self.params.rerank_k, len(ranked))
78
+ docs = [d for d, _ in ranked[:final_k]]
79
+ else:
80
+ # fallback: truncate
81
+ docs = docs[: self.params.rerank_k]
82
+ return docs
83
+
84
+ # For LangChain compatibility
85
+ def invoke(self, query: str):
86
+ return self.get_relevant_documents(query)
87
+
88
+
89
+ def build_advanced_retriever(
90
+ vectorstore,
91
+ base_k: int = 12,
92
+ rerank_k: int = 5,
93
+ primary_category: Optional[str] = None,
94
+ year_min: Optional[int] = None,
95
+ year_max: Optional[int] = None,
96
+ dynamic: bool = True,
97
+ use_rerank: bool = True,
98
+ ):
99
+ params = RetrievalParams(
100
+ base_k=base_k,
101
+ rerank_k=rerank_k,
102
+ primary_category=primary_category,
103
+ year_min=year_min,
104
+ year_max=year_max,
105
+ dynamic=dynamic,
106
+ use_rerank=use_rerank,
107
+ )
108
+ return RerankRetriever(vectorstore, params)
src/streamlit_app.py DELETED
@@ -1,40 +0,0 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
- import streamlit as st
5
-
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/text_processing.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import Iterable
3
+
4
+ # Basic English stopwords (small set to avoid extra dependency); extend if needed
5
+ BASIC_STOPWORDS = {
6
+ 'the','and','a','an','of','in','to','is','are','for','on','with','that','this','by','from','at','as','it','be','or','we','can','our','their','these','those','using','used'
7
+ }
8
+
9
+ LATEX_EQ_RE = re.compile(r'\$\$.*?\$\$|\$[^$]*\$', re.DOTALL)
10
+ URL_RE = re.compile(r'https?://\S+|www\.\S+')
11
+ MULTI_WS_RE = re.compile(r'\s+')
12
+ INLINE_LATEX_CMD_RE = re.compile(r'\\(?:cite|ref|label|eqref|begin|end|textbf|emph|mathrm|mathbb)\{[^}]*\}')
13
+
14
+
15
+ def remove_latex(text: str) -> str:
16
+ text = LATEX_EQ_RE.sub(' ', text)
17
+ text = INLINE_LATEX_CMD_RE.sub(' ', text)
18
+ return text
19
+
20
+
21
+ def remove_urls(text: str) -> str:
22
+ return URL_RE.sub(' ', text)
23
+
24
+
25
+ def normalize_whitespace(text: str) -> str:
26
+ return MULTI_WS_RE.sub(' ', text).strip()
27
+
28
+
29
+ def strip_stopwords(tokens: Iterable[str]) -> str:
30
+ return ' '.join(t for t in tokens if t not in BASIC_STOPWORDS)
31
+
32
+
33
+ def clean_text(text: str, lowercase: bool = False, remove_stopwords: bool = False) -> str:
34
+ if not text:
35
+ return ''
36
+ t = remove_urls(text)
37
+ t = remove_latex(t)
38
+ if lowercase:
39
+ t = t.lower()
40
+ # Tokenize very simply on whitespace after basic cleanup
41
+ t = normalize_whitespace(t)
42
+ if remove_stopwords:
43
+ tokens = t.split()
44
+ t = strip_stopwords(tokens)
45
+ return t
src/vector_store.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Optional
3
+ from langchain_community.vectorstores import FAISS
4
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
5
+ try:
6
+ from semantic_text_splitter import TextSplitter as SemanticTextSplitter # type: ignore
7
+ _HAS_SEMANTIC = True
8
+ except ImportError: # graceful fallback if package missing
9
+ _HAS_SEMANTIC = False
10
+ from langchain_core.documents import Document
11
+ from .embeddings import get_embedding_model
12
+ from .config import (
13
+ FAISS_INDEX_PATH,
14
+ HF_DATASET_REPO_ID,
15
+ HF_DATASET_REPO_TYPE,
16
+ FAISS_INDEX_REMOTE_DIR,
17
+ FAISS_INDEX_FILES,
18
+ )
19
+ from pathlib import Path
20
+ from typing import Tuple
21
+ import shutil
22
+
23
+ def _ensure_local_faiss_from_hub(index_dir: str) -> bool:
24
+ """Download FAISS index files from Hugging Face Hub dataset repo if missing.
25
+
26
+ Returns True if files are present (downloaded or already existed), False otherwise.
27
+ """
28
+ target = Path(index_dir)
29
+ target.mkdir(parents=True, exist_ok=True)
30
+ faiss_name, pkl_name = FAISS_INDEX_FILES
31
+ faiss_path = target / faiss_name
32
+ pkl_path = target / pkl_name
33
+ if faiss_path.exists() and pkl_path.exists():
34
+ return True
35
+ try:
36
+ from huggingface_hub import hf_hub_download, list_repo_files
37
+
38
+ def _download_pair(faiss_fname: str, meta_fname: str, remote_subfolder: Optional[str] = None) -> bool:
39
+ try:
40
+ # Download FAISS file
41
+ local_faiss = hf_hub_download(
42
+ repo_id=HF_DATASET_REPO_ID,
43
+ repo_type=HF_DATASET_REPO_TYPE,
44
+ filename=faiss_fname,
45
+ subfolder=remote_subfolder or FAISS_INDEX_REMOTE_DIR or None,
46
+ local_dir=str(target),
47
+ local_dir_use_symlinks=False,
48
+ )
49
+ # Download metadata file
50
+ local_meta = hf_hub_download(
51
+ repo_id=HF_DATASET_REPO_ID,
52
+ repo_type=HF_DATASET_REPO_TYPE,
53
+ filename=meta_fname,
54
+ subfolder=remote_subfolder or FAISS_INDEX_REMOTE_DIR or None,
55
+ local_dir=str(target),
56
+ local_dir_use_symlinks=False,
57
+ )
58
+ # Normalize file names in target so FAISS.load_local can find them
59
+ try:
60
+ dst_faiss = target / faiss_name
61
+ dst_meta = target / pkl_name
62
+ if Path(local_faiss) != dst_faiss:
63
+ shutil.copy2(local_faiss, dst_faiss)
64
+ if Path(local_meta) != dst_meta:
65
+ shutil.copy2(local_meta, dst_meta)
66
+ except Exception as copy_err:
67
+ print(f"[FAISS download] Copy to expected names failed: {copy_err}")
68
+ return (target / faiss_name).exists() and (target / pkl_name).exists()
69
+ except Exception:
70
+ return False
71
+
72
+ # First try configured names
73
+ if _download_pair(faiss_name, pkl_name, FAISS_INDEX_REMOTE_DIR):
74
+ return True
75
+
76
+ # Fallback: auto-discover by listing repository files
77
+ try:
78
+ files = list_repo_files(repo_id=HF_DATASET_REPO_ID, repo_type=HF_DATASET_REPO_TYPE)
79
+ except Exception as e:
80
+ print(f"[FAISS download] list_repo_files failed for {HF_DATASET_REPO_ID}: {e}")
81
+ files = []
82
+
83
+ def _in_remote_dir(path: str) -> bool:
84
+ if not FAISS_INDEX_REMOTE_DIR:
85
+ return True
86
+ return path.startswith(f"{FAISS_INDEX_REMOTE_DIR}/") or path == FAISS_INDEX_REMOTE_DIR
87
+
88
+ faiss_candidates = [f for f in files if f.lower().endswith('.faiss') and _in_remote_dir(f)]
89
+ meta_candidates = [
90
+ f for f in files if (f.lower().endswith('.pkl') or f.lower().endswith('.pickle')) and _in_remote_dir(f)
91
+ ]
92
+ if faiss_candidates and meta_candidates:
93
+ # Take the first candidates
94
+ cand_faiss_path = faiss_candidates[0]
95
+ cand_meta_path = meta_candidates[0]
96
+ # Split into subfolder + filename
97
+ def _split_path(p: str) -> Tuple[Optional[str], str]:
98
+ if '/' in p:
99
+ d, b = p.rsplit('/', 1)
100
+ return d, b
101
+ return None, p
102
+ sub_faiss, base_faiss = _split_path(cand_faiss_path)
103
+ sub_meta, base_meta = _split_path(cand_meta_path)
104
+ # Prefer the shared subfolder if both live under the same dir
105
+ shared_sub = sub_faiss if sub_faiss == sub_meta else sub_faiss or sub_meta
106
+ if _download_pair(base_faiss, base_meta, shared_sub):
107
+ return True
108
+
109
+ print(
110
+ f"[FAISS download] Could not find/download FAISS pair in {HF_DATASET_REPO_ID}. "
111
+ f"Looked for {faiss_name} and {pkl_name}, candidates: {faiss_candidates} / {meta_candidates}"
112
+ )
113
+ return False
114
+ except Exception as e:
115
+ print(f"[FAISS download] Could not fetch from Hub ({HF_DATASET_REPO_ID}): {e}")
116
+ return False
117
+
118
+ def _semantic_chunk_documents(
119
+ documents: List[Document],
120
+ chunk_size: int,
121
+ chunk_overlap: int
122
+ ) -> List[Document]:
123
+ # Newer versions expose factory; fallback to direct init
124
+ if hasattr(SemanticTextSplitter, "from_tiktoken_encoder"):
125
+ splitter = SemanticTextSplitter.from_tiktoken_encoder(
126
+ chunk_size=chunk_size,
127
+ chunk_overlap=chunk_overlap,
128
+ )
129
+ else: # try simple init signature
130
+ splitter = SemanticTextSplitter(
131
+ chunk_size=chunk_size,
132
+ chunk_overlap=chunk_overlap,
133
+ )
134
+ semantic_chunks: List[Document] = []
135
+ for d in documents:
136
+ try:
137
+ parts = splitter.chunks(d.page_content)
138
+ except AttributeError:
139
+ # Fallback: naive sentence-ish split
140
+ parts = d.page_content.split('. ')
141
+ for part in parts:
142
+ cleaned = part.strip()
143
+ if cleaned:
144
+ semantic_chunks.append(
145
+ Document(page_content=cleaned, metadata=d.metadata)
146
+ )
147
+ return semantic_chunks
148
+
149
+ def _chunk_documents(
150
+ documents: List[Document],
151
+ method: str = "recursive",
152
+ chunk_size: int = 1000,
153
+ chunk_overlap: int = 120
154
+ ):
155
+ if method == "semantic" and _HAS_SEMANTIC:
156
+ try:
157
+ return _semantic_chunk_documents(documents, chunk_size, chunk_overlap)
158
+ except Exception as e:
159
+ print(f"[semantic chunking fallback] {e}; reverting to recursive splitter.")
160
+ # fallback / default
161
+ splitter = RecursiveCharacterTextSplitter(
162
+ chunk_size=chunk_size, chunk_overlap=chunk_overlap, add_start_index=True
163
+ )
164
+ return splitter.split_documents(documents)
165
+
166
+ def build_or_load_vectorstore(
167
+ documents: List[Document],
168
+ force_rebuild: bool = False,
169
+ chunk_method: str = "recursive", # or "semantic"
170
+ chunk_size: int = 1000,
171
+ chunk_overlap: int = 120
172
+ ):
173
+ # Ensure local index exists (download from Hub if needed)
174
+ if not os.path.exists(FAISS_INDEX_PATH):
175
+ fetched = _ensure_local_faiss_from_hub(FAISS_INDEX_PATH)
176
+ if fetched:
177
+ print(f"Downloaded FAISS index from Hub into {FAISS_INDEX_PATH}")
178
+
179
+ if os.path.exists(FAISS_INDEX_PATH) and not force_rebuild:
180
+ print(f"Loading existing FAISS index from {FAISS_INDEX_PATH}...")
181
+ try:
182
+ vectorstore = FAISS.load_local(
183
+ FAISS_INDEX_PATH,
184
+ get_embedding_model(),
185
+ allow_dangerous_deserialization=True
186
+ )
187
+ print("Vector store loaded successfully.")
188
+ return vectorstore
189
+ except Exception as e:
190
+ print(f"Failed to load FAISS index due to: {e}")
191
+ if not documents:
192
+ raise RuntimeError(
193
+ "Existing FAISS index is incompatible with current libraries and no documents were "
194
+ "provided to rebuild it. Delete 'faiss_index' and rebuild, or pass documents to rebuild."
195
+ ) from e
196
+ print("Rebuilding FAISS index from provided documents...")
197
+
198
+ print("Building FAISS index (force_rebuild=%s, method=%s)..." % (force_rebuild, chunk_method))
199
+ splits = _chunk_documents(
200
+ documents,
201
+ method=chunk_method,
202
+ chunk_size=chunk_size,
203
+ chunk_overlap=chunk_overlap
204
+ )
205
+ print(f"Split {len(documents)} docs into {len(splits)} chunks (method={chunk_method}).")
206
+ vectorstore = FAISS.from_documents(splits, get_embedding_model())
207
+ vectorstore.save_local(FAISS_INDEX_PATH)
208
+ print(f"Vector store created and saved to {FAISS_INDEX_PATH}")
209
+ return vectorstore
210
+
211
+ def build_filtered_retriever(vectorstore, primary_category: Optional[str] = None, k: int = 3):
212
+ base = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": k})
213
+ if not primary_category:
214
+ return base
215
+ # Simple wrapper applying post-filtering by metadata; could be replaced by a VectorStore-specific filter if supported
216
+ def _get_relevant_documents(query):
217
+ docs = base.get_relevant_documents(query)
218
+ return [d for d in docs if d.metadata.get("primary_category") == primary_category]
219
+ base.get_relevant_documents = _get_relevant_documents # monkey patch
220
+ return base
streamlit_app.py ADDED
@@ -0,0 +1,922 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ # Ensure Streamlit and ML caches write to a writable location (e.g., on HF Spaces)
4
+ os.environ["HOME"] = "/tmp"
5
+ os.environ["STREAMLIT_BROWSER_GATHER_USAGE_STATS"] = "false"
6
+ os.environ["STREAMLIT_GLOBAL_DATA_DIR"] = "/tmp/.streamlit"
7
+ os.environ["XDG_CACHE_HOME"] = "/tmp/.cache"
8
+ os.environ["HF_HOME"] = "/tmp/hf"
9
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf/transformers"
10
+ os.environ["SENTENCE_TRANSFORMERS_HOME"] = "/tmp/hf/sentence-transformers"
11
+ os.environ["TORCH_HOME"] = "/tmp/torch"
12
+
13
+ # Create the cache directories
14
+ for _d in [
15
+ os.environ["XDG_CACHE_HOME"],
16
+ os.environ["HF_HOME"],
17
+ os.environ["TRANSFORMERS_CACHE"],
18
+ os.environ["SENTENCE_TRANSFORMERS_HOME"],
19
+ os.environ["TORCH_HOME"],
20
+ os.environ.get("STREAMLIT_GLOBAL_DATA_DIR", "/tmp/.streamlit"),
21
+ ]:
22
+ try:
23
+ os.makedirs(_d, exist_ok=True)
24
+ except Exception:
25
+ pass
26
+
27
+ import streamlit as st
28
+ from dotenv import load_dotenv, find_dotenv
29
+ from langchain.prompts import PromptTemplate
30
+ from langchain.schema.runnable import RunnablePassthrough
31
+ from langchain_core.runnables import RunnableLambda
32
+ from langchain_groq import ChatGroq
33
+ import time
34
+
35
+ from src.vector_store import build_or_load_vectorstore
36
+ from src.ingestion import load_data_subset, preprocess_dataframe, df_to_documents, load_hf_dataset
37
+ from src.retriever import build_advanced_retriever
38
+ from src.config import DATA_PATH, FAISS_INDEX_PATH, GROQ_API_KEY, GEMINI_API_KEY, ANTHROPIC_API_KEY, GROQ_MODEL, GEMINI_MODEL, ANTHROPIC_MODEL
39
+
40
+ load_dotenv(find_dotenv())
41
+
42
+ # Initialize global vectorstore reference to avoid NameError before it is set
43
+ vectorstore = None
44
+
45
+ # PAGE CONFIG - Must be first Streamlit command
46
+ st.set_page_config(
47
+ page_title="Research Assistant",
48
+ page_icon="πŸ€–",
49
+ layout="wide",
50
+ initial_sidebar_state="expanded" # Start with sidebar expanded
51
+ )
52
+
53
+ # ENHANCED CUSTOM CSS - ChatGPT-like styling
54
+ st.markdown("""
55
+ <style>
56
+ /* Hide Streamlit branding */
57
+ #MainMenu {visibility: hidden;}
58
+ footer {visibility: hidden;}
59
+
60
+ /* Make sure header is visible for sidebar toggle */
61
+ header {visibility: visible !important;}
62
+
63
+ /* Style the sidebar toggle button to be more visible */
64
+ [data-testid="collapsedControl"] {
65
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
66
+ border-radius: 0 8px 8px 0 !important;
67
+ padding: 8px !important;
68
+ margin-top: 60px !important;
69
+ box-shadow: 0 4px 12px rgba(102, 126, 234, 0.4) !important;
70
+ }
71
+
72
+ [data-testid="collapsedControl"]:hover {
73
+ box-shadow: 0 6px 20px rgba(102, 126, 234, 0.6) !important;
74
+ transform: translateX(2px);
75
+ }
76
+
77
+ /* Overall app styling */
78
+ .stApp {
79
+ background: linear-gradient(180deg, #0f1419 0%, #1a1f2e 100%);
80
+ }
81
+
82
+ /* Main chat container */
83
+ .main .block-container {
84
+ padding-top: 2rem;
85
+ padding-bottom: 2rem;
86
+ max-width: 900px;
87
+ margin: 0 auto;
88
+ }
89
+
90
+ /* Chat input styling - Fixed at bottom like ChatGPT */
91
+ .stChatInputContainer {
92
+ background: transparent;
93
+ border: none;
94
+ padding: 1rem 0;
95
+ }
96
+
97
+ .stChatInput > div {
98
+ background: rgba(255, 255, 255, 0.05);
99
+ border: 1px solid rgba(255, 255, 255, 0.1);
100
+ border-radius: 24px;
101
+ padding: 12px 20px;
102
+ backdrop-filter: blur(10px);
103
+ transition: all 0.3s ease;
104
+ }
105
+
106
+ .stChatInput > div:hover {
107
+ background: rgba(255, 255, 255, 0.08);
108
+ border-color: rgba(255, 255, 255, 0.2);
109
+ box-shadow: 0 4px 12px rgba(0, 0, 0, 0.3);
110
+ }
111
+
112
+ .stChatInput > div:focus-within {
113
+ background: rgba(255, 255, 255, 0.1);
114
+ border-color: #10a37f;
115
+ box-shadow: 0 0 0 3px rgba(16, 163, 127, 0.1);
116
+ }
117
+
118
+ /* User messages - Right aligned with gradient */
119
+ [data-testid="stChatMessage"]:has([data-testid*="user"]) {
120
+ background: transparent;
121
+ justify-content: flex-end;
122
+ }
123
+
124
+ [data-testid="stChatMessage"]:has([data-testid*="user"]) [data-testid="stChatMessageContent"] {
125
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
126
+ border-radius: 18px;
127
+ padding: 14px 18px;
128
+ margin-left: auto;
129
+ max-width: 75%;
130
+ color: white;
131
+ box-shadow: 0 4px 12px rgba(102, 126, 234, 0.3);
132
+ }
133
+
134
+ /* Bot messages - Left aligned with subtle styling */
135
+ [data-testid="stChatMessage"]:not(:has([data-testid*="user"])) {
136
+ background: transparent;
137
+ justify-content: flex-start;
138
+ }
139
+
140
+ [data-testid="stChatMessage"]:not(:has([data-testid*="user"])) [data-testid="stChatMessageContent"] {
141
+ background: rgba(255, 255, 255, 0.03);
142
+ border: 1px solid rgba(255, 255, 255, 0.08);
143
+ border-radius: 18px;
144
+ padding: 14px 18px;
145
+ margin-right: auto;
146
+ max-width: 85%;
147
+ color: #e8e8e8;
148
+ box-shadow: 0 2px 8px rgba(0, 0, 0, 0.2);
149
+ backdrop-filter: blur(10px);
150
+ }
151
+
152
+ /* Avatar styling */
153
+ [data-testid="stChatMessage"] [data-testid="stAvatar"] {
154
+ width: 36px;
155
+ height: 36px;
156
+ border-radius: 50%;
157
+ border: 2px solid rgba(255, 255, 255, 0.1);
158
+ }
159
+
160
+ /* User avatar - gradient border */
161
+ [data-testid="stChatMessage"]:has([data-testid*="user"]) [data-testid="stAvatar"] {
162
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
163
+ border: 2px solid transparent;
164
+ box-shadow: 0 2px 8px rgba(102, 126, 234, 0.4);
165
+ }
166
+
167
+ /* Bot avatar - themed */
168
+ [data-testid="stChatMessage"]:not(:has([data-testid*="user"])) [data-testid="stAvatar"] {
169
+ background: linear-gradient(135deg, #10a37f 0%, #0d8a6a 100%);
170
+ border: 2px solid rgba(16, 163, 127, 0.3);
171
+ box-shadow: 0 2px 8px rgba(16, 163, 127, 0.3);
172
+ }
173
+
174
+ /* Sidebar styling */
175
+ [data-testid="stSidebar"] {
176
+ background: rgba(15, 20, 25, 0.95);
177
+ border-right: 1px solid rgba(255, 255, 255, 0.08);
178
+ backdrop-filter: blur(20px);
179
+ }
180
+
181
+ [data-testid="stSidebar"] .stButton button {
182
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
183
+ border: none;
184
+ border-radius: 12px;
185
+ color: white;
186
+ padding: 10px 20px;
187
+ font-weight: 600;
188
+ transition: all 0.3s ease;
189
+ }
190
+
191
+ [data-testid="stSidebar"] .stButton button:hover {
192
+ transform: translateY(-2px);
193
+ box-shadow: 0 6px 20px rgba(102, 126, 234, 0.4);
194
+ }
195
+
196
+ /* Expander styling */
197
+ .streamlit-expanderHeader {
198
+ background: rgba(255, 255, 255, 0.03);
199
+ border-radius: 12px;
200
+ border: 1px solid rgba(255, 255, 255, 0.08);
201
+ color: #b4b4b4;
202
+ padding: 12px 16px;
203
+ transition: all 0.3s ease;
204
+ }
205
+
206
+ .streamlit-expanderHeader:hover {
207
+ background: rgba(255, 255, 255, 0.06);
208
+ border-color: rgba(255, 255, 255, 0.15);
209
+ }
210
+
211
+ .streamlit-expanderContent {
212
+ background: rgba(255, 255, 255, 0.02);
213
+ border: 1px solid rgba(255, 255, 255, 0.05);
214
+ border-top: none;
215
+ border-radius: 0 0 12px 12px;
216
+ }
217
+
218
+ /* Divider styling */
219
+ hr {
220
+ border: none;
221
+ height: 1px;
222
+ background: linear-gradient(90deg, transparent, rgba(255, 255, 255, 0.1), transparent);
223
+ margin: 2rem 0;
224
+ }
225
+
226
+ /* Info boxes */
227
+ .stAlert {
228
+ background: rgba(16, 163, 127, 0.1);
229
+ border: 1px solid rgba(16, 163, 127, 0.3);
230
+ border-radius: 12px;
231
+ color: #a8e6cf;
232
+ }
233
+
234
+ /* Scrollbar styling */
235
+ ::-webkit-scrollbar {
236
+ width: 8px;
237
+ height: 8px;
238
+ }
239
+
240
+ ::-webkit-scrollbar-track {
241
+ background: rgba(255, 255, 255, 0.02);
242
+ }
243
+
244
+ ::-webkit-scrollbar-thumb {
245
+ background: rgba(255, 255, 255, 0.15);
246
+ border-radius: 10px;
247
+ }
248
+
249
+ ::-webkit-scrollbar-thumb:hover {
250
+ background: rgba(255, 255, 255, 0.25);
251
+ }
252
+
253
+ /* Typography improvements */
254
+ h1, h2, h3 {
255
+ color: #f0f0f0;
256
+ font-weight: 600;
257
+ }
258
+
259
+ p {
260
+ line-height: 1.7;
261
+ color: #d4d4d4;
262
+ }
263
+
264
+ /* Slider styling */
265
+ .stSlider > div > div > div > div {
266
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
267
+ }
268
+
269
+ /* Checkbox styling */
270
+ .stCheckbox > label > div[data-testid="stMarkdownContainer"] > p {
271
+ color: #d4d4d4;
272
+ }
273
+
274
+ /* Thinking animation */
275
+ @keyframes pulse {
276
+ 0%, 100% { opacity: 0.6; }
277
+ 50% { opacity: 1; }
278
+ }
279
+
280
+ .thinking {
281
+ animation: pulse 1.5s ease-in-out infinite;
282
+ color: #10a37f;
283
+ font-style: italic;
284
+ }
285
+
286
+ /* Welcome message styling */
287
+ .welcome-message {
288
+ background: linear-gradient(135deg, rgba(16, 163, 127, 0.1) 0%, rgba(102, 126, 234, 0.1) 100%);
289
+ border: 1px solid rgba(16, 163, 127, 0.3);
290
+ border-radius: 16px;
291
+ padding: 24px;
292
+ margin: 20px 0;
293
+ text-align: center;
294
+ box-shadow: 0 4px 16px rgba(16, 163, 127, 0.1);
295
+ }
296
+
297
+ .welcome-message h2 {
298
+ background: linear-gradient(135deg, #10a37f 0%, #667eea 100%);
299
+ -webkit-background-clip: text;
300
+ -webkit-text-fill-color: transparent;
301
+ margin-bottom: 12px;
302
+ }
303
+
304
+ /* Suggestion chips */
305
+ .suggestion-chip {
306
+ display: inline-block;
307
+ background: rgba(255, 255, 255, 0.05);
308
+ border: 1px solid rgba(255, 255, 255, 0.1);
309
+ border-radius: 20px;
310
+ padding: 8px 16px;
311
+ margin: 6px;
312
+ color: #b4b4b4;
313
+ cursor: pointer;
314
+ transition: all 0.3s ease;
315
+ }
316
+
317
+ .suggestion-chip:hover {
318
+ background: rgba(16, 163, 127, 0.15);
319
+ border-color: rgba(16, 163, 127, 0.4);
320
+ color: #10a37f;
321
+ transform: translateY(-2px);
322
+ }
323
+ </style>
324
+ """, unsafe_allow_html=True)
325
+
326
+ # Title with emoji and clean design
327
+ col1, col2, col3 = st.columns([1, 6, 1])
328
+ with col2:
329
+ st.markdown("<h1 style='text-align: center; margin-bottom: 0;'>πŸ€– Research Assistant</h1>", unsafe_allow_html=True)
330
+ st.markdown("<p style='text-align: center; color: #888; margin-top: 0;'>Powered by Multi-LLM RAG + FAISS</p>", unsafe_allow_html=True)
331
+
332
+ # Sidebar controls with improved organization
333
+ with st.sidebar:
334
+ st.markdown("### βš™οΈ Configuration")
335
+
336
+ with st.expander("πŸ“Š Dataset Info", expanded=False):
337
+ index_repo = os.environ.get("HF_DATASET_REPO_ID", "Wasifjafri/research-paper-vdb")
338
+ index_dir = os.environ.get("FAISS_INDEX_REMOTE_DIR", "faiss_index")
339
+ source_ds = os.environ.get("HF_SOURCE_DATASET", "")
340
+ st.markdown(f"""
341
+ **Vector index:** downloaded from `{index_repo}/{index_dir}` (HF dataset)
342
+
343
+ Rebuild (optional) requires a papers dataset set via env:
344
+ - `HF_SOURCE_DATASET` = `<owner>/<dataset>` (e.g., `CShorten/ML-ArXiv-Papers`)
345
+
346
+ If not set, the app will skip rebuilding and just use the packaged FAISS index.
347
+ Current HF_SOURCE_DATASET: `{source_ds or 'not set'}`
348
+ """)
349
+
350
+ st.markdown("---")
351
+
352
+ with st.expander("πŸ” Retrieval Settings", expanded=False):
353
+ base_k = st.slider("Initial fetch", 4, 30, 20, 1, help="Number of documents to initially retrieve")
354
+ rerank_k = st.slider("Final docs", 1, 12, 8, 1, help="Number of documents after reranking")
355
+ dynamic = st.checkbox("Dynamic k", True, help="Adjust retrieval size dynamically")
356
+ use_rerank = st.checkbox("Use reranking", True, help="Apply reranking for better relevance")
357
+
358
+ with st.expander("πŸ”§ Advanced Filters"):
359
+ primary_category = st.text_input("Category filter", "", help="Filter by arXiv category") or None
360
+ col1, col2 = st.columns(2)
361
+ with col1:
362
+ year_min = st.number_input("Min year", value=0, step=1)
363
+ with col2:
364
+ year_max = st.number_input("Max year", value=0, step=1)
365
+ if year_min == 0:
366
+ year_min = None
367
+ if year_max == 0:
368
+ year_max = None
369
+
370
+ st.markdown("---")
371
+
372
+ with st.expander("πŸ”„ Index Management", expanded=False):
373
+ subset_size = st.number_input("Dataset size", 1000, 100000, 10000, 1000)
374
+ rebuild = st.button("πŸ”¨ Rebuild Index", use_container_width=True)
375
+
376
+ st.markdown("---")
377
+
378
+ with st.expander("πŸ€– LLM Provider", expanded=False):
379
+ # Determine default provider based on available API keys
380
+ if ANTHROPIC_API_KEY:
381
+ default_provider = "Anthropic (Claude)"
382
+ elif GEMINI_API_KEY:
383
+ default_provider = "Gemini"
384
+ elif GROQ_API_KEY:
385
+ default_provider = "Groq"
386
+ else:
387
+ default_provider = "Gemini"
388
+
389
+ available_providers = ["Anthropic (Claude)", "Gemini", "Groq"]
390
+ try:
391
+ default_index = available_providers.index(default_provider)
392
+ except ValueError:
393
+ default_index = 0
394
+
395
+ provider = st.selectbox("Provider", available_providers, index=default_index)
396
+
397
+ if provider == "Anthropic (Claude)":
398
+ ui_anthropic_model = st.selectbox(
399
+ "Model",
400
+ [
401
+ "claude-sonnet-4-5-20250929",
402
+ "claude-opus-4-1-20250805",
403
+ "claude-opus-4-20250514",
404
+ "claude-sonnet-4-20250514",
405
+ "claude-3-7-sonnet-20250219",
406
+ "claude-3-5-haiku-20241022",
407
+ "claude-3-haiku-20240307"
408
+ ],
409
+ index=3
410
+ )
411
+ ui_gemini_model = None
412
+ ui_groq_model = None
413
+ elif provider == "Gemini":
414
+ ui_gemini_model = st.text_input("Model", GEMINI_MODEL)
415
+ ui_groq_model = None
416
+ ui_anthropic_model = None
417
+ else:
418
+ ui_groq_model = st.text_input("Model", GROQ_MODEL)
419
+ ui_gemini_model = None
420
+ ui_anthropic_model = None
421
+
422
+ # Stats at bottom
423
+ st.markdown("---")
424
+ try:
425
+ if 'vectorstore' in locals():
426
+ index_stats = vectorstore.index.ntotal if hasattr(vectorstore, 'index') else "Unknown"
427
+ st.metric("πŸ“š Embeddings", f"{index_stats:,}" if isinstance(index_stats, int) else index_stats)
428
+ except:
429
+ pass
430
+
431
+ # Build or load vectorstore
432
+ from typing import Optional
433
+
434
+ def _load_df_from_hf(num_records: int, dataset_name: Optional[str] = None):
435
+ """Load dataset from Hugging Face when rebuilding is explicitly requested.
436
+
437
+ Only used for index rebuilds; normal path downloads the ready-made FAISS index.
438
+ """
439
+ ds_name = dataset_name or os.environ.get("HF_SOURCE_DATASET")
440
+ if not ds_name:
441
+ st.error("❌ Rebuild requested but HF_SOURCE_DATASET is not set. Set it to a dataset like 'CShorten/ML-ArXiv-Papers'.")
442
+ st.stop()
443
+ try:
444
+ with st.spinner(f"πŸ”„ Loading papers from Hugging Face dataset: {ds_name}..."):
445
+ df = load_hf_dataset(num_records=num_records, dataset_name=ds_name)
446
+ return preprocess_dataframe(df)
447
+ except Exception as e:
448
+ st.error(f"❌ Failed to load dataset '{ds_name}': {e}")
449
+ st.info("πŸ’‘ If the dataset is private, add your HF token as a secret and set HF_SOURCE_DATASET.")
450
+ st.stop()
451
+
452
+ # Default path: try to download+load the FAISS index from HF dataset repo
453
+ if not rebuild:
454
+ try:
455
+ vectorstore = build_or_load_vectorstore([], force_rebuild=False)
456
+ except Exception as e:
457
+ st.error("❌ Could not load the FAISS index from the configured dataset repo.")
458
+ st.info("πŸ’‘ Check HF_DATASET_REPO_ID/FAISS_INDEX_REMOTE_DIR env vars and that the dataset has index.faiss/index.pkl.")
459
+ st.stop()
460
+ else:
461
+ # Rebuild only when explicitly requested and a source dataset is configured
462
+ with st.spinner("πŸ”¨ Rebuilding vector index from source dataset..."):
463
+ df = _load_df_from_hf(num_records=int(subset_size))
464
+ docs = df_to_documents(df)
465
+ vectorstore = build_or_load_vectorstore(
466
+ docs,
467
+ force_rebuild=True,
468
+ chunk_method="semantic",
469
+ chunk_size=1000,
470
+ chunk_overlap=125
471
+ )
472
+
473
+ def make_llm(provider_name: str):
474
+ if provider_name == "Anthropic (Claude)":
475
+ if not ANTHROPIC_API_KEY:
476
+ st.error("❌ ANTHROPIC_API_KEY not set")
477
+ st.stop()
478
+ try:
479
+ from langchain_anthropic import ChatAnthropic
480
+ return ChatAnthropic(
481
+ model=ui_anthropic_model or ANTHROPIC_MODEL,
482
+ temperature=0.7,
483
+ max_tokens=2048,
484
+ api_key=ANTHROPIC_API_KEY,
485
+ )
486
+ except Exception as e:
487
+ st.error(f"❌ Claude initialization failed: {e}")
488
+ st.stop()
489
+
490
+ if provider_name == "Gemini":
491
+ if not GEMINI_API_KEY:
492
+ st.error("❌ GEMINI_API_KEY not set")
493
+ st.stop()
494
+ try:
495
+ from langchain_google_genai import ChatGoogleGenerativeAI
496
+ return ChatGoogleGenerativeAI(
497
+ model=ui_gemini_model or GEMINI_MODEL,
498
+ temperature=0.7,
499
+ max_output_tokens=1024,
500
+ api_key=GEMINI_API_KEY,
501
+ )
502
+ except Exception as e:
503
+ st.error(f"❌ Gemini initialization failed: {e}")
504
+ st.stop()
505
+
506
+ if not GROQ_API_KEY:
507
+ st.error("❌ No valid LLM provider configured")
508
+ st.stop()
509
+ return ChatGroq(
510
+ model=ui_groq_model or GROQ_MODEL,
511
+ temperature=0.7,
512
+ max_tokens=1024,
513
+ groq_api_key=GROQ_API_KEY,
514
+ )
515
+
516
+ llm = make_llm(provider)
517
+
518
+ # Relevance checking prompt
519
+ relevance_check_prompt = """You are a research paper relevance checker. Your task is to determine if the retrieved documents are relevant to the user's question.
520
+
521
+ Retrieved Documents:
522
+ {context}
523
+
524
+ User Question: {question}
525
+
526
+ Instructions:
527
+ - Carefully analyze whether the retrieved documents contain information that can answer the user's question
528
+ - Consider if the documents discuss the topic, concepts, or methods mentioned in the question
529
+ - Respond with ONLY one word: "RELEVANT" or "IRRELEVANT"
530
+ - Be strict: if the documents are only tangentially related or don't actually address the question, respond "IRRELEVANT"
531
+
532
+ Response:"""
533
+
534
+ relevance_prompt = PromptTemplate(template=relevance_check_prompt, input_variables=["context", "question"])
535
+
536
+ # IMPROVED PROMPT
537
+ prompt_template = """You are a knowledgeable and helpful research assistant specializing in arXiv papers. You MUST ONLY answer questions based on the provided research papers context.
538
+
539
+ Context from Research Papers:
540
+ {context}
541
+
542
+ User Question: {question}
543
+
544
+ CRITICAL RULES:
545
+ - ONLY use information from the provided research papers context above
546
+ - DO NOT use your general knowledge or training data
547
+ - If the context doesn't contain relevant information, you MUST respond with: "I couldn't find relevant information about this topic in the available research papers. The retrieved documents don't address your question. Please try different search terms or the database may not contain papers on this specific topic."
548
+
549
+ Instructions:
550
+ - Analyze the user's question and provide a thorough, well-structured response BASED ONLY ON THE CONTEXT
551
+ - Be conversational and descriptive - explain concepts clearly with sufficient detail
552
+ - Use multiple paragraphs when needed to fully address the question
553
+
554
+ **For paper listing requests** (e.g., "find papers", "list papers", "show papers"):
555
+ Format as a structured list with detailed summaries:
556
+
557
+ **Paper #[Number]: [Title]**
558
+ - **Authors:** [Author names]
559
+ - **Year:** [Publication year]
560
+ - **ArXiv ID:** [ID if available]
561
+ - **Category:** [Research category]
562
+ - **Summary:** [3-4 sentences explaining the paper's objectives, methodology, key contributions, and findings based on the context]
563
+
564
+ **For specific questions** (e.g., "What is...", "Explain...", "How does...", "What is the purpose of..."):
565
+ - Provide a comprehensive, multi-paragraph answer that fully addresses the question USING ONLY THE CONTEXT
566
+ - Start with a clear overview or direct answer from the papers
567
+ - Elaborate with details, context, and explanations from the research papers
568
+ - Discuss relevant methodologies, findings, implications, or technical details found in the papers
569
+ - Cite sources naturally throughout (e.g., "According to the research by [Authors] (Year)...")
570
+ - Use clear transitions between ideas
571
+ - Conclude with key takeaways or significance when appropriate
572
+
573
+ **General Guidelines:**
574
+ - Write in a natural, conversational tone similar to ChatGPT
575
+ - Aim for depth and clarity - don't give one-liner responses
576
+ - Break complex information into digestible paragraphs
577
+ - Use examples and analogies when helpful from the context
578
+ - NEVER invent or hallucinate information not in the context
579
+ - Always prioritize being helpful, informative, and thorough - but ONLY based on the provided context
580
+
581
+ Answer:"""
582
+
583
+ prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
584
+
585
+ def _format_metadata(metadata):
586
+ """Format metadata in a clean, readable way."""
587
+ if not metadata:
588
+ return ""
589
+ meta_lines = []
590
+ if metadata.get("title"):
591
+ meta_lines.append(f"πŸ“„ {metadata['title']}")
592
+ if metadata.get("id"):
593
+ meta_lines.append(f"πŸ”— {metadata['id']}")
594
+ if metadata.get("authors") and metadata["authors"] != "N/A":
595
+ authors = metadata['authors']
596
+ if len(authors) > 100:
597
+ authors = authors[:100] + "..."
598
+ meta_lines.append(f"πŸ‘₯ {authors}")
599
+ if metadata.get("year"):
600
+ meta_lines.append(f"πŸ“… {metadata['year']}")
601
+ if metadata.get("primary_category") and metadata["primary_category"] != "N/A":
602
+ meta_lines.append(f"🏷️ {metadata['primary_category']}")
603
+ return " β€’ ".join(meta_lines)
604
+
605
+ def format_docs(docs):
606
+ """Format documents with clear structure and metadata."""
607
+ if not docs:
608
+ return "No relevant documents found in the database."
609
+
610
+ formatted_chunks = []
611
+ for idx, doc in enumerate(docs, start=1):
612
+ meta_str = _format_metadata(doc.metadata)
613
+ content = doc.page_content.strip()
614
+
615
+ if len(content) > 1000:
616
+ content = content[:1000] + "..."
617
+
618
+ formatted_chunk = f"[Document {idx}]\n{meta_str}\n\n{content}"
619
+ formatted_chunks.append(formatted_chunk)
620
+
621
+ return "\n\n" + "="*80 + "\n\n".join(formatted_chunks)
622
+
623
+ def build_chain():
624
+ """Build the RAG chain with improved retrieval."""
625
+ retriever = build_advanced_retriever(
626
+ vectorstore,
627
+ base_k=base_k,
628
+ rerank_k=rerank_k,
629
+ primary_category=primary_category,
630
+ year_min=year_min,
631
+ year_max=year_max,
632
+ dynamic=dynamic,
633
+ use_rerank=use_rerank,
634
+ )
635
+
636
+ def retrieval_with_logging(q):
637
+ try:
638
+ docs = retriever.get_relevant_documents(q)
639
+ return format_docs(docs)
640
+ except Exception as e:
641
+ return f"Error retrieving documents: {e}"
642
+
643
+ retrieval_runnable = RunnableLambda(retrieval_with_logging)
644
+ chain = {"context": retrieval_runnable, "question": RunnablePassthrough()} | prompt | llm
645
+ return chain, retriever
646
+
647
+ # Initialize session state
648
+ if "messages" not in st.session_state:
649
+ st.session_state["messages"] = []
650
+ st.session_state["show_welcome"] = True
651
+
652
+ # Welcome message with suggestions
653
+ if st.session_state.get("show_welcome", False):
654
+ st.markdown("""
655
+ <div class="welcome-message">
656
+ <h2>πŸ‘‹ Welcome to Research Assistant!</h2>
657
+ <p>I'm your AI-powered research companion. Ask me anything about Machine Learning papers!</p>
658
+ <div style="margin-top: 20px;">
659
+ <span class="suggestion-chip">πŸ” Find papers on transformers</span>
660
+ <span class="suggestion-chip">πŸ’‘ Explain attention mechanism</span>
661
+ <span class="suggestion-chip">πŸ“Š Compare CNN vs RNN</span>
662
+ <span class="suggestion-chip">🎯 Latest in reinforcement learning</span>
663
+ </div>
664
+ </div>
665
+ """, unsafe_allow_html=True)
666
+ st.session_state["show_welcome"] = False
667
+
668
+ # Helper functions
669
+ def is_casual_conversation(query_text):
670
+ """Check if the query is a greeting or casual conversation."""
671
+ query_lower = query_text.lower().strip()
672
+ greetings = ["hi", "hello", "hey", "good morning", "good afternoon", "good evening",
673
+ "hola", "greetings", "howdy", "yo", "sup", "what's up", "whats up"]
674
+ casual_patterns = [
675
+ "how are you", "how r u", "how do you do", "what's up", "whats up",
676
+ "who are you", "what are you", "what is your name", "your name",
677
+ "what can you do", "help me", "can you help", "thank you", "thanks",
678
+ "bye", "goodbye", "see you", "nice to meet you", "pleasure"
679
+ ]
680
+
681
+ if query_lower in greetings:
682
+ return True
683
+ for pattern in casual_patterns:
684
+ if pattern in query_lower:
685
+ return True
686
+ return False
687
+
688
+ def get_casual_response(query_text):
689
+ """Generate appropriate response for casual conversation."""
690
+ query_lower = query_text.lower().strip()
691
+
692
+ if any(word in query_lower for word in ["hi", "hello", "hey", "hola", "howdy", "yo"]):
693
+ return "Hello! πŸ‘‹ I'm your AI Research Assistant for Machine Learning papers. How can I help you today?"
694
+ if "good morning" in query_lower:
695
+ return "Good morning! β˜€οΈ Ready to explore some ML research? What interests you today?"
696
+ if "good afternoon" in query_lower:
697
+ return "Good afternoon! 🌀️ Let's dive into some research! What would you like to learn about?"
698
+ if "good evening" in query_lower:
699
+ return "Good evening! πŸŒ™ I'm here to help with ML research. What topic interests you?"
700
+ if any(phrase in query_lower for phrase in ["how are you", "how r u", "how do you do"]):
701
+ return "I'm doing great, thanks! 😊 Ready to help you explore ML research. What's on your mind?"
702
+ if any(phrase in query_lower for phrase in ["who are you", "what are you", "your name"]):
703
+ return "I'm an AI Research Assistant specialized in Machine Learning! πŸ€– I help you find papers, explain concepts, and answer research questions. What would you like to know?"
704
+ if any(phrase in query_lower for phrase in ["what can you do", "help me", "can you help"]):
705
+ return """I can help you with:
706
+
707
+ πŸ” **Finding research papers** on specific ML topics
708
+ πŸ“š **Explaining ML concepts** from published research
709
+ πŸ’‘ **Answering questions** about techniques and methods
710
+ πŸŽ“ **Exploring** the latest ML research developments
711
+
712
+ Try asking:
713
+ - "Find papers on deep learning"
714
+ - "What is transfer learning?"
715
+ - "Explain adversarial training"
716
+
717
+ What interests you?"""
718
+ if any(word in query_lower for word in ["thank you", "thanks", "thx"]):
719
+ return "You're welcome! 😊 Happy to help! Let me know if you have other questions."
720
+ if any(word in query_lower for word in ["bye", "goodbye", "see you"]):
721
+ return "Goodbye! πŸ‘‹ Come back anytime for ML research help. Happy learning!"
722
+
723
+ return "I'm here to help with Machine Learning research! 😊 Ask me about any ML topics or papers."
724
+
725
+ # Chat input
726
+ query = st.chat_input("πŸ’¬ Ask me anything about ML research...")
727
+
728
+ # Display chat history
729
+ for i, msg in enumerate(st.session_state["messages"]):
730
+ # Show user message
731
+ st.chat_message("user", avatar="πŸ‘€").write(msg["query"])
732
+
733
+ # Show assistant response if available
734
+ if msg.get("answer") is not None:
735
+ with st.chat_message("assistant", avatar="πŸ€–"):
736
+ st.write(msg["answer"])
737
+ if msg.get("context") and len(msg["context"]) > 0:
738
+ with st.expander(f"πŸ“„ View {len(msg['context'])} Retrieved Documents", expanded=False):
739
+ for idx, doc in enumerate(msg["context"], 1):
740
+ st.markdown(f"**πŸ“Ž Document {idx}**")
741
+ st.caption(_format_metadata(doc.metadata))
742
+ st.text_area(
743
+ f"Content {idx}",
744
+ doc.page_content[:800] + ("..." if len(doc.page_content) > 800 else ""),
745
+ height=150,
746
+ key=f"doc_{i}_{idx}",
747
+ disabled=True
748
+ )
749
+ if idx < len(msg["context"]):
750
+ st.markdown("---")
751
+ else:
752
+ # Answer is being generated - show thinking indicator
753
+ with st.chat_message("assistant", avatar="πŸ€–"):
754
+ thinking_placeholder = st.empty()
755
+ thinking_placeholder.markdown('<p class="thinking">πŸ” Searching research papers...</p>', unsafe_allow_html=True)
756
+
757
+ # Check if casual conversation
758
+ if is_casual_conversation(msg["query"]):
759
+ casual_response = get_casual_response(msg["query"])
760
+
761
+ # Smooth streaming effect
762
+ response_placeholder = st.empty()
763
+ full_response = ""
764
+ words = casual_response.split()
765
+
766
+ for word in words:
767
+ full_response += word + " "
768
+ response_placeholder.markdown(full_response)
769
+ time.sleep(0.02)
770
+
771
+ st.session_state["messages"][i]["answer"] = casual_response
772
+ st.rerun()
773
+
774
+ else:
775
+ # Research question - full RAG pipeline
776
+ rag_chain, adv_retriever = build_chain()
777
+
778
+ docs = []
779
+ answer_text = ""
780
+ error_occurred = False
781
+
782
+ try:
783
+ docs = adv_retriever.get_relevant_documents(msg["query"])
784
+
785
+ if not docs:
786
+ answer_text = """I couldn't find any relevant research papers in the database that match your query.
787
+
788
+ **πŸ’‘ Suggestions:**
789
+ - Try using broader or different search terms
790
+ - Check the spelling of technical terms
791
+ - The database may not contain papers on this specific topic
792
+ - Consider rebuilding the index with more data
793
+
794
+ The current database focuses on ArXiv ML papers, but may not cover all research areas comprehensively."""
795
+ else:
796
+ thinking_placeholder.markdown('<p class="thinking">🧠 Analyzing documents...</p>', unsafe_allow_html=True)
797
+
798
+ # Check relevance
799
+ formatted_context = format_docs(docs)
800
+ relevance_check_chain = {"context": RunnablePassthrough(), "question": RunnablePassthrough()} | relevance_prompt | llm
801
+ relevance_result = relevance_check_chain.invoke({"context": formatted_context, "question": msg["query"]})
802
+ relevance_text = relevance_result.content if hasattr(relevance_result, "content") else str(relevance_result)
803
+
804
+ if "IRRELEVANT" in relevance_text.strip().upper():
805
+ answer_text = f"""I found {len(docs)} documents in the database, but they don't contain relevant information about your question.
806
+
807
+ **πŸ“‹ Retrieved topics:**
808
+ - {docs[0].metadata.get('title', 'Various topics') if docs else 'N/A'}
809
+
810
+ **πŸ’‘ Suggestions:**
811
+ - Try rephrasing with different keywords
812
+ - Use more specific technical terms
813
+ - Search for related concepts or broader topics
814
+ - The database may not have papers specifically on this topic
815
+
816
+ I can only provide answers based on the ArXiv papers in the database."""
817
+ else:
818
+ # Generate answer with streaming
819
+ thinking_placeholder.markdown('<p class="thinking">✍️ Generating response...</p>', unsafe_allow_html=True)
820
+ answer = rag_chain.invoke(msg["query"])
821
+ answer_text = answer.content if hasattr(answer, "content") else str(answer)
822
+
823
+ except Exception as e:
824
+ error_occurred = True
825
+ msg_err = str(e)
826
+ if "models/" in msg_err and "not found" in msg_err.lower():
827
+ answer_text = "⚠️ Selected model not found. Try a different model in the sidebar."
828
+ else:
829
+ answer_text = f"⚠️ An error occurred: {e}\n\nPlease try again or rebuild the index."
830
+
831
+ # Clear thinking and display response with streaming
832
+ thinking_placeholder.empty()
833
+
834
+ # Stream response
835
+ import re
836
+ response_placeholder = st.empty()
837
+ parts = re.split(r'(\n\n|(?<=[.!?])\s+)', answer_text)
838
+
839
+ full_response = ""
840
+ for part in parts:
841
+ full_response += part
842
+ response_placeholder.markdown(full_response)
843
+ time.sleep(0.03)
844
+
845
+ # Update session state
846
+ st.session_state["messages"][i]["answer"] = answer_text
847
+ st.session_state["messages"][i]["context"] = docs
848
+
849
+ # Show retrieved documents
850
+ if docs:
851
+ with st.expander(f"πŸ“„ View {len(docs)} Retrieved Documents", expanded=False):
852
+ for idx, doc in enumerate(docs, 1):
853
+ st.markdown(f"**πŸ“Ž Document {idx}**")
854
+ st.caption(_format_metadata(doc.metadata))
855
+ st.text_area(
856
+ f"Content {idx}",
857
+ doc.page_content[:800] + ("..." if len(doc.page_content) > 800 else ""),
858
+ height=150,
859
+ key=f"new_doc_{i}_{idx}",
860
+ disabled=True
861
+ )
862
+ if idx < len(docs):
863
+ st.markdown("---")
864
+
865
+ st.rerun()
866
+
867
+ # Process new query
868
+ if query:
869
+ # Add message to session state immediately
870
+ st.session_state["messages"].append({
871
+ "query": query,
872
+ "answer": None,
873
+ "context": []
874
+ })
875
+
876
+ # Force rerun to show the user message immediately
877
+ st.rerun()
878
+
879
+ # Footer with tips - only show if there are messages
880
+ if len(st.session_state["messages"]) > 0:
881
+ st.markdown("---")
882
+ with st.expander("πŸ’‘ Tips for Better Results", expanded=False):
883
+ col1, col2 = st.columns(2)
884
+
885
+ with col1:
886
+ st.markdown("""
887
+ **🎯 Asking Better Questions**
888
+
889
+ βœ… Use specific ML terminology
890
+ βœ… Mention techniques or methods
891
+ βœ… Ask for comparisons
892
+ βœ… Reference specific problems
893
+
894
+ **Examples:**
895
+ - "Papers on transformer architecture"
896
+ - "Compare CNNs vs Vision Transformers"
897
+ - "Explain BERT training methodology"
898
+ """)
899
+
900
+ with col2:
901
+ st.markdown("""
902
+ **πŸ“š Understanding Responses**
903
+
904
+ βœ… All answers from actual papers
905
+ βœ… View source documents anytime
906
+ βœ… Check relevance of results
907
+ βœ… Adjust settings if needed
908
+
909
+ **⚑ Advanced Tips:**
910
+ - Use sidebar filters (year, category)
911
+ - Adjust retrieval settings
912
+ - Try different LLM providers
913
+ - Rebuild index for fresh data
914
+ """)
915
+
916
+ # Add a "Clear Chat" button at the bottom of sidebar
917
+ with st.sidebar:
918
+ st.markdown("---")
919
+ if st.button("πŸ—‘οΈ Clear Chat History", use_container_width=True):
920
+ st.session_state["messages"] = []
921
+ st.session_state["show_welcome"] = True
922
+ st.rerun()