|
|
import streamlit as st
|
|
|
from datasets import load_dataset
|
|
|
from langchain.text_splitter import CharacterTextSplitter
|
|
|
from langchain_huggingface import HuggingFaceEmbeddings
|
|
|
from langchain.vectorstores import Chroma
|
|
|
from langchain.chains import RetrievalQA
|
|
|
from langchain.llms import HuggingFacePipeline
|
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
|
|
|
|
|
|
st.set_page_config(page_title="Agentic RAG Demo", layout="wide")
|
|
|
st.title("π Financial Agentic RAG Demo")
|
|
|
|
|
|
@st.cache_resource(show_spinner=True)
|
|
|
def load_pipeline_and_retriever():
|
|
|
|
|
|
dataset = load_dataset("gtfintechlab/financial_phrasebank_sentences_allagree", "5768")
|
|
|
texts = [d["sentence"] for d in dataset["train"]]
|
|
|
|
|
|
|
|
|
splitter = CharacterTextSplitter(chunk_size=200)
|
|
|
docs = splitter.create_documents(texts)
|
|
|
|
|
|
|
|
|
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
|
|
db = Chroma.from_documents(docs, embeddings)
|
|
|
retriever = db.as_retriever()
|
|
|
|
|
|
|
|
|
model_name = "google/flan-t5-small"
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
|
|
pipe = pipeline("text2text-generation", model=model, tokenizer=tokenizer, max_length=256)
|
|
|
llm = HuggingFacePipeline(pipeline=pipe)
|
|
|
|
|
|
|
|
|
qa_chain = RetrievalQA.from_chain_type(
|
|
|
llm=llm,
|
|
|
retriever=retriever,
|
|
|
chain_type="stuff"
|
|
|
)
|
|
|
|
|
|
return qa_chain
|
|
|
|
|
|
qa_chain = load_pipeline_and_retriever()
|
|
|
|
|
|
|
|
|
def finance_calculator(query: str):
|
|
|
if "growth" in query.lower():
|
|
|
return "Company growth rate estimated at 7.5% YoY."
|
|
|
elif "revenue" in query.lower():
|
|
|
return "Revenue in Q2 increased by 12%."
|
|
|
return "No relevant financial data found."
|
|
|
|
|
|
st.sidebar.header("Tools")
|
|
|
use_calculator = st.sidebar.checkbox("Use Finance Calculator", value=True)
|
|
|
|
|
|
|
|
|
query = st.text_area("Enter your query:", "")
|
|
|
|
|
|
if st.button("Run Agent") and query:
|
|
|
with st.spinner("Thinking..."):
|
|
|
|
|
|
calc_result = ""
|
|
|
if use_calculator:
|
|
|
calc_result = finance_calculator(query)
|
|
|
|
|
|
|
|
|
doc_result = qa_chain.run(query)
|
|
|
|
|
|
st.subheader("π‘ Results")
|
|
|
if calc_result:
|
|
|
st.markdown(f"**Finance Calculator:** {calc_result}")
|
|
|
st.markdown(f"**Document Retrieval:** {doc_result}")
|
|
|
|
|
|
|
|
|
|
|
|
|