File size: 3,200 Bytes
a2f326b
b3b881b
 
872925a
 
 
 
 
 
a2f326b
 
872925a
a2f326b
 
0e1d30c
 
872925a
 
 
 
a2f326b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3b881b
a2f326b
 
 
 
 
0e1d30c
 
 
 
 
 
b3b881b
a2f326b
b3b881b
 
a2f326b
872925a
0e1d30c
 
 
 
872925a
 
 
 
13fe753
872925a
 
 
0e1d30c
a2f326b
0e1d30c
872925a
 
 
 
 
 
 
0e1d30c
872925a
 
a2f326b
872925a
0e1d30c
 
a2f326b
 
 
 
0e1d30c
a2f326b
 
 
0e1d30c
a2f326b
872925a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# app.py - FINAL, 100% WORKING
import os
import subprocess
import gradio as gr
from langchain_community.vectorstores import Weaviate
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from sentence_transformers import SentenceTransformer
import weaviate
import time
import requests

# === FIX: Correct Groq import ===
from langchain_groq import ChatGroq

# === SECRETS ===
WEAVIATE_URL = os.environ["WEAVIATE_URL"]
WEAVIATE_KEY = os.environ["WEAVIATE_KEY"]
os.environ["GROQ_API_KEY"] = os.environ["GROQ_API_KEY"]

# === WAIT FOR WEAVIATE TO BE READY ===
def wait_for_weaviate(url, key, timeout=120):
    headers = {"Authorization": f"Bearer {key}"}
    ready_url = f"{url}/v1/.well-known/ready"
    print("Waiting for Weaviate to be ready...")
    for i in range(timeout):
        try:
            response = requests.get(ready_url, headers=headers)
            if response.status_code == 200:
                print("Weaviate is ready!")
                return True
        except:
            pass
        print(f"Attempt {i+1}/{timeout}... waiting 1s")
        time.sleep(1)
    print("Weaviate did not start in time.")
    return False

# === AUTO-INGEST ON START ===
def run_ingestion():
    print("Starting ingestion...")
    if not wait_for_weaviate(WEAVIATE_URL, WEAVIATE_KEY):
        print("Cannot connect to Weaviate. Skipping ingestion.")
        return

    result = subprocess.run(["python", "ingest.py"], capture_output=True, text=True)
    print(result.stdout)
    if result.returncode != 0:
        print("Ingestion failed:", result.stderr)
    else:
        print("Ingestion complete!")

# Run once
run_ingestion()

# === RAG CHAIN (NO @gr.cache) ===
def get_rag_chain():
    client = weaviate.Client(
        url=WEAVIATE_URL,
        auth_client_secret=weaviate.AuthApiKey(WEAVIATE_KEY)
    )
    embedder = SentenceTransformer("all-MiniLM-L6-v2")
    vectorstore = Weaviate(client, "Paper", "text", embedding=embedder)
    retriever = vectorstore.as_retriever(search_kwargs={"k": 3})

    llm = ChatGroq(model="llama-3.1-70b-instruct", temperature=0)
    prompt = PromptTemplate.from_template(
        "Answer using only this context:\n{context}\n\nQuestion: {question}\nAnswer:"
    )
    return RetrievalQA.from_chain_type(
        llm=llm, retriever=retriever, chain_type_kwargs={"prompt": prompt}
    )

def search(query):
    try:
        qa = get_rag_chain()
        result = qa.invoke({"query": query})
        answer = result["result"]
        sources = "\n\n".join([
            f"**{doc.metadata.get('title', 'No Title')}**\n{doc.page_content[:300]}..."
            for doc in result["source_documents"]
        ])
        return answer, sources or "No sources found."
    except Exception as e:
        return f"Error: {str(e)}", "Check logs."

# === GRADIO UI ===
with gr.Blocks(title="ArXiv RAG") as demo:
    gr.Markdown("# ArXiv RAG Search")
    gr.Markdown("10K+ papers • Llama-3.1 • Weaviate")

    txt = gr.Textbox(label="Ask", placeholder="What is attention?", lines=2)
    btn = gr.Button("Search", variant="primary")
    out1 = gr.Markdown(); out2 = gr.Markdown()

    btn.click(search, txt, [out1, out2])

demo.launch()