Wasifjafri commited on
Commit
d50ac21
Β·
1 Parent(s): fc49018

fix ui issues

Browse files
README.md CHANGED
@@ -1,19 +1,147 @@
1
  ---
2
- title: Research Rag Chatbot
3
- emoji: πŸš€
4
  colorFrom: blue
5
  colorTo: purple
6
  sdk: docker
7
  app_port: 8501
8
  tags:
9
  - streamlit
 
 
 
 
10
  pinned: false
11
- short_description: Rag chatbot for research papers
12
  ---
13
 
14
- # Welcome to Streamlit!
15
 
16
- Edit `/src/streamlit_app.py` to customize this app to your heart's desire. :heart:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
19
- forums](https://discuss.streamlit.io).
 
1
  ---
2
+ title: ML Research Paper RAG Chatbot
3
+ emoji: πŸ€–
4
  colorFrom: blue
5
  colorTo: purple
6
  sdk: docker
7
  app_port: 8501
8
  tags:
9
  - streamlit
10
+ - machine-learning
11
+ - research
12
+ - rag
13
+ - chatbot
14
  pinned: false
15
+ short_description: AI-powered chatbot for ML research papers
16
  ---
17
 
18
+ # πŸ“„ ML Research Paper RAG Chatbot
19
 
20
+ An intelligent research assistant that helps you discover, understand, and explore Machine Learning research papers from ArXiv using Retrieval-Augmented Generation (RAG).
21
+
22
+ ## 🎯 What is this?
23
+
24
+ This chatbot uses advanced AI to help you:
25
+ - πŸ” **Find relevant research papers** on any ML topic
26
+ - πŸ“š **Get detailed explanations** from published research
27
+ - πŸ’‘ **Understand complex concepts** with cited sources
28
+ - πŸŽ“ **Stay updated** with ML research trends
29
+
30
+ ## ✨ Features
31
+
32
+ - **Multi-LLM Support**: Choose between Anthropic Claude, Google Gemini, or Groq
33
+ - **Smart Retrieval**: FAISS vector store with semantic search and reranking
34
+ - **Research-Focused**: Only provides answers based on actual papers (no hallucinations)
35
+ - **Citation-Backed**: All responses cite source papers with metadata
36
+ - **Interactive UI**: Clean Streamlit interface with helpful guides
37
+
38
+ ## πŸš€ Quick Start Guide
39
+
40
+ ### For First-Time Users
41
+
42
+ 1. **Start a conversation** by typing a question in the chat box
43
+ 2. **Try example queries** using the quick action buttons
44
+ 3. **Explore results** by expanding the "View Retrieved Documents" section
45
+ 4. **Adjust settings** in the sidebar for fine-tuned results
46
+
47
+ ### Example Queries
48
+
49
+ ```
50
+ βœ… Find papers on handling imbalanced datasets
51
+ βœ… What methods are used for fraud detection in ML?
52
+ βœ… Explain the attention mechanism in transformers
53
+ βœ… List recent papers about reinforcement learning
54
+ βœ… How does batch normalization improve training?
55
+ ```
56
+
57
+ ## πŸ’‘ Tips for Best Results
58
+
59
+ ### Ask Better Questions
60
+ - βœ… **Be specific**: "fraud detection in credit cards" > "fraud"
61
+ - βœ… **Use ML terminology**: "convolutional neural networks" > "image AI"
62
+ - βœ… **Ask for comparisons**: "Compare CNN vs RNN for sequences"
63
+
64
+ ### Understand the Responses
65
+ - πŸ“š All answers are based on research papers in the database
66
+ - πŸ” Check "View Retrieved Documents" to see sources
67
+ - ⚠️ If documents seem irrelevant, try rephrasing
68
+
69
+ ### Advanced Usage
70
+ - βš™οΈ Adjust retrieval settings (base_k, rerank_k) for more/fewer papers
71
+ - 🎨 Switch LLM providers for different response styles
72
+ - πŸ“… Filter by year or category for focused results
73
+
74
+ ## πŸ—‚οΈ Dataset
75
+
76
+ Uses **CShorten/ML-ArXiv-Papers** from Hugging Face:
77
+ - Curated Machine Learning research papers from ArXiv
78
+ - Includes titles, abstracts, metadata, and citations
79
+ - Regularly updated with new publications
80
+
81
+ ## βš™οΈ Configuration
82
+
83
+ ### LLM Providers
84
+ 1. **Anthropic Claude** (Recommended for quality)
85
+ - claude-3-5-sonnet-20241022 (Best balance)
86
+ - claude-3-5-haiku-20241022 (Fast)
87
+
88
+ 2. **Google Gemini** (Good for free tier)
89
+ - gemini-2.5-flash (Fast and efficient)
90
+
91
+ 3. **Groq** (Fastest inference)
92
+ - llama-4-maverick-17b (Open source)
93
+
94
+ ### Retrieval Settings
95
+ - **base_k**: Initial papers fetched (4-30, default: 20)
96
+ - **rerank_k**: Final papers after reranking (1-12, default: 8)
97
+ - **Dynamic k**: Auto-adjust based on query
98
+ - **Reranking**: Improve relevance with cross-encoder
99
+
100
+ ## πŸ”§ Setup (For Developers)
101
+
102
+ ### Prerequisites
103
+ ```bash
104
+ pip install -r requirements.txt
105
+ ```
106
+
107
+ ### API Keys
108
+ Create a `.env` file:
109
+ ```env
110
+ ANTHROPIC_API_KEY=your-key-here
111
+ GEMINI_API_KEY=your-key-here
112
+ GROQ_API_KEY=your-key-here
113
+ ```
114
+
115
+ ### Run Locally
116
+ ```bash
117
+ streamlit run streamlit_app.py
118
+ ```
119
+
120
+ ## πŸ“Š How It Works
121
+
122
+ 1. **User Query** β†’ Semantic embedding created
123
+ 2. **Vector Search** β†’ FAISS retrieves similar papers
124
+ 3. **Reranking** β†’ Cross-encoder scores relevance
125
+ 4. **LLM Generation** β†’ AI generates answer from papers
126
+ 5. **Response** β†’ Cited answer with source papers
127
+
128
+ ## πŸ”’ Important Notes
129
+
130
+ - βœ… Answers are **based only on research papers** in the database
131
+ - βœ… System won't make up information from general knowledge
132
+ - βœ… If no relevant papers found, it will tell you
133
+ - ❌ Not a replacement for reading the full papers
134
+ - ⚠️ Always verify critical information with original sources
135
+
136
+ ## 🀝 Contributing
137
+
138
+ Feel free to submit issues, fork the repository, and create pull requests for any improvements.
139
+
140
+ ## πŸ“„ License
141
+
142
+ This project is open source and available under the MIT License.
143
+
144
+ ---
145
+
146
+ **Ready to explore ML research?** Start by asking a question! πŸš€
147
 
 
 
faiss_index/index.faiss DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:5990c6b4f15d524dae50f9d256cd3e16feef0863d5bb3f467629b4b7534cdca5
3
- size 151790637
 
 
 
 
faiss_index/index.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:d5fd0b4d6d2b61c5f67dc2447721091df6eee5a6a47b8d7018a7775b880aabd0
3
- size 63795991
 
 
 
 
requirements.txt CHANGED
@@ -29,6 +29,7 @@ langchain-text-splitters==0.3.11
29
  langchain-groq==0.3.8
30
  langchain-huggingface==0.3.1
31
  langchain-google-genai==2.1.12
 
32
 
33
  # Vector store and NLP
34
  faiss-cpu==1.12.0
@@ -39,11 +40,12 @@ transformers==4.56.1
39
  pandas==2.3.2
40
  numpy==1.26.4
41
  requests==2.32.5
 
42
 
43
  # Optional semantic splitter (app gracefully falls back if missing)
44
  semantic-text-splitter==0.27.0
45
 
46
- # Dataset fetcher
47
- kagglehub==0.3.13
48
 
49
 
 
29
  langchain-groq==0.3.8
30
  langchain-huggingface==0.3.1
31
  langchain-google-genai==2.1.12
32
+ langchain-anthropic==0.3.6
33
 
34
  # Vector store and NLP
35
  faiss-cpu==1.12.0
 
40
  pandas==2.3.2
41
  numpy==1.26.4
42
  requests==2.32.5
43
+ datasets==3.2.0
44
 
45
  # Optional semantic splitter (app gracefully falls back if missing)
46
  semantic-text-splitter==0.27.0
47
 
48
+ # Dataset fetcher (legacy - now using Hugging Face datasets)
49
+ # kagglehub==0.3.13
50
 
51
 
src/config.py CHANGED
@@ -12,10 +12,12 @@ DEVICE = "cuda" if os.environ.get("CUDA_AVAILABLE", "0") == "1" else "cpu"
12
  GROQ_API_KEY = os.environ.get("GROQ_API_KEY", "")
13
  GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY", "")
14
  GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY", GOOGLE_API_KEY)
 
15
 
16
  # Default chat model identifiers
17
  GROQ_MODEL = os.environ.get("GROQ_MODEL", "meta-llama/llama-4-maverick-17b-128e-instruct")
18
  GEMINI_MODEL = os.environ.get("GEMINI_MODEL", "gemini-2.5-flash")
 
19
 
20
  # Cross-encoder model for reranking
21
  CROSS_ENCODER_MODEL = "cross-encoder/ms-marco-MiniLM-L-12-v2"
 
12
  GROQ_API_KEY = os.environ.get("GROQ_API_KEY", "")
13
  GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY", "")
14
  GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY", GOOGLE_API_KEY)
15
+ ANTHROPIC_API_KEY = os.environ.get("ANTHROPIC_API_KEY", "")
16
 
17
  # Default chat model identifiers
18
  GROQ_MODEL = os.environ.get("GROQ_MODEL", "meta-llama/llama-4-maverick-17b-128e-instruct")
19
  GEMINI_MODEL = os.environ.get("GEMINI_MODEL", "gemini-2.5-flash")
20
+ ANTHROPIC_MODEL = os.environ.get("ANTHROPIC_MODEL", "claude-sonnet-4-20250514")
21
 
22
  # Cross-encoder model for reranking
23
  CROSS_ENCODER_MODEL = "cross-encoder/ms-marco-MiniLM-L-12-v2"
src/ingestion.py CHANGED
@@ -8,6 +8,38 @@ from langchain_core.documents import Document
8
  from .config import DATA_PATH
9
  from .text_processing import clean_text
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  def _open_file(file_path):
12
  """Open file with appropriate mode and encoding."""
13
  if file_path.endswith('.gz'):
@@ -84,10 +116,28 @@ def load_data_subset(file_path, num_records=50000):
84
  return pd.DataFrame(records)
85
 
86
  def preprocess_dataframe(df: pd.DataFrame) -> pd.DataFrame:
87
- df['update_date'] = pd.to_datetime(df['update_date'])
88
- df['year'] = df['update_date'].dt.year
89
- df = df.dropna(subset=['abstract'])
90
- df = df[df['abstract'].str.strip() != '']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  return df
92
 
93
  def df_to_documents(
@@ -95,19 +145,35 @@ def df_to_documents(
95
  lowercase: bool = False,
96
  remove_stopwords: bool = False
97
  ):
 
98
  documents = []
99
  for _, row in df.iterrows():
100
- title_clean = clean_text(str(row['title']), lowercase=lowercase, remove_stopwords=remove_stopwords)
101
- abstract_clean = clean_text(str(row['abstract']), lowercase=lowercase, remove_stopwords=remove_stopwords)
 
 
 
 
102
  page_content = f"Title: {title_clean}\n\nAbstract: {abstract_clean}"
 
 
103
  categories_raw = row.get('categories', 'N/A') or 'N/A'
104
- primary_category = categories_raw.split()[0] if isinstance(categories_raw, str) else 'N/A'
 
 
 
 
 
 
 
105
  metadata = {
106
  "id": row.get('id', 'N/A'),
 
107
  "authors": row.get('authors', 'N/A'),
108
  "year": int(row.get('year')) if not pd.isna(row.get('year')) else None,
109
- "categories": categories_raw,
110
  "primary_category": primary_category
111
  }
 
112
  documents.append(Document(page_content=page_content, metadata=metadata))
113
  return documents
 
8
  from .config import DATA_PATH
9
  from .text_processing import clean_text
10
 
11
+ def load_hf_dataset(num_records=50000, dataset_name="CShorten/ML-ArXiv-Papers"):
12
+ """Load ArXiv papers from Hugging Face dataset.
13
+
14
+ Args:
15
+ num_records: Number of records to load
16
+ dataset_name: Hugging Face dataset identifier
17
+
18
+ Returns:
19
+ pandas DataFrame with the papers
20
+ """
21
+ try:
22
+ from datasets import load_dataset
23
+
24
+ print(f"Loading {num_records} records from {dataset_name}...")
25
+
26
+ # Load dataset from Hugging Face
27
+ dataset = load_dataset(dataset_name, split="train", streaming=False)
28
+
29
+ # Convert to pandas DataFrame
30
+ if num_records and num_records < len(dataset):
31
+ df = dataset.select(range(num_records)).to_pandas()
32
+ else:
33
+ df = dataset.to_pandas()
34
+
35
+ print(f"Loaded {len(df)} records from Hugging Face dataset")
36
+ return df
37
+
38
+ except ImportError:
39
+ raise ImportError("Please install the datasets library: pip install datasets")
40
+ except Exception as e:
41
+ raise ValueError(f"Failed to load Hugging Face dataset: {e}")
42
+
43
  def _open_file(file_path):
44
  """Open file with appropriate mode and encoding."""
45
  if file_path.endswith('.gz'):
 
116
  return pd.DataFrame(records)
117
 
118
  def preprocess_dataframe(df: pd.DataFrame) -> pd.DataFrame:
119
+ """Preprocess the dataframe from Hugging Face or local file."""
120
+ # Handle different date column names
121
+ date_col = None
122
+ if 'update_date' in df.columns:
123
+ date_col = 'update_date'
124
+ elif 'updated' in df.columns:
125
+ date_col = 'updated'
126
+ elif 'published' in df.columns:
127
+ date_col = 'published'
128
+
129
+ if date_col:
130
+ df[date_col] = pd.to_datetime(df[date_col], errors='coerce')
131
+ df['year'] = df[date_col].dt.year
132
+ elif 'year' not in df.columns:
133
+ # If no date column exists, set year to None
134
+ df['year'] = None
135
+
136
+ # Ensure required columns exist
137
+ if 'abstract' in df.columns:
138
+ df = df.dropna(subset=['abstract'])
139
+ df = df[df['abstract'].str.strip() != '']
140
+
141
  return df
142
 
143
  def df_to_documents(
 
145
  lowercase: bool = False,
146
  remove_stopwords: bool = False
147
  ):
148
+ """Convert dataframe to LangChain documents."""
149
  documents = []
150
  for _, row in df.iterrows():
151
+ # Get title and abstract
152
+ title = str(row.get('title', ''))
153
+ abstract = str(row.get('abstract', ''))
154
+
155
+ title_clean = clean_text(title, lowercase=lowercase, remove_stopwords=remove_stopwords)
156
+ abstract_clean = clean_text(abstract, lowercase=lowercase, remove_stopwords=remove_stopwords)
157
  page_content = f"Title: {title_clean}\n\nAbstract: {abstract_clean}"
158
+
159
+ # Handle categories - can be string or list
160
  categories_raw = row.get('categories', 'N/A') or 'N/A'
161
+ if isinstance(categories_raw, list):
162
+ categories_str = ' '.join(categories_raw) if categories_raw else 'N/A'
163
+ primary_category = categories_raw[0] if categories_raw else 'N/A'
164
+ else:
165
+ categories_str = str(categories_raw)
166
+ primary_category = categories_str.split()[0] if categories_str != 'N/A' else 'N/A'
167
+
168
+ # Build metadata
169
  metadata = {
170
  "id": row.get('id', 'N/A'),
171
+ "title": title, # Keep original title in metadata
172
  "authors": row.get('authors', 'N/A'),
173
  "year": int(row.get('year')) if not pd.isna(row.get('year')) else None,
174
+ "categories": categories_str,
175
  "primary_category": primary_category
176
  }
177
+
178
  documents.append(Document(page_content=page_content, metadata=metadata))
179
  return documents
streamlit_app.py CHANGED
@@ -30,159 +30,540 @@ from langchain.prompts import PromptTemplate
30
  from langchain.schema.runnable import RunnablePassthrough
31
  from langchain_core.runnables import RunnableLambda
32
  from langchain_groq import ChatGroq
 
33
 
34
  from src.vector_store import build_or_load_vectorstore
35
- from src.ingestion import load_data_subset, preprocess_dataframe, df_to_documents
36
  from src.retriever import build_advanced_retriever
37
- from src.config import DATA_PATH, FAISS_INDEX_PATH, GROQ_API_KEY, GEMINI_API_KEY, GROQ_MODEL, GEMINI_MODEL
38
 
39
  load_dotenv(find_dotenv())
40
 
41
- st.set_page_config(page_title="πŸ“„ Research Paper RAG Chatbot", page_icon="πŸ’¬", layout="wide")
42
- st.title("πŸ“„ Research Paper RAG Chatbot (Groq + FAISS + Rerank)")
 
 
 
 
 
43
 
44
- # Sidebar controls
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  with st.sidebar:
46
- st.header("Retrieval Settings")
47
- base_k = st.slider("Initial fetch (base_k)", 4, 30, 20, 1) # Increased default
48
- rerank_k = st.slider("Final docs (rerank_k)", 1, 12, 8, 1) # Increased default
49
- dynamic = st.checkbox("Dynamic k", True)
50
- use_rerank = st.checkbox("Use reranking", True)
51
- primary_category = st.text_input("Primary category filter", "") or None
52
- year_min = st.number_input("Min year", value=0, step=1)
53
- year_max = st.number_input("Max year", value=0, step=1)
54
- if year_min == 0:
55
- year_min = None
56
- if year_max == 0:
57
- year_max = None
58
- rebuild = st.button("Rebuild index (semantic)")
59
- subset_size = st.number_input("Subset records (rebuild)", 1000, 100000, 50000, 1000)
60
- st.divider()
61
- st.header("LLM Provider")
62
- if GEMINI_API_KEY:
63
- default_provider = "Gemini"
64
- elif GROQ_API_KEY:
65
- default_provider = "Groq"
66
- else:
67
- default_provider = "Gemini"
68
- provider = st.selectbox("Choose provider", ["Gemini", "Groq"], index=["Gemini", "Groq"].index(default_provider))
69
- if provider == "Gemini":
70
- ui_gemini_model = st.text_input("Gemini model", GEMINI_MODEL)
71
- ui_groq_model = None
72
- else:
73
- ui_groq_model = st.text_input("Groq model", GROQ_MODEL)
74
- ui_gemini_model = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
  # Build or load vectorstore
77
- def _load_df_with_fallback(data_file: str, num_records: int):
78
- """Try to load the dataset; if it fails, download via KaggleHub and retry once."""
79
  try:
80
- return preprocess_dataframe(load_data_subset(data_file, num_records=num_records))
 
 
81
  except Exception as e:
82
- st.warning(f"Dataset read failed: {e}. Attempting fresh download via KaggleHub...")
83
- try:
84
- import kagglehub, shutil
85
- os.makedirs(DATA_PATH, exist_ok=True)
86
- with st.spinner("Downloading ArXiv dataset..."):
87
- path = kagglehub.dataset_download("Cornell-University/arxiv")
88
- src = os.path.join(path, "arxiv-metadata-oai-snapshot.json")
89
- shutil.copy(src, data_file)
90
- st.success("Dataset downloaded; retrying load...")
91
- return preprocess_dataframe(load_data_subset(data_file, num_records=num_records))
92
- except Exception as e2:
93
- st.error(f"Dataset download or reload failed: {e2}")
94
- st.stop()
95
 
96
  if rebuild or not os.path.exists(FAISS_INDEX_PATH):
97
- data_file = os.path.join(DATA_PATH, "arxiv-metadata-oai-snapshot.json")
98
- if not os.path.exists(data_file):
99
- st.warning("Dataset missing. Attempting to download via KaggleHub...")
100
- try:
101
- import kagglehub, shutil
102
- os.makedirs(DATA_PATH, exist_ok=True)
103
- with st.spinner("Downloading ArXiv dataset..."):
104
- path = kagglehub.dataset_download("Cornell-University/arxiv")
105
- src = os.path.join(path, "arxiv-metadata-oai-snapshot.json")
106
- shutil.copy(src, data_file)
107
- st.success("Dataset downloaded.")
108
- except Exception as e:
109
- st.error(f"Dataset download failed: {e}. Please run main pipeline first.")
110
- st.stop()
111
- with st.spinner("Building vector index..."):
112
- df = _load_df_with_fallback(data_file, num_records=int(subset_size))
113
  docs = df_to_documents(df)
114
  vectorstore = build_or_load_vectorstore(
115
  docs,
116
  force_rebuild=True,
117
  chunk_method="semantic",
118
- chunk_size=800,
119
- chunk_overlap=120
120
  )
121
  else:
122
  try:
123
  vectorstore = build_or_load_vectorstore([], force_rebuild=False)
124
  except Exception as e:
125
- st.warning(f"Failed to load existing index: {e}. Attempting to rebuild from dataset...")
126
- data_file = os.path.join(DATA_PATH, "arxiv-metadata-oai-snapshot.json")
127
- if not os.path.exists(data_file):
128
- st.error("Dataset missing. Run main pipeline first or click 'Rebuild index'.")
129
- st.stop()
130
- with st.spinner("Rebuilding vector index after load failure..."):
131
- df = _load_df_with_fallback(data_file, num_records=50000)
132
  docs = df_to_documents(df)
133
  vectorstore = build_or_load_vectorstore(
134
  docs,
135
  force_rebuild=True,
136
  chunk_method="semantic",
137
- chunk_size=800,
138
- chunk_overlap=120
139
  )
140
 
141
  def make_llm(provider_name: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  if provider_name == "Gemini":
143
  if not GEMINI_API_KEY:
144
- st.error("GEMINI_API_KEY not set. Please add it to your .env or environment variables to use Gemini.")
145
  st.stop()
146
  try:
147
  from langchain_google_genai import ChatGoogleGenerativeAI
148
  return ChatGoogleGenerativeAI(
149
  model=ui_gemini_model or GEMINI_MODEL,
150
  temperature=0.7,
151
- max_output_tokens=1024, # Increased for better responses
152
  api_key=GEMINI_API_KEY,
153
  )
154
- except ModuleNotFoundError:
155
- st.error("Missing dependency 'langchain-google-genai'. Please install it (pip install langchain-google-genai).")
156
- st.stop()
157
  except Exception as e:
158
- st.error(f"Failed to initialize Gemini: {e}")
159
  st.stop()
 
160
  if not GROQ_API_KEY:
161
- st.error("No valid LLM provider configured. Please set GEMINI_API_KEY or GROQ_API_KEY in environment.")
162
  st.stop()
163
  return ChatGroq(
164
  model=ui_groq_model or GROQ_MODEL,
165
  temperature=0.7,
166
- max_tokens=1024, # Increased for better responses
167
  groq_api_key=GROQ_API_KEY,
168
  )
169
 
170
  llm = make_llm(provider)
171
 
172
- # IMPROVED PROMPT - More flexible and informative
173
- prompt_template = """You are a helpful research assistant specializing in arXiv papers. Your goal is to provide informative, accurate answers based on the retrieved research papers.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
 
175
  Context from Research Papers:
176
  {context}
177
 
178
  User Question: {question}
179
 
 
 
 
 
 
180
  Instructions:
181
- 1. If the context contains relevant information, provide a comprehensive answer citing the papers (e.g., "According to Doc 1...")
182
- 2. Mention key details like authors, year, methods, and findings when available
183
- 3. If the context is partially relevant, provide what you can and note what's missing
184
- 4. Structure your response clearly with bullet points or sections when appropriate
185
- 5. Always be helpful and informative
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
  Answer:"""
188
 
@@ -194,35 +575,34 @@ def _format_metadata(metadata):
194
  return ""
195
  meta_lines = []
196
  if metadata.get("title"):
197
- meta_lines.append(f"Title: {metadata['title']}")
198
  if metadata.get("id"):
199
- meta_lines.append(f"ArXiv ID: {metadata['id']}")
200
  if metadata.get("authors") and metadata["authors"] != "N/A":
201
  authors = metadata['authors']
202
- if len(authors) > 100: # Truncate long author lists
203
  authors = authors[:100] + "..."
204
- meta_lines.append(f"Authors: {authors}")
205
  if metadata.get("year"):
206
- meta_lines.append(f"Year: {metadata['year']}")
207
  if metadata.get("primary_category") and metadata["primary_category"] != "N/A":
208
- meta_lines.append(f"Category: {metadata['primary_category']}")
209
- return " | ".join(meta_lines)
210
 
211
  def format_docs(docs):
212
  """Format documents with clear structure and metadata."""
213
  if not docs:
214
- return "No relevant documents found in the database. The system may need more data or a rebuilt index."
215
 
216
  formatted_chunks = []
217
  for idx, doc in enumerate(docs, start=1):
218
  meta_str = _format_metadata(doc.metadata)
219
  content = doc.page_content.strip()
220
 
221
- # Truncate very long content
222
  if len(content) > 1000:
223
  content = content[:1000] + "..."
224
 
225
- formatted_chunk = f"[Document {idx}]\n{meta_str}\n\nContent: {content}"
226
  formatted_chunks.append(formatted_chunk)
227
 
228
  return "\n\n" + "="*80 + "\n\n".join(formatted_chunks)
@@ -240,18 +620,12 @@ def build_chain():
240
  use_rerank=use_rerank,
241
  )
242
 
243
- # Improved retrieval with error handling
244
  def retrieval_with_logging(q):
245
  try:
246
  docs = retriever.get_relevant_documents(q)
247
- formatted = format_docs(docs)
248
- # Debug info in sidebar
249
- with st.sidebar:
250
- st.write(f"πŸ“Š Retrieved {len(docs)} documents")
251
- return formatted
252
  except Exception as e:
253
- st.error(f"Retrieval error: {e}")
254
- return "Error retrieving documents. Please try again or rebuild the index."
255
 
256
  retrieval_runnable = RunnableLambda(retrieval_with_logging)
257
  chain = {"context": retrieval_runnable, "question": RunnablePassthrough()} | prompt | llm
@@ -260,58 +634,276 @@ def build_chain():
260
  # Initialize session state
261
  if "messages" not in st.session_state:
262
  st.session_state["messages"] = []
 
263
 
264
- # Display index stats
265
- try:
266
- index_stats = vectorstore.index.ntotal if hasattr(vectorstore, 'index') else "Unknown"
267
- st.sidebar.info(f"πŸ“š Vector store contains {index_stats} embeddings")
268
- except:
269
- pass
 
 
 
 
 
 
 
 
 
270
 
271
- query = st.chat_input("Ask me something...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
 
273
- if query:
274
- with st.spinner("πŸ” Searching papers and generating response..."):
275
- rag_chain, adv_retriever = build_chain()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
 
277
- # Get documents first for display
278
- try:
279
- docs = adv_retriever.get_relevant_documents(query)
280
-
281
- # Show retrieval success/failure
282
- if not docs:
283
- st.warning("⚠️ No documents found matching your query. Try broader terms or rebuild the index.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
 
285
- # Generate answer
286
- answer = rag_chain.invoke(query)
287
- answer_text = answer.content if hasattr(answer, "content") else str(answer)
 
 
 
 
 
 
 
 
 
 
 
 
 
288
 
289
- except Exception as e:
290
- msg = str(e)
291
- if "models/" in msg and "not found" in msg.lower():
292
- st.error("Selected Gemini model not found or unsupported. Try 'gemini-1.5-pro-latest' or check your model name in the sidebar.")
293
  else:
294
- st.error(f"LLM error: {e}")
295
- st.stop()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
 
297
- # Save to session state
298
- st.session_state["messages"].append({
299
- "query": query,
300
- "answer": answer_text,
301
- "context": docs
302
- })
303
-
304
- # Display conversation history
305
- for msg in st.session_state["messages"]:
306
- st.chat_message("user").write(msg["query"])
307
- with st.chat_message("assistant"):
308
- st.write(msg["answer"])
309
- with st.expander(f"πŸ“„ View {len(msg['context'])} Retrieved Documents"):
310
- if not msg["context"]:
311
- st.info("No documents were retrieved for this query.")
312
- for i, doc in enumerate(msg["context"], 1):
313
- st.markdown(f"### Document {i}")
314
- st.write(doc.page_content[:500] + ("..." if len(doc.page_content) > 500 else ""))
315
- if doc.metadata:
316
- st.caption(_format_metadata(doc.metadata))
317
- st.divider()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  from langchain.schema.runnable import RunnablePassthrough
31
  from langchain_core.runnables import RunnableLambda
32
  from langchain_groq import ChatGroq
33
+ import time
34
 
35
  from src.vector_store import build_or_load_vectorstore
36
+ from src.ingestion import load_data_subset, preprocess_dataframe, df_to_documents, load_hf_dataset
37
  from src.retriever import build_advanced_retriever
38
+ from src.config import DATA_PATH, FAISS_INDEX_PATH, GROQ_API_KEY, GEMINI_API_KEY, ANTHROPIC_API_KEY, GROQ_MODEL, GEMINI_MODEL, ANTHROPIC_MODEL
39
 
40
  load_dotenv(find_dotenv())
41
 
42
+ # PAGE CONFIG - Must be first Streamlit command
43
+ st.set_page_config(
44
+ page_title="Research Assistant",
45
+ page_icon="πŸ€–",
46
+ layout="wide",
47
+ initial_sidebar_state="expanded" # Start with sidebar expanded
48
+ )
49
 
50
+ # ENHANCED CUSTOM CSS - ChatGPT-like styling
51
+ st.markdown("""
52
+ <style>
53
+ /* Hide Streamlit branding */
54
+ #MainMenu {visibility: hidden;}
55
+ footer {visibility: hidden;}
56
+
57
+ /* Make sure header is visible for sidebar toggle */
58
+ header {visibility: visible !important;}
59
+
60
+ /* Style the sidebar toggle button to be more visible */
61
+ [data-testid="collapsedControl"] {
62
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
63
+ border-radius: 0 8px 8px 0 !important;
64
+ padding: 8px !important;
65
+ margin-top: 60px !important;
66
+ box-shadow: 0 4px 12px rgba(102, 126, 234, 0.4) !important;
67
+ }
68
+
69
+ [data-testid="collapsedControl"]:hover {
70
+ box-shadow: 0 6px 20px rgba(102, 126, 234, 0.6) !important;
71
+ transform: translateX(2px);
72
+ }
73
+
74
+ /* Overall app styling */
75
+ .stApp {
76
+ background: linear-gradient(180deg, #0f1419 0%, #1a1f2e 100%);
77
+ }
78
+
79
+ /* Main chat container */
80
+ .main .block-container {
81
+ padding-top: 2rem;
82
+ padding-bottom: 2rem;
83
+ max-width: 900px;
84
+ margin: 0 auto;
85
+ }
86
+
87
+ /* Chat input styling - Fixed at bottom like ChatGPT */
88
+ .stChatInputContainer {
89
+ background: transparent;
90
+ border: none;
91
+ padding: 1rem 0;
92
+ }
93
+
94
+ .stChatInput > div {
95
+ background: rgba(255, 255, 255, 0.05);
96
+ border: 1px solid rgba(255, 255, 255, 0.1);
97
+ border-radius: 24px;
98
+ padding: 12px 20px;
99
+ backdrop-filter: blur(10px);
100
+ transition: all 0.3s ease;
101
+ }
102
+
103
+ .stChatInput > div:hover {
104
+ background: rgba(255, 255, 255, 0.08);
105
+ border-color: rgba(255, 255, 255, 0.2);
106
+ box-shadow: 0 4px 12px rgba(0, 0, 0, 0.3);
107
+ }
108
+
109
+ .stChatInput > div:focus-within {
110
+ background: rgba(255, 255, 255, 0.1);
111
+ border-color: #10a37f;
112
+ box-shadow: 0 0 0 3px rgba(16, 163, 127, 0.1);
113
+ }
114
+
115
+ /* User messages - Right aligned with gradient */
116
+ [data-testid="stChatMessage"]:has([data-testid*="user"]) {
117
+ background: transparent;
118
+ justify-content: flex-end;
119
+ }
120
+
121
+ [data-testid="stChatMessage"]:has([data-testid*="user"]) [data-testid="stChatMessageContent"] {
122
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
123
+ border-radius: 18px;
124
+ padding: 14px 18px;
125
+ margin-left: auto;
126
+ max-width: 75%;
127
+ color: white;
128
+ box-shadow: 0 4px 12px rgba(102, 126, 234, 0.3);
129
+ }
130
+
131
+ /* Bot messages - Left aligned with subtle styling */
132
+ [data-testid="stChatMessage"]:not(:has([data-testid*="user"])) {
133
+ background: transparent;
134
+ justify-content: flex-start;
135
+ }
136
+
137
+ [data-testid="stChatMessage"]:not(:has([data-testid*="user"])) [data-testid="stChatMessageContent"] {
138
+ background: rgba(255, 255, 255, 0.03);
139
+ border: 1px solid rgba(255, 255, 255, 0.08);
140
+ border-radius: 18px;
141
+ padding: 14px 18px;
142
+ margin-right: auto;
143
+ max-width: 85%;
144
+ color: #e8e8e8;
145
+ box-shadow: 0 2px 8px rgba(0, 0, 0, 0.2);
146
+ backdrop-filter: blur(10px);
147
+ }
148
+
149
+ /* Avatar styling */
150
+ [data-testid="stChatMessage"] [data-testid="stAvatar"] {
151
+ width: 36px;
152
+ height: 36px;
153
+ border-radius: 50%;
154
+ border: 2px solid rgba(255, 255, 255, 0.1);
155
+ }
156
+
157
+ /* User avatar - gradient border */
158
+ [data-testid="stChatMessage"]:has([data-testid*="user"]) [data-testid="stAvatar"] {
159
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
160
+ border: 2px solid transparent;
161
+ box-shadow: 0 2px 8px rgba(102, 126, 234, 0.4);
162
+ }
163
+
164
+ /* Bot avatar - themed */
165
+ [data-testid="stChatMessage"]:not(:has([data-testid*="user"])) [data-testid="stAvatar"] {
166
+ background: linear-gradient(135deg, #10a37f 0%, #0d8a6a 100%);
167
+ border: 2px solid rgba(16, 163, 127, 0.3);
168
+ box-shadow: 0 2px 8px rgba(16, 163, 127, 0.3);
169
+ }
170
+
171
+ /* Sidebar styling */
172
+ [data-testid="stSidebar"] {
173
+ background: rgba(15, 20, 25, 0.95);
174
+ border-right: 1px solid rgba(255, 255, 255, 0.08);
175
+ backdrop-filter: blur(20px);
176
+ }
177
+
178
+ [data-testid="stSidebar"] .stButton button {
179
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
180
+ border: none;
181
+ border-radius: 12px;
182
+ color: white;
183
+ padding: 10px 20px;
184
+ font-weight: 600;
185
+ transition: all 0.3s ease;
186
+ }
187
+
188
+ [data-testid="stSidebar"] .stButton button:hover {
189
+ transform: translateY(-2px);
190
+ box-shadow: 0 6px 20px rgba(102, 126, 234, 0.4);
191
+ }
192
+
193
+ /* Expander styling */
194
+ .streamlit-expanderHeader {
195
+ background: rgba(255, 255, 255, 0.03);
196
+ border-radius: 12px;
197
+ border: 1px solid rgba(255, 255, 255, 0.08);
198
+ color: #b4b4b4;
199
+ padding: 12px 16px;
200
+ transition: all 0.3s ease;
201
+ }
202
+
203
+ .streamlit-expanderHeader:hover {
204
+ background: rgba(255, 255, 255, 0.06);
205
+ border-color: rgba(255, 255, 255, 0.15);
206
+ }
207
+
208
+ .streamlit-expanderContent {
209
+ background: rgba(255, 255, 255, 0.02);
210
+ border: 1px solid rgba(255, 255, 255, 0.05);
211
+ border-top: none;
212
+ border-radius: 0 0 12px 12px;
213
+ }
214
+
215
+ /* Divider styling */
216
+ hr {
217
+ border: none;
218
+ height: 1px;
219
+ background: linear-gradient(90deg, transparent, rgba(255, 255, 255, 0.1), transparent);
220
+ margin: 2rem 0;
221
+ }
222
+
223
+ /* Info boxes */
224
+ .stAlert {
225
+ background: rgba(16, 163, 127, 0.1);
226
+ border: 1px solid rgba(16, 163, 127, 0.3);
227
+ border-radius: 12px;
228
+ color: #a8e6cf;
229
+ }
230
+
231
+ /* Scrollbar styling */
232
+ ::-webkit-scrollbar {
233
+ width: 8px;
234
+ height: 8px;
235
+ }
236
+
237
+ ::-webkit-scrollbar-track {
238
+ background: rgba(255, 255, 255, 0.02);
239
+ }
240
+
241
+ ::-webkit-scrollbar-thumb {
242
+ background: rgba(255, 255, 255, 0.15);
243
+ border-radius: 10px;
244
+ }
245
+
246
+ ::-webkit-scrollbar-thumb:hover {
247
+ background: rgba(255, 255, 255, 0.25);
248
+ }
249
+
250
+ /* Typography improvements */
251
+ h1, h2, h3 {
252
+ color: #f0f0f0;
253
+ font-weight: 600;
254
+ }
255
+
256
+ p {
257
+ line-height: 1.7;
258
+ color: #d4d4d4;
259
+ }
260
+
261
+ /* Slider styling */
262
+ .stSlider > div > div > div > div {
263
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
264
+ }
265
+
266
+ /* Checkbox styling */
267
+ .stCheckbox > label > div[data-testid="stMarkdownContainer"] > p {
268
+ color: #d4d4d4;
269
+ }
270
+
271
+ /* Thinking animation */
272
+ @keyframes pulse {
273
+ 0%, 100% { opacity: 0.6; }
274
+ 50% { opacity: 1; }
275
+ }
276
+
277
+ .thinking {
278
+ animation: pulse 1.5s ease-in-out infinite;
279
+ color: #10a37f;
280
+ font-style: italic;
281
+ }
282
+
283
+ /* Welcome message styling */
284
+ .welcome-message {
285
+ background: linear-gradient(135deg, rgba(16, 163, 127, 0.1) 0%, rgba(102, 126, 234, 0.1) 100%);
286
+ border: 1px solid rgba(16, 163, 127, 0.3);
287
+ border-radius: 16px;
288
+ padding: 24px;
289
+ margin: 20px 0;
290
+ text-align: center;
291
+ box-shadow: 0 4px 16px rgba(16, 163, 127, 0.1);
292
+ }
293
+
294
+ .welcome-message h2 {
295
+ background: linear-gradient(135deg, #10a37f 0%, #667eea 100%);
296
+ -webkit-background-clip: text;
297
+ -webkit-text-fill-color: transparent;
298
+ margin-bottom: 12px;
299
+ }
300
+
301
+ /* Suggestion chips */
302
+ .suggestion-chip {
303
+ display: inline-block;
304
+ background: rgba(255, 255, 255, 0.05);
305
+ border: 1px solid rgba(255, 255, 255, 0.1);
306
+ border-radius: 20px;
307
+ padding: 8px 16px;
308
+ margin: 6px;
309
+ color: #b4b4b4;
310
+ cursor: pointer;
311
+ transition: all 0.3s ease;
312
+ }
313
+
314
+ .suggestion-chip:hover {
315
+ background: rgba(16, 163, 127, 0.15);
316
+ border-color: rgba(16, 163, 127, 0.4);
317
+ color: #10a37f;
318
+ transform: translateY(-2px);
319
+ }
320
+ </style>
321
+ """, unsafe_allow_html=True)
322
+
323
+ # Title with emoji and clean design
324
+ col1, col2, col3 = st.columns([1, 6, 1])
325
+ with col2:
326
+ st.markdown("<h1 style='text-align: center; margin-bottom: 0;'>πŸ€– Research Assistant</h1>", unsafe_allow_html=True)
327
+ st.markdown("<p style='text-align: center; color: #888; margin-top: 0;'>Powered by Multi-LLM RAG + FAISS</p>", unsafe_allow_html=True)
328
+
329
+ # Sidebar controls with improved organization
330
  with st.sidebar:
331
+ st.markdown("### βš™οΈ Configuration")
332
+
333
+ with st.expander("πŸ“Š Dataset Info", expanded=False):
334
+ st.markdown("""
335
+ **Source:** CShorten/ML-ArXiv-Papers
336
+ **Focus:** Machine Learning Research
337
+ **Platform:** Hugging Face
338
+ """)
339
+
340
+ st.markdown("---")
341
+
342
+ with st.expander("πŸ” Retrieval Settings", expanded=False):
343
+ base_k = st.slider("Initial fetch", 4, 30, 20, 1, help="Number of documents to initially retrieve")
344
+ rerank_k = st.slider("Final docs", 1, 12, 8, 1, help="Number of documents after reranking")
345
+ dynamic = st.checkbox("Dynamic k", True, help="Adjust retrieval size dynamically")
346
+ use_rerank = st.checkbox("Use reranking", True, help="Apply reranking for better relevance")
347
+
348
+ with st.expander("πŸ”§ Advanced Filters"):
349
+ primary_category = st.text_input("Category filter", "", help="Filter by arXiv category") or None
350
+ col1, col2 = st.columns(2)
351
+ with col1:
352
+ year_min = st.number_input("Min year", value=0, step=1)
353
+ with col2:
354
+ year_max = st.number_input("Max year", value=0, step=1)
355
+ if year_min == 0:
356
+ year_min = None
357
+ if year_max == 0:
358
+ year_max = None
359
+
360
+ st.markdown("---")
361
+
362
+ with st.expander("πŸ”„ Index Management", expanded=False):
363
+ subset_size = st.number_input("Dataset size", 1000, 100000, 10000, 1000)
364
+ rebuild = st.button("πŸ”¨ Rebuild Index", use_container_width=True)
365
+
366
+ st.markdown("---")
367
+
368
+ with st.expander("πŸ€– LLM Provider", expanded=False):
369
+ # Determine default provider based on available API keys
370
+ if ANTHROPIC_API_KEY:
371
+ default_provider = "Anthropic (Claude)"
372
+ elif GEMINI_API_KEY:
373
+ default_provider = "Gemini"
374
+ elif GROQ_API_KEY:
375
+ default_provider = "Groq"
376
+ else:
377
+ default_provider = "Gemini"
378
+
379
+ available_providers = ["Anthropic (Claude)", "Gemini", "Groq"]
380
+ try:
381
+ default_index = available_providers.index(default_provider)
382
+ except ValueError:
383
+ default_index = 0
384
+
385
+ provider = st.selectbox("Provider", available_providers, index=default_index)
386
+
387
+ if provider == "Anthropic (Claude)":
388
+ ui_anthropic_model = st.selectbox(
389
+ "Model",
390
+ [
391
+ "claude-sonnet-4-5-20250929",
392
+ "claude-opus-4-1-20250805",
393
+ "claude-opus-4-20250514",
394
+ "claude-sonnet-4-20250514",
395
+ "claude-3-7-sonnet-20250219",
396
+ "claude-3-5-haiku-20241022",
397
+ "claude-3-haiku-20240307"
398
+ ],
399
+ index=3
400
+ )
401
+ ui_gemini_model = None
402
+ ui_groq_model = None
403
+ elif provider == "Gemini":
404
+ ui_gemini_model = st.text_input("Model", GEMINI_MODEL)
405
+ ui_groq_model = None
406
+ ui_anthropic_model = None
407
+ else:
408
+ ui_groq_model = st.text_input("Model", GROQ_MODEL)
409
+ ui_gemini_model = None
410
+ ui_anthropic_model = None
411
+
412
+ # Stats at bottom
413
+ st.markdown("---")
414
+ try:
415
+ if 'vectorstore' in locals():
416
+ index_stats = vectorstore.index.ntotal if hasattr(vectorstore, 'index') else "Unknown"
417
+ st.metric("πŸ“š Embeddings", f"{index_stats:,}" if isinstance(index_stats, int) else index_stats)
418
+ except:
419
+ pass
420
 
421
  # Build or load vectorstore
422
+ def _load_df_from_hf(num_records: int):
423
+ """Load dataset from Hugging Face."""
424
  try:
425
+ with st.spinner("πŸ”„ Loading ML papers from Hugging Face..."):
426
+ df = load_hf_dataset(num_records=num_records, dataset_name="CShorten/ML-ArXiv-Papers")
427
+ return preprocess_dataframe(df)
428
  except Exception as e:
429
+ st.error(f"❌ Failed to load dataset: {e}")
430
+ st.info("πŸ’‘ Make sure 'datasets' is installed: `pip install datasets`")
431
+ st.stop()
 
 
 
 
 
 
 
 
 
 
432
 
433
  if rebuild or not os.path.exists(FAISS_INDEX_PATH):
434
+ with st.spinner("πŸ”¨ Building vector index..."):
435
+ df = _load_df_from_hf(num_records=int(subset_size))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
436
  docs = df_to_documents(df)
437
  vectorstore = build_or_load_vectorstore(
438
  docs,
439
  force_rebuild=True,
440
  chunk_method="semantic",
441
+ chunk_size=1000,
442
+ chunk_overlap=125
443
  )
444
  else:
445
  try:
446
  vectorstore = build_or_load_vectorstore([], force_rebuild=False)
447
  except Exception as e:
448
+ st.warning(f"⚠️ Index load failed. Rebuilding...")
449
+ with st.spinner("πŸ”¨ Rebuilding vector index..."):
450
+ df = _load_df_from_hf(num_records=50000)
 
 
 
 
451
  docs = df_to_documents(df)
452
  vectorstore = build_or_load_vectorstore(
453
  docs,
454
  force_rebuild=True,
455
  chunk_method="semantic",
456
+ chunk_size=1000,
457
+ chunk_overlap=125
458
  )
459
 
460
  def make_llm(provider_name: str):
461
+ if provider_name == "Anthropic (Claude)":
462
+ if not ANTHROPIC_API_KEY:
463
+ st.error("❌ ANTHROPIC_API_KEY not set")
464
+ st.stop()
465
+ try:
466
+ from langchain_anthropic import ChatAnthropic
467
+ return ChatAnthropic(
468
+ model=ui_anthropic_model or ANTHROPIC_MODEL,
469
+ temperature=0.7,
470
+ max_tokens=2048,
471
+ api_key=ANTHROPIC_API_KEY,
472
+ )
473
+ except Exception as e:
474
+ st.error(f"❌ Claude initialization failed: {e}")
475
+ st.stop()
476
+
477
  if provider_name == "Gemini":
478
  if not GEMINI_API_KEY:
479
+ st.error("❌ GEMINI_API_KEY not set")
480
  st.stop()
481
  try:
482
  from langchain_google_genai import ChatGoogleGenerativeAI
483
  return ChatGoogleGenerativeAI(
484
  model=ui_gemini_model or GEMINI_MODEL,
485
  temperature=0.7,
486
+ max_output_tokens=1024,
487
  api_key=GEMINI_API_KEY,
488
  )
 
 
 
489
  except Exception as e:
490
+ st.error(f"❌ Gemini initialization failed: {e}")
491
  st.stop()
492
+
493
  if not GROQ_API_KEY:
494
+ st.error("❌ No valid LLM provider configured")
495
  st.stop()
496
  return ChatGroq(
497
  model=ui_groq_model or GROQ_MODEL,
498
  temperature=0.7,
499
+ max_tokens=1024,
500
  groq_api_key=GROQ_API_KEY,
501
  )
502
 
503
  llm = make_llm(provider)
504
 
505
+ # Relevance checking prompt
506
+ relevance_check_prompt = """You are a research paper relevance checker. Your task is to determine if the retrieved documents are relevant to the user's question.
507
+
508
+ Retrieved Documents:
509
+ {context}
510
+
511
+ User Question: {question}
512
+
513
+ Instructions:
514
+ - Carefully analyze whether the retrieved documents contain information that can answer the user's question
515
+ - Consider if the documents discuss the topic, concepts, or methods mentioned in the question
516
+ - Respond with ONLY one word: "RELEVANT" or "IRRELEVANT"
517
+ - Be strict: if the documents are only tangentially related or don't actually address the question, respond "IRRELEVANT"
518
+
519
+ Response:"""
520
+
521
+ relevance_prompt = PromptTemplate(template=relevance_check_prompt, input_variables=["context", "question"])
522
+
523
+ # IMPROVED PROMPT
524
+ prompt_template = """You are a knowledgeable and helpful research assistant specializing in arXiv papers. You MUST ONLY answer questions based on the provided research papers context.
525
 
526
  Context from Research Papers:
527
  {context}
528
 
529
  User Question: {question}
530
 
531
+ CRITICAL RULES:
532
+ - ONLY use information from the provided research papers context above
533
+ - DO NOT use your general knowledge or training data
534
+ - If the context doesn't contain relevant information, you MUST respond with: "I couldn't find relevant information about this topic in the available research papers. The retrieved documents don't address your question. Please try different search terms or the database may not contain papers on this specific topic."
535
+
536
  Instructions:
537
+ - Analyze the user's question and provide a thorough, well-structured response BASED ONLY ON THE CONTEXT
538
+ - Be conversational and descriptive - explain concepts clearly with sufficient detail
539
+ - Use multiple paragraphs when needed to fully address the question
540
+
541
+ **For paper listing requests** (e.g., "find papers", "list papers", "show papers"):
542
+ Format as a structured list with detailed summaries:
543
+
544
+ **Paper #[Number]: [Title]**
545
+ - **Authors:** [Author names]
546
+ - **Year:** [Publication year]
547
+ - **ArXiv ID:** [ID if available]
548
+ - **Category:** [Research category]
549
+ - **Summary:** [3-4 sentences explaining the paper's objectives, methodology, key contributions, and findings based on the context]
550
+
551
+ **For specific questions** (e.g., "What is...", "Explain...", "How does...", "What is the purpose of..."):
552
+ - Provide a comprehensive, multi-paragraph answer that fully addresses the question USING ONLY THE CONTEXT
553
+ - Start with a clear overview or direct answer from the papers
554
+ - Elaborate with details, context, and explanations from the research papers
555
+ - Discuss relevant methodologies, findings, implications, or technical details found in the papers
556
+ - Cite sources naturally throughout (e.g., "According to the research by [Authors] (Year)...")
557
+ - Use clear transitions between ideas
558
+ - Conclude with key takeaways or significance when appropriate
559
+
560
+ **General Guidelines:**
561
+ - Write in a natural, conversational tone similar to ChatGPT
562
+ - Aim for depth and clarity - don't give one-liner responses
563
+ - Break complex information into digestible paragraphs
564
+ - Use examples and analogies when helpful from the context
565
+ - NEVER invent or hallucinate information not in the context
566
+ - Always prioritize being helpful, informative, and thorough - but ONLY based on the provided context
567
 
568
  Answer:"""
569
 
 
575
  return ""
576
  meta_lines = []
577
  if metadata.get("title"):
578
+ meta_lines.append(f"πŸ“„ {metadata['title']}")
579
  if metadata.get("id"):
580
+ meta_lines.append(f"πŸ”— {metadata['id']}")
581
  if metadata.get("authors") and metadata["authors"] != "N/A":
582
  authors = metadata['authors']
583
+ if len(authors) > 100:
584
  authors = authors[:100] + "..."
585
+ meta_lines.append(f"πŸ‘₯ {authors}")
586
  if metadata.get("year"):
587
+ meta_lines.append(f"πŸ“… {metadata['year']}")
588
  if metadata.get("primary_category") and metadata["primary_category"] != "N/A":
589
+ meta_lines.append(f"🏷️ {metadata['primary_category']}")
590
+ return " β€’ ".join(meta_lines)
591
 
592
  def format_docs(docs):
593
  """Format documents with clear structure and metadata."""
594
  if not docs:
595
+ return "No relevant documents found in the database."
596
 
597
  formatted_chunks = []
598
  for idx, doc in enumerate(docs, start=1):
599
  meta_str = _format_metadata(doc.metadata)
600
  content = doc.page_content.strip()
601
 
 
602
  if len(content) > 1000:
603
  content = content[:1000] + "..."
604
 
605
+ formatted_chunk = f"[Document {idx}]\n{meta_str}\n\n{content}"
606
  formatted_chunks.append(formatted_chunk)
607
 
608
  return "\n\n" + "="*80 + "\n\n".join(formatted_chunks)
 
620
  use_rerank=use_rerank,
621
  )
622
 
 
623
  def retrieval_with_logging(q):
624
  try:
625
  docs = retriever.get_relevant_documents(q)
626
+ return format_docs(docs)
 
 
 
 
627
  except Exception as e:
628
+ return f"Error retrieving documents: {e}"
 
629
 
630
  retrieval_runnable = RunnableLambda(retrieval_with_logging)
631
  chain = {"context": retrieval_runnable, "question": RunnablePassthrough()} | prompt | llm
 
634
  # Initialize session state
635
  if "messages" not in st.session_state:
636
  st.session_state["messages"] = []
637
+ st.session_state["show_welcome"] = True
638
 
639
+ # Welcome message with suggestions
640
+ if st.session_state.get("show_welcome", False):
641
+ st.markdown("""
642
+ <div class="welcome-message">
643
+ <h2>πŸ‘‹ Welcome to Research Assistant!</h2>
644
+ <p>I'm your AI-powered research companion. Ask me anything about Machine Learning papers!</p>
645
+ <div style="margin-top: 20px;">
646
+ <span class="suggestion-chip">πŸ” Find papers on transformers</span>
647
+ <span class="suggestion-chip">πŸ’‘ Explain attention mechanism</span>
648
+ <span class="suggestion-chip">πŸ“Š Compare CNN vs RNN</span>
649
+ <span class="suggestion-chip">🎯 Latest in reinforcement learning</span>
650
+ </div>
651
+ </div>
652
+ """, unsafe_allow_html=True)
653
+ st.session_state["show_welcome"] = False
654
 
655
+ # Helper functions
656
+ def is_casual_conversation(query_text):
657
+ """Check if the query is a greeting or casual conversation."""
658
+ query_lower = query_text.lower().strip()
659
+ greetings = ["hi", "hello", "hey", "good morning", "good afternoon", "good evening",
660
+ "hola", "greetings", "howdy", "yo", "sup", "what's up", "whats up"]
661
+ casual_patterns = [
662
+ "how are you", "how r u", "how do you do", "what's up", "whats up",
663
+ "who are you", "what are you", "what is your name", "your name",
664
+ "what can you do", "help me", "can you help", "thank you", "thanks",
665
+ "bye", "goodbye", "see you", "nice to meet you", "pleasure"
666
+ ]
667
+
668
+ if query_lower in greetings:
669
+ return True
670
+ for pattern in casual_patterns:
671
+ if pattern in query_lower:
672
+ return True
673
+ return False
674
 
675
+ def get_casual_response(query_text):
676
+ """Generate appropriate response for casual conversation."""
677
+ query_lower = query_text.lower().strip()
678
+
679
+ if any(word in query_lower for word in ["hi", "hello", "hey", "hola", "howdy", "yo"]):
680
+ return "Hello! πŸ‘‹ I'm your AI Research Assistant for Machine Learning papers. How can I help you today?"
681
+ if "good morning" in query_lower:
682
+ return "Good morning! β˜€οΈ Ready to explore some ML research? What interests you today?"
683
+ if "good afternoon" in query_lower:
684
+ return "Good afternoon! 🌀️ Let's dive into some research! What would you like to learn about?"
685
+ if "good evening" in query_lower:
686
+ return "Good evening! πŸŒ™ I'm here to help with ML research. What topic interests you?"
687
+ if any(phrase in query_lower for phrase in ["how are you", "how r u", "how do you do"]):
688
+ return "I'm doing great, thanks! 😊 Ready to help you explore ML research. What's on your mind?"
689
+ if any(phrase in query_lower for phrase in ["who are you", "what are you", "your name"]):
690
+ return "I'm an AI Research Assistant specialized in Machine Learning! πŸ€– I help you find papers, explain concepts, and answer research questions. What would you like to know?"
691
+ if any(phrase in query_lower for phrase in ["what can you do", "help me", "can you help"]):
692
+ return """I can help you with:
693
 
694
+ πŸ” **Finding research papers** on specific ML topics
695
+ πŸ“š **Explaining ML concepts** from published research
696
+ πŸ’‘ **Answering questions** about techniques and methods
697
+ πŸŽ“ **Exploring** the latest ML research developments
698
+
699
+ Try asking:
700
+ - "Find papers on deep learning"
701
+ - "What is transfer learning?"
702
+ - "Explain adversarial training"
703
+
704
+ What interests you?"""
705
+ if any(word in query_lower for word in ["thank you", "thanks", "thx"]):
706
+ return "You're welcome! 😊 Happy to help! Let me know if you have other questions."
707
+ if any(word in query_lower for word in ["bye", "goodbye", "see you"]):
708
+ return "Goodbye! πŸ‘‹ Come back anytime for ML research help. Happy learning!"
709
+
710
+ return "I'm here to help with Machine Learning research! 😊 Ask me about any ML topics or papers."
711
+
712
+ # Chat input
713
+ query = st.chat_input("πŸ’¬ Ask me anything about ML research...")
714
+
715
+ # Display chat history
716
+ for i, msg in enumerate(st.session_state["messages"]):
717
+ # Show user message
718
+ st.chat_message("user", avatar="πŸ‘€").write(msg["query"])
719
+
720
+ # Show assistant response if available
721
+ if msg.get("answer") is not None:
722
+ with st.chat_message("assistant", avatar="πŸ€–"):
723
+ st.write(msg["answer"])
724
+ if msg.get("context") and len(msg["context"]) > 0:
725
+ with st.expander(f"πŸ“„ View {len(msg['context'])} Retrieved Documents", expanded=False):
726
+ for idx, doc in enumerate(msg["context"], 1):
727
+ st.markdown(f"**πŸ“Ž Document {idx}**")
728
+ st.caption(_format_metadata(doc.metadata))
729
+ st.text_area(
730
+ f"Content {idx}",
731
+ doc.page_content[:800] + ("..." if len(doc.page_content) > 800 else ""),
732
+ height=150,
733
+ key=f"doc_{i}_{idx}",
734
+ disabled=True
735
+ )
736
+ if idx < len(msg["context"]):
737
+ st.markdown("---")
738
+ else:
739
+ # Answer is being generated - show thinking indicator
740
+ with st.chat_message("assistant", avatar="πŸ€–"):
741
+ thinking_placeholder = st.empty()
742
+ thinking_placeholder.markdown('<p class="thinking">πŸ” Searching research papers...</p>', unsafe_allow_html=True)
743
 
744
+ # Check if casual conversation
745
+ if is_casual_conversation(msg["query"]):
746
+ casual_response = get_casual_response(msg["query"])
747
+
748
+ # Smooth streaming effect
749
+ response_placeholder = st.empty()
750
+ full_response = ""
751
+ words = casual_response.split()
752
+
753
+ for word in words:
754
+ full_response += word + " "
755
+ response_placeholder.markdown(full_response)
756
+ time.sleep(0.02)
757
+
758
+ st.session_state["messages"][i]["answer"] = casual_response
759
+ st.rerun()
760
 
 
 
 
 
761
  else:
762
+ # Research question - full RAG pipeline
763
+ rag_chain, adv_retriever = build_chain()
764
+
765
+ docs = []
766
+ answer_text = ""
767
+ error_occurred = False
768
+
769
+ try:
770
+ docs = adv_retriever.get_relevant_documents(msg["query"])
771
+
772
+ if not docs:
773
+ answer_text = """I couldn't find any relevant research papers in the database that match your query.
774
+
775
+ **πŸ’‘ Suggestions:**
776
+ - Try using broader or different search terms
777
+ - Check the spelling of technical terms
778
+ - The database may not contain papers on this specific topic
779
+ - Consider rebuilding the index with more data
780
+
781
+ The current database focuses on ArXiv ML papers, but may not cover all research areas comprehensively."""
782
+ else:
783
+ thinking_placeholder.markdown('<p class="thinking">🧠 Analyzing documents...</p>', unsafe_allow_html=True)
784
+
785
+ # Check relevance
786
+ formatted_context = format_docs(docs)
787
+ relevance_check_chain = {"context": RunnablePassthrough(), "question": RunnablePassthrough()} | relevance_prompt | llm
788
+ relevance_result = relevance_check_chain.invoke({"context": formatted_context, "question": msg["query"]})
789
+ relevance_text = relevance_result.content if hasattr(relevance_result, "content") else str(relevance_result)
790
+
791
+ if "IRRELEVANT" in relevance_text.strip().upper():
792
+ answer_text = f"""I found {len(docs)} documents in the database, but they don't contain relevant information about your question.
793
+
794
+ **πŸ“‹ Retrieved topics:**
795
+ - {docs[0].metadata.get('title', 'Various topics') if docs else 'N/A'}
796
+
797
+ **πŸ’‘ Suggestions:**
798
+ - Try rephrasing with different keywords
799
+ - Use more specific technical terms
800
+ - Search for related concepts or broader topics
801
+ - The database may not have papers specifically on this topic
802
+
803
+ I can only provide answers based on the ArXiv papers in the database."""
804
+ else:
805
+ # Generate answer with streaming
806
+ thinking_placeholder.markdown('<p class="thinking">✍️ Generating response...</p>', unsafe_allow_html=True)
807
+ answer = rag_chain.invoke(msg["query"])
808
+ answer_text = answer.content if hasattr(answer, "content") else str(answer)
809
+
810
+ except Exception as e:
811
+ error_occurred = True
812
+ msg_err = str(e)
813
+ if "models/" in msg_err and "not found" in msg_err.lower():
814
+ answer_text = "⚠️ Selected model not found. Try a different model in the sidebar."
815
+ else:
816
+ answer_text = f"⚠️ An error occurred: {e}\n\nPlease try again or rebuild the index."
817
+
818
+ # Clear thinking and display response with streaming
819
+ thinking_placeholder.empty()
820
+
821
+ # Stream response
822
+ import re
823
+ response_placeholder = st.empty()
824
+ parts = re.split(r'(\n\n|(?<=[.!?])\s+)', answer_text)
825
+
826
+ full_response = ""
827
+ for part in parts:
828
+ full_response += part
829
+ response_placeholder.markdown(full_response)
830
+ time.sleep(0.03)
831
+
832
+ # Update session state
833
+ st.session_state["messages"][i]["answer"] = answer_text
834
+ st.session_state["messages"][i]["context"] = docs
835
+
836
+ # Show retrieved documents
837
+ if docs:
838
+ with st.expander(f"πŸ“„ View {len(docs)} Retrieved Documents", expanded=False):
839
+ for idx, doc in enumerate(docs, 1):
840
+ st.markdown(f"**πŸ“Ž Document {idx}**")
841
+ st.caption(_format_metadata(doc.metadata))
842
+ st.text_area(
843
+ f"Content {idx}",
844
+ doc.page_content[:800] + ("..." if len(doc.page_content) > 800 else ""),
845
+ height=150,
846
+ key=f"new_doc_{i}_{idx}",
847
+ disabled=True
848
+ )
849
+ if idx < len(docs):
850
+ st.markdown("---")
851
+
852
+ st.rerun()
853
+
854
+ # Process new query
855
+ if query:
856
+ # Add message to session state immediately
857
+ st.session_state["messages"].append({
858
+ "query": query,
859
+ "answer": None,
860
+ "context": []
861
+ })
862
+
863
+ # Force rerun to show the user message immediately
864
+ st.rerun()
865
+
866
+ # Footer with tips - only show if there are messages
867
+ if len(st.session_state["messages"]) > 0:
868
+ st.markdown("---")
869
+ with st.expander("πŸ’‘ Tips for Better Results", expanded=False):
870
+ col1, col2 = st.columns(2)
871
 
872
+ with col1:
873
+ st.markdown("""
874
+ **🎯 Asking Better Questions**
875
+
876
+ βœ… Use specific ML terminology
877
+ βœ… Mention techniques or methods
878
+ βœ… Ask for comparisons
879
+ βœ… Reference specific problems
880
+
881
+ **Examples:**
882
+ - "Papers on transformer architecture"
883
+ - "Compare CNNs vs Vision Transformers"
884
+ - "Explain BERT training methodology"
885
+ """)
886
+
887
+ with col2:
888
+ st.markdown("""
889
+ **πŸ“š Understanding Responses**
890
+
891
+ βœ… All answers from actual papers
892
+ βœ… View source documents anytime
893
+ βœ… Check relevance of results
894
+ βœ… Adjust settings if needed
895
+
896
+ **⚑ Advanced Tips:**
897
+ - Use sidebar filters (year, category)
898
+ - Adjust retrieval settings
899
+ - Try different LLM providers
900
+ - Rebuild index for fresh data
901
+ """)
902
+
903
+ # Add a "Clear Chat" button at the bottom of sidebar
904
+ with st.sidebar:
905
+ st.markdown("---")
906
+ if st.button("πŸ—‘οΈ Clear Chat History", use_container_width=True):
907
+ st.session_state["messages"] = []
908
+ st.session_state["show_welcome"] = True
909
+ st.rerun()