arxiv-rag-demo / app.py
aakash-malhan's picture
Update app.py
a2f326b verified
# app.py - FINAL, 100% WORKING
import os
import subprocess
import gradio as gr
from langchain_community.vectorstores import Weaviate
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from sentence_transformers import SentenceTransformer
import weaviate
import time
import requests
# === FIX: Correct Groq import ===
from langchain_groq import ChatGroq
# === SECRETS ===
WEAVIATE_URL = os.environ["WEAVIATE_URL"]
WEAVIATE_KEY = os.environ["WEAVIATE_KEY"]
os.environ["GROQ_API_KEY"] = os.environ["GROQ_API_KEY"]
# === WAIT FOR WEAVIATE TO BE READY ===
def wait_for_weaviate(url, key, timeout=120):
headers = {"Authorization": f"Bearer {key}"}
ready_url = f"{url}/v1/.well-known/ready"
print("Waiting for Weaviate to be ready...")
for i in range(timeout):
try:
response = requests.get(ready_url, headers=headers)
if response.status_code == 200:
print("Weaviate is ready!")
return True
except:
pass
print(f"Attempt {i+1}/{timeout}... waiting 1s")
time.sleep(1)
print("Weaviate did not start in time.")
return False
# === AUTO-INGEST ON START ===
def run_ingestion():
print("Starting ingestion...")
if not wait_for_weaviate(WEAVIATE_URL, WEAVIATE_KEY):
print("Cannot connect to Weaviate. Skipping ingestion.")
return
result = subprocess.run(["python", "ingest.py"], capture_output=True, text=True)
print(result.stdout)
if result.returncode != 0:
print("Ingestion failed:", result.stderr)
else:
print("Ingestion complete!")
# Run once
run_ingestion()
# === RAG CHAIN (NO @gr.cache) ===
def get_rag_chain():
client = weaviate.Client(
url=WEAVIATE_URL,
auth_client_secret=weaviate.AuthApiKey(WEAVIATE_KEY)
)
embedder = SentenceTransformer("all-MiniLM-L6-v2")
vectorstore = Weaviate(client, "Paper", "text", embedding=embedder)
retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
llm = ChatGroq(model="llama-3.1-70b-instruct", temperature=0)
prompt = PromptTemplate.from_template(
"Answer using only this context:\n{context}\n\nQuestion: {question}\nAnswer:"
)
return RetrievalQA.from_chain_type(
llm=llm, retriever=retriever, chain_type_kwargs={"prompt": prompt}
)
def search(query):
try:
qa = get_rag_chain()
result = qa.invoke({"query": query})
answer = result["result"]
sources = "\n\n".join([
f"**{doc.metadata.get('title', 'No Title')}**\n{doc.page_content[:300]}..."
for doc in result["source_documents"]
])
return answer, sources or "No sources found."
except Exception as e:
return f"Error: {str(e)}", "Check logs."
# === GRADIO UI ===
with gr.Blocks(title="ArXiv RAG") as demo:
gr.Markdown("# ArXiv RAG Search")
gr.Markdown("10K+ papers • Llama-3.1 • Weaviate")
txt = gr.Textbox(label="Ask", placeholder="What is attention?", lines=2)
btn = gr.Button("Search", variant="primary")
out1 = gr.Markdown(); out2 = gr.Markdown()
btn.click(search, txt, [out1, out2])
demo.launch()