arxiv-rag-demo / app.py
aakash-malhan's picture
Update app.py
13fe753 verified
raw
history blame
1.94 kB
import os
import subprocess
import gradio as gr
from langchain_community.vectorstores import Weaviate
from langchain_groq import ChatGroq
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from sentence_transformers import SentenceTransformer
import weaviate
# SECRETS
WEAVIATE_URL = os.environ["WEAVIATE_URL"]
WEAVIATE_KEY = os.environ["WEAVIATE_KEY"]
os.environ["GROQ_API_KEY"] = os.environ["GROQ_API_KEY"]
# AUTO-INGEST ON START
def run_ingestion():
print("Running ingestion...")
subprocess.run(["python", "ingest.py"], check=True)
print("Ingestion complete!")
run_ingestion()
# RAG CHAIN
@gr.cache
def get_rag_chain():
client = weaviate.Client(url=WEAVIATE_URL, auth_client_secret=weaviate.AuthApiKey(WEAVIATE_KEY))
embedder = SentenceTransformer("all-MiniLM-L6-v2")
vectorstore = Weaviate(client, "Paper", "text", embedding=embedder)
retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
llm = ChatGroq(model="llama-3.1-70b-instruct", temperature=0)
prompt = PromptTemplate.from_template(
"Answer using only this context:\n{context}\n\nQuestion: {question}\nAnswer:"
)
return RetrievalQA.from_chain_type(llm=llm, retriever=retriever, chain_type_kwargs={"prompt": prompt})
def search(query):
try:
qa = get_rag_chain()
result = qa.invoke({"query": query})
answer = result["result"]
sources = "\n\n".join([
f"**{doc.metadata['title']}**\n{doc.page_content[:300]}..."
for doc in result["source_documents"]
])
return answer, sources
except Exception as e:
return f"Error: {str(e)}", ""
# UI
with gr.Blocks() as demo:
gr.Markdown("# ArXiv RAG Search")
txt = gr.Textbox(label="Ask", placeholder="What is attention?")
btn = gr.Button("Search")
out1 = gr.Markdown(); out2 = gr.Markdown()
btn.click(search, txt, [out1, out2])
demo.launch()