File size: 3,717 Bytes
1adbdd7
 
18e2196
1adbdd7
 
 
 
 
 
 
 
 
 
18e2196
e0eeaa3
 
 
 
 
 
 
 
 
1adbdd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc5a4e3
1adbdd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e0eeaa3
 
 
 
 
9a1a3d0
e0eeaa3
1adbdd7
 
 
 
 
 
 
 
e0eeaa3
 
1adbdd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import streamlit as st 
from urllib.parse import urlparse

from dotenv import load_dotenv
import os
import openai
from langchain.chains.qa_with_sources.loading import load_qa_with_sources_chain, BaseCombineDocumentsChain
from langchain.tools.base import BaseTool
from langchain.text_splitter import RecursiveCharacterTextSplitter
from pydantic import Field
import os, asyncio, trafilatura
from langchain.docstore.document import Document
import requests
from langchain_openai import ChatOpenAI
from langchain_google_genai import ChatGoogleGenerativeAI



options = ["openai","gemini"]
# 创建下拉选项框
selected_option = st.selectbox("select your llm:", options)

api_key = st.text_input("enter your llm api key:")



@st.cache_resource
def get_url_name(url):
    parsed_url = urlparse(url)
    return parsed_url.netloc

def _get_text_splitter():
    return RecursiveCharacterTextSplitter(
        chunk_size = 500,
        chunk_overlap  = 20,
        length_function = len,
    )

class WebpageQATool(BaseTool):
    name = "query_webpage"
    description = "Browse a webpage and retrieve the information and answers relevant to the question. Please use bullet points to list the answers"
    text_splitter: RecursiveCharacterTextSplitter = _get_text_splitter()
    qa_chain: BaseCombineDocumentsChain

    def _run(self, url: str, question: str) -> str:
        response = requests.get(url)
        page_content = response.text
        print(page_content)
        docs = [Document(page_content=page_content, metadata={"source": url})]
        web_docs = self.text_splitter.split_documents(docs)
        results = []
        for i in range(0, len(web_docs), 4):
            input_docs = web_docs[i:i+4]
            window_result = self.qa_chain({"input_documents": input_docs, "question": question}, return_only_outputs=True)
            results.append(f"Response from window {i} - {window_result}")
        results_docs = [Document(page_content="\n".join(results), metadata={"source": url})]
        print(results_docs)
        return self.qa_chain({"input_documents": results_docs, "question": question}, return_only_outputs=True)

    async def _arun(self, url: str, question: str) -> str:
        raise NotImplementedError

def run_llm(url, query):
    if selected_option == "openai":
        os.environ["OPENAI_API_KEY"] = api_key
        llm = ChatOpenAI(temperature=0.5)
    if selected_option == "gemini":
        os.environ["GOOGLE_API_KEY"] = api_key
        llm = ChatGoogleGenerativeAI(model="gemini_pro",temperature=0.5)
    # llm = ChatOpenAI(temperature=0.5)
    query_website_tool = WebpageQATool(qa_chain=load_qa_with_sources_chain(llm))
    result = query_website_tool._run(url, query)  # Pass the URL and query as arguments
    return result

st.markdown("<h1 style='text-align: center; color: green;'>Info Retrieval from Website 🦜 </h1>", unsafe_allow_html=True)
st.markdown("<h3 style='text-align: center; color: green;'>Developed by <a href='https://github.com/AIAnytime'>AI Anytime with ❤️ </a></h3>", unsafe_allow_html=True)
st.markdown("<h2 style='text-align: center; color:red;'>Enter the Website URL 👇</h2>", unsafe_allow_html=True)



input_url = st.text_input("Enter the URL")

if len(input_url)>0:
    url_name = get_url_name(input_url)
    st.info("Your URL is: 👇")
    st.write(url_name)

    st.markdown("<h4 style='text-align: center; color:green;'>Enter Your Query 👇</h4>", unsafe_allow_html=True)
    your_query = st.text_area("Query the Website")
    if st.button("Get Answers"):
        if len(your_query)>0:
            st.info("Your query is: "+ your_query)

            final_answer = run_llm(input_url, your_query)
            st.write(final_answer)