Spaces:
Runtime error
Runtime error
| """Ask a question to the netspresso database.""" | |
| import json | |
| import sys | |
| import argparse | |
| from typing import List | |
| from langchain.chat_models import ChatOpenAI # for `gpt-3.5-turbo` & `gpt-4` | |
| from langchain.chains import RetrievalQAWithSourcesChain | |
| from langchain.prompts import ( | |
| ChatPromptTemplate, | |
| SystemMessagePromptTemplate, | |
| HumanMessagePromptTemplate, | |
| ) | |
| from langchain.schema import BaseRetriever, Document | |
| import gradio as gr | |
| from search_online import OnlineSearcher | |
| # DEFAULT_QUESTION = "모델 경량화 및 최적화와 관련하여 Netspresso bot에게 물어보세요.\n예를들어 \n\n- Why do I need to use Netspresso?\n- Summarize how to compress the model with netspresso.\n- Tell me what the pruning is.\n- What kinds of hardware can I use with this toolkit?\n- Can I use YOLOv8 with this tool? If so, tell me the examples." | |
| DEFAULT_QUESTION = "Ask the Netspresso bot about model lightweighting and optimization.\nFor example \n\n- Why do I need to use Netspresso?\n- Summarize how to compress the model with netspresso.\n- Tell me what the pruning is.\n- What kinds of hardware can I use with this toolkit?\n- Can I use YOLOv8 with this tool? If so, tell me the examples." | |
| TEMPERATURE = 0 | |
| # manual arguments (FIXME) | |
| args = argparse.Namespace | |
| args.index_type = "hybrid" | |
| args.index = ( | |
| "/root/indexes/docs-netspresso-ai/sparse,/root/indexes/docs-netspresso-ai/dense" | |
| ) | |
| if isinstance( | |
| args.index, tuple | |
| ): # black extension automatically convert long str to tuple | |
| assert len(args.index) == 1 | |
| args.index = args.index[0] | |
| args.encoder = "castorini/mdpr-question-nq" | |
| args.device = "cuda:0" | |
| args.alpha = 0.5 | |
| args.normalization = True | |
| args.lang_abbr = "en" | |
| args.K = 10 | |
| # initialize qabot | |
| print("initialize NP doc retrieval bot") | |
| RETRIEVER = OnlineSearcher(args) | |
| class LangChainCustomRetrieverWrapper(BaseRetriever): | |
| def __init__(self, args): | |
| super().__init__() | |
| # self.retriever = RETRIEVER # TODO. should be initialize from args | |
| # self.args = args | |
| print("Initialize LangChainCustomRetrieverWrapper, TODO: fix minor bug") | |
| def get_relevant_documents(self, query: str) -> List[Document]: | |
| """Get texts relevant for a query. | |
| Args: | |
| query: string to find relevant texts for | |
| Returns: | |
| List of relevant documents | |
| """ | |
| print(f"query = {query}") | |
| # retrieve | |
| # hits = self.retriever.search(query, self.args.K) | |
| hits = RETRIEVER.search( | |
| query, args.K | |
| ) # TODO: fix bug that BaseRetriever object cannot have extra field | |
| # extract docs | |
| results = [ | |
| { | |
| "contents": json.loads( | |
| # self.retriever.searcher.sparse_searcher.doc(hits[i].docid).raw() # TODO: fix bug that BaseRetriever object cannot have extra field | |
| RETRIEVER.searcher.sparse_searcher.doc(hits[i].docid).raw() | |
| )["contents"], | |
| "docid": hits[i].docid, | |
| } | |
| for i in range(len(hits)) | |
| ] | |
| # make result list of Document object | |
| return [ | |
| Document( | |
| page_content=result["contents"], metadata={"source": result["docid"]} | |
| ) | |
| for result in results | |
| ] | |
| async def aget_relevant_documents( | |
| self, query: str | |
| ) -> List[Document]: # abstractmethod | |
| raise NotImplementedError | |
| class RaLM: | |
| def __init__(self, args): | |
| self.args = args | |
| self.initialize_ralm() | |
| def initialize_ralm(self): | |
| # initialize custom retriever | |
| self.retriever = LangChainCustomRetrieverWrapper(self.args) | |
| # prompt for RaLM | |
| system_template = """Use the following pieces of context to answer the users question. | |
| Take note of the sources and include them in the answer in the format: "SOURCES: source1 source2", use "SOURCES" in capital letters regardless of the number of sources. | |
| Always try to generate answer from source. | |
| ---------------- | |
| {summaries}""" | |
| messages = [ | |
| SystemMessagePromptTemplate.from_template(system_template), | |
| HumanMessagePromptTemplate.from_template("{question}"), | |
| ] | |
| prompt = ChatPromptTemplate.from_messages(messages) | |
| chain_type_kwargs = {"prompt": prompt} | |
| llm = ChatOpenAI(model_name=self.args.model_name, temperature=TEMPERATURE) | |
| self.chain = RetrievalQAWithSourcesChain.from_chain_type( | |
| llm=llm, | |
| chain_type="stuff", | |
| retriever=self.retriever, | |
| return_source_documents=True, | |
| reduce_k_below_max_tokens=True, | |
| chain_type_kwargs=chain_type_kwargs, | |
| ) | |
| def run_chain(self, question, force_korean=False): | |
| if force_korean: | |
| question = f"{question} 본문을 참고해서 한글로 대답해줘" | |
| result = self.chain({"question": question}) | |
| # postprocess | |
| result["answer"] = self.postprocess(result["answer"]) | |
| if isinstance(result["sources"], str): | |
| result["sources"] = self.postprocess(result["sources"]) | |
| result["sources"] = result["sources"].split(", ") | |
| result["sources"] = [src.strip() for src in result["sources"]] | |
| # print result | |
| self.print_result(result) | |
| return result | |
| def print_result( | |
| self, result | |
| ): # print result of RetrievalQAWithSourcesChain of langchain | |
| print(f"Answer: {result['answer']}") | |
| print(f"Sources: ") | |
| print(result["sources"]) | |
| assert isinstance(result["sources"], list) | |
| nSource = len(result["sources"]) | |
| for i in range(nSource): | |
| source_title = result["sources"][i] | |
| print(f"{source_title}: ") | |
| if "source_documents" in result: | |
| for j in range(len(result["source_documents"])): | |
| if result["source_documents"][j].metadata["source"] == source_title: | |
| print(result["source_documents"][j].page_content) | |
| break | |
| def postprocess(self, text): | |
| # remove final parenthesis (bug with unknown cause) | |
| if ( | |
| text.endswith(")") | |
| or text.endswith("(") | |
| or text.endswith("[") | |
| or text.endswith("]") | |
| ): | |
| text = text[:-1] | |
| return text.strip() | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser( | |
| description="Ask a question to the netspresso docs." | |
| ) | |
| # General | |
| # parser.add_argument( | |
| # "--question", | |
| # type=str, | |
| # default=None, | |
| # required=True, | |
| # help="The question to ask for database", | |
| # ) | |
| parser.add_argument( | |
| "--model_name", | |
| type=str, | |
| default="gpt-3.5-turbo-16k-0613", | |
| help="model name for openai api", | |
| ) | |
| # Retriever: fixed arg for now | |
| """ | |
| parser.add_argument( | |
| "--query_encoder_name_or_dir", | |
| type=str, | |
| default="princeton-nlp/densephrases-multi-query-multi", | |
| help="query encoder name registered in huggingface model hub OR custom query encoder checkpoint directory", | |
| ) | |
| parser.add_argument( | |
| "--index_name", | |
| type=str, | |
| default="1048576_flat_OPQ96", | |
| help="index name appended to index directory prefix", | |
| ) | |
| """ | |
| args = parser.parse_args() | |
| # to prevent collision with DensePhrase native argparser | |
| sys.argv = [sys.argv[0]] | |
| # initialize class | |
| app = RaLM(args) | |
| def question_answer(question): | |
| result = app.run_chain(question=question, force_korean=False) | |
| return result[ | |
| "answer" | |
| ], "\n######################################################\n\n".join( | |
| [ | |
| f"Source {idx}\n{doc.page_content}" | |
| for idx, doc in enumerate(result["source_documents"]) | |
| ] | |
| ) | |
| # launch gradio | |
| gr.Interface( | |
| fn=question_answer, | |
| inputs=gr.inputs.Textbox(default=DEFAULT_QUESTION, label="Question"), | |
| outputs=[ | |
| gr.inputs.Textbox(default="", label="Bot response"), | |
| gr.inputs.Textbox(default="", label="Search result used by bot"), | |
| ], | |
| title="Netspresso Q&A bot", | |
| theme="dark-grass", | |
| description="Ask the Netspresso bot about model lightweighting and optimization.", # simplified version, hide detail version | |
| # description="모델 경량화 및 최적화와 관련하여 Netspresso bot에게 물어보세요.\n\n retriever: BM25&mdpr-question-nq, generator: gpt-3.5-turbo-16k-0613 (API)", | |
| ).launch(share=True) | |