research-rag-chatbot-old / streamlit_app.py
Wasifjafri's picture
src error
ad1095f
raw
history blame
4.06 kB
import os
import streamlit as st
from dotenv import load_dotenv, find_dotenv
from langchain.prompts import PromptTemplate
from langchain.schema.runnable import RunnablePassthrough
from langchain_core.runnables import RunnableLambda
from langchain_groq import ChatGroq
from src.vector_store import build_or_load_vectorstore
from src.ingestion import load_data_subset, preprocess_dataframe, df_to_documents
from src.retriever import build_advanced_retriever
from src.config import DATA_PATH, FAISS_INDEX_PATH, GROQ_API_KEY
load_dotenv(find_dotenv())
st.set_page_config(page_title="πŸ“„ Research Paper RAG Chatbot", page_icon="πŸ’¬", layout="wide")
st.title("πŸ“„ Research Paper RAG Chatbot (Groq + FAISS + Rerank)")
# Sidebar controls
with st.sidebar:
st.header("Retrieval Settings")
base_k = st.slider("Initial fetch (base_k)", 4, 30, 16, 1)
rerank_k = st.slider("Final docs (rerank_k)", 1, 12, 6, 1)
dynamic = st.checkbox("Dynamic k", True)
use_rerank = st.checkbox("Use reranking", True)
primary_category = st.text_input("Primary category filter", "") or None
year_min = st.number_input("Min year", value=0, step=1)
year_max = st.number_input("Max year", value=0, step=1)
if year_min == 0:
year_min = None
if year_max == 0:
year_max = None
rebuild = st.button("Rebuild index (semantic)")
subset_size = st.number_input("Subset records (rebuild)", 1000, 100000, 50000, 1000)
# Build or load vectorstore
if rebuild or not os.path.exists(FAISS_INDEX_PATH):
data_file = os.path.join(DATA_PATH, "arxiv-metadata-oai-snapshot.json")
if not os.path.exists(data_file):
st.error("Dataset missing. Run main pipeline first.")
st.stop()
with st.spinner("Building vector index..."):
df = load_data_subset(data_file, num_records=int(subset_size))
df = preprocess_dataframe(df)
docs = df_to_documents(df)
vectorstore = build_or_load_vectorstore(
docs,
force_rebuild=True,
chunk_method="semantic",
chunk_size=800,
chunk_overlap=120
)
else:
vectorstore = build_or_load_vectorstore([], force_rebuild=False)
llm = ChatGroq(
model="meta-llama/llama-4-maverick-17b-128e-instruct",
temperature=0.7,
max_tokens=512,
groq_api_key=GROQ_API_KEY,
)
prompt_template = """Answer the following question based on the provided context.
Context:
{context}
Question: {question}
Answer:"""
prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
def format_docs(docs):
return "\n\n".join(d.page_content for d in docs)
def build_chain():
retriever = build_advanced_retriever(
vectorstore,
base_k=base_k,
rerank_k=rerank_k,
primary_category=primary_category,
year_min=year_min,
year_max=year_max,
dynamic=dynamic,
use_rerank=use_rerank,
)
retrieval_runnable = RunnableLambda(lambda q: format_docs(retriever.get_relevant_documents(q)))
chain = {"context": retrieval_runnable, "question": RunnablePassthrough()} | prompt | llm
return chain, retriever
if "messages" not in st.session_state:
st.session_state["messages"] = []
query = st.chat_input("Ask me something...")
if query:
rag_chain, adv_retriever = build_chain()
docs = adv_retriever.get_relevant_documents(query)
answer = rag_chain.invoke(query)
answer_text = answer.content if hasattr(answer, "content") else str(answer)
st.session_state["messages"].append({
"query": query,
"answer": answer_text,
"context": docs
})
for msg in st.session_state["messages"]:
st.chat_message("user").write(msg["query"])
with st.chat_message("assistant"):
st.write(msg["answer"])
with st.expander("Documents"):
for i, doc in enumerate(msg["context"]):
st.markdown(f"**Doc {i+1}**")
st.write(doc.page_content)
if doc.metadata:
st.caption(str(doc.metadata))