Spaces:
Sleeping
Sleeping
| import os | |
| import streamlit as st | |
| from dotenv import load_dotenv, find_dotenv | |
| from langchain.prompts import PromptTemplate | |
| from langchain.schema.runnable import RunnablePassthrough | |
| from langchain_core.runnables import RunnableLambda | |
| from langchain_groq import ChatGroq | |
| from src.vector_store import build_or_load_vectorstore | |
| from src.ingestion import load_data_subset, preprocess_dataframe, df_to_documents | |
| from src.retriever import build_advanced_retriever | |
| from src.config import DATA_PATH, FAISS_INDEX_PATH, GROQ_API_KEY | |
| load_dotenv(find_dotenv()) | |
| st.set_page_config(page_title="π Research Paper RAG Chatbot", page_icon="π¬", layout="wide") | |
| st.title("π Research Paper RAG Chatbot (Groq + FAISS + Rerank)") | |
| # Sidebar controls | |
| with st.sidebar: | |
| st.header("Retrieval Settings") | |
| base_k = st.slider("Initial fetch (base_k)", 4, 30, 16, 1) | |
| rerank_k = st.slider("Final docs (rerank_k)", 1, 12, 6, 1) | |
| dynamic = st.checkbox("Dynamic k", True) | |
| use_rerank = st.checkbox("Use reranking", True) | |
| primary_category = st.text_input("Primary category filter", "") or None | |
| year_min = st.number_input("Min year", value=0, step=1) | |
| year_max = st.number_input("Max year", value=0, step=1) | |
| if year_min == 0: | |
| year_min = None | |
| if year_max == 0: | |
| year_max = None | |
| rebuild = st.button("Rebuild index (semantic)") | |
| subset_size = st.number_input("Subset records (rebuild)", 1000, 100000, 50000, 1000) | |
| # Build or load vectorstore | |
| if rebuild or not os.path.exists(FAISS_INDEX_PATH): | |
| data_file = os.path.join(DATA_PATH, "arxiv-metadata-oai-snapshot.json") | |
| if not os.path.exists(data_file): | |
| st.error("Dataset missing. Run main pipeline first.") | |
| st.stop() | |
| with st.spinner("Building vector index..."): | |
| df = load_data_subset(data_file, num_records=int(subset_size)) | |
| df = preprocess_dataframe(df) | |
| docs = df_to_documents(df) | |
| vectorstore = build_or_load_vectorstore( | |
| docs, | |
| force_rebuild=True, | |
| chunk_method="semantic", | |
| chunk_size=800, | |
| chunk_overlap=120 | |
| ) | |
| else: | |
| vectorstore = build_or_load_vectorstore([], force_rebuild=False) | |
| llm = ChatGroq( | |
| model="meta-llama/llama-4-maverick-17b-128e-instruct", | |
| temperature=0.7, | |
| max_tokens=512, | |
| groq_api_key=GROQ_API_KEY, | |
| ) | |
| prompt_template = """Answer the following question based on the provided context. | |
| Context: | |
| {context} | |
| Question: {question} | |
| Answer:""" | |
| prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"]) | |
| def format_docs(docs): | |
| return "\n\n".join(d.page_content for d in docs) | |
| def build_chain(): | |
| retriever = build_advanced_retriever( | |
| vectorstore, | |
| base_k=base_k, | |
| rerank_k=rerank_k, | |
| primary_category=primary_category, | |
| year_min=year_min, | |
| year_max=year_max, | |
| dynamic=dynamic, | |
| use_rerank=use_rerank, | |
| ) | |
| retrieval_runnable = RunnableLambda(lambda q: format_docs(retriever.get_relevant_documents(q))) | |
| chain = {"context": retrieval_runnable, "question": RunnablePassthrough()} | prompt | llm | |
| return chain, retriever | |
| if "messages" not in st.session_state: | |
| st.session_state["messages"] = [] | |
| query = st.chat_input("Ask me something...") | |
| if query: | |
| rag_chain, adv_retriever = build_chain() | |
| docs = adv_retriever.get_relevant_documents(query) | |
| answer = rag_chain.invoke(query) | |
| answer_text = answer.content if hasattr(answer, "content") else str(answer) | |
| st.session_state["messages"].append({ | |
| "query": query, | |
| "answer": answer_text, | |
| "context": docs | |
| }) | |
| for msg in st.session_state["messages"]: | |
| st.chat_message("user").write(msg["query"]) | |
| with st.chat_message("assistant"): | |
| st.write(msg["answer"]) | |
| with st.expander("Documents"): | |
| for i, doc in enumerate(msg["context"]): | |
| st.markdown(f"**Doc {i+1}**") | |
| st.write(doc.page_content) | |
| if doc.metadata: | |
| st.caption(str(doc.metadata)) | |