krishnasimha's picture
Upload 3 files
7da7d62 verified
raw
history blame
2.57 kB
import streamlit as st
from datasets import load_dataset
from langchain.text_splitter import CharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
from langchain.chains import RetrievalQA
from langchain.llms import HuggingFacePipeline
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
st.set_page_config(page_title="Agentic RAG Demo", layout="wide")
st.title("πŸ“Š Financial Agentic RAG Demo")
@st.cache_resource(show_spinner=True)
def load_pipeline_and_retriever():
# Load dataset
dataset = load_dataset("gtfintechlab/financial_phrasebank_sentences_allagree", "5768")
texts = [d["sentence"] for d in dataset["train"]]
# Split texts
splitter = CharacterTextSplitter(chunk_size=200)
docs = splitter.create_documents(texts)
# Embeddings
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
db = Chroma.from_documents(docs, embeddings)
retriever = db.as_retriever()
# LLM pipeline
model_name = "google/flan-t5-small"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
pipe = pipeline("text2text-generation", model=model, tokenizer=tokenizer, max_length=256)
llm = HuggingFacePipeline(pipeline=pipe)
# QA chain
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
retriever=retriever,
chain_type="stuff"
)
return qa_chain
qa_chain = load_pipeline_and_retriever()
# Tool: simple financial calculator
def finance_calculator(query: str):
if "growth" in query.lower():
return "Company growth rate estimated at 7.5% YoY."
elif "revenue" in query.lower():
return "Revenue in Q2 increased by 12%."
return "No relevant financial data found."
st.sidebar.header("Tools")
use_calculator = st.sidebar.checkbox("Use Finance Calculator", value=True)
# User query
query = st.text_area("Enter your query:", "")
if st.button("Run Agent") and query:
with st.spinner("Thinking..."):
# Step 1: check calculator
calc_result = ""
if use_calculator:
calc_result = finance_calculator(query)
# Step 2: retrieve documents
doc_result = qa_chain.run(query)
st.subheader("πŸ’‘ Results")
if calc_result:
st.markdown(f"**Finance Calculator:** {calc_result}")
st.markdown(f"**Document Retrieval:** {doc_result}")