import gradio as gr from agent import Agent, process_data, process_retrievers import torch from markdown_it import MarkdownIt import json def init_agent(): print("init agent ......") DATA_DIR = "data_websites" PATH_SAVE_CHUNKS = "chunks_saved.json" PATH_SAVE_CONTEXT = "chunks_with_context.json" PATH_IDX = "index_faiss_data_sh" PATH_IDX_CONTEXT = "index_faiss_context_sh" PATH_IDX_CONTEXT_AND_WT = "index_faiss_context_and_wt_sh" # embedding_model_names = ["Geotrend/distilbert-base-en-fr-cased"] embedding_model_names = [] use_context = False reformulation = False use_HyDE = False use_HyDE_cut = False use_context_and_wt = True ask_again = False TOP_K = 10 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') list_dir, chunks = process_data(DATA_DIR, PATH_SAVE_CHUNKS, PATH_SAVE_CONTEXT, use_context_and_wt, use_context) embedding_models, BM25_retriever = process_retrievers(embedding_model_names, chunks, TOP_K, use_context, use_context_and_wt, PATH_IDX, PATH_IDX_CONTEXT, PATH_IDX_CONTEXT_AND_WT, device) agent = Agent(list_dir, chunks, embedding_models, BM25_retriever, TOP_K, reformulation, use_HyDE, use_HyDE_cut, ask_again) return agent def respond(agent, user_input, history, comments, likes): reply, sources = agent.get_a_reply(user_input) bot_response = str(reply) + "\n\n\nℹ️ *Retrouvez davantage d'informations sur la page du site internet de l'université*:\n"+ "\n".join(f"- [{src}]({src})" for src in sources) # bot_response = f"Echo: {user_input}" history.append(("user", user_input)) history.append(("bot", bot_response)) return agent, "", update_html(history, comments, likes), history, comments def add_comment(comment_input, comments, history, likes): com = comment_input.split(" : ") message_index = int(com[0]) comment_text = com[1] if message_index in comments: comments[message_index] += [comment_text] else: comments[message_index] = [comment_text] print(data) return update_html(history, comments, likes), comments def handle_feedback(history, comments, feedback, likes): com = feedback.split(" : ") message_index = int(com[0]) like_val = int(com[1]) if message_index != -1: if message_index in likes and likes[message_index] == like_val: # likes[message_index] = 0 likes.pop(message_index) else: likes[message_index] = like_val print(data) feedback = "-1 : 0" return update_html(history, comments, likes), feedback, likes def update_html(history, comments, likes): like_empty = """ """ unlike_empty = """ """ like_filled = """ """ unlike_filled = """ """ send = """ """ html = "" md = MarkdownIt() for i, (sender, msg) in reversed(list(enumerate(history))): if sender == "user": html += f'
{md.render(msg)}
' else: html += f'
{md.render(msg)}' # Add like/dislike buttons val_like = likes[i] if i in likes else 0 # val_like = 0 # if i in likes: # val_like = likes[i] like = like_filled if val_like == 1 else like_empty unlike = unlike_filled if val_like == -1 else unlike_empty html += f"""
""" # Add comment box html += f"""
""" # Add others comments if i in comments: for comment in comments[i]: html += f'
💬 {comment}
' html += "
" return f"""
{html}
""" js=""" document.getElementById("like").style.visibility = "hidden"; document.getElementById("comm").style.visibility = "hidden"; """ head=""" """ css= """ #hidden {{ visibility: hidden !important; background: white !important; }}; """ # agent = gr.State(init_agent()) # txtBox.innerHTML='reply '+nb_rep+": "+val; data = {} def store_data(request: gr.Request, state, likes, comment_state): data[request.session_hash] = {"state": state, "likes": likes, "comment_state": comment_state} def save_and_clean_data(request: gr.Request): # print(data[request.session_hash]) # global data print(data) if data[request.session_hash]['state']: with open('log.json', 'a') as fp: json.dump(data[request.session_hash], fp, ensure_ascii=False, indent=2) fp.write(',\n') data.pop(request.session_hash, None) with gr.Blocks(head=head) as demo: UB_logo = """ """ gr.HTML(""" """) gr.HTML(f"

{UB_logo} Chatbot UBX

") gr.HTML("""

Ce chatbot répond à des questions administratives de l'Université de Bordeaux.
Les réponses seront enregistrées dans le cadre d'une étude d'amélioration de cet outil.

""") gr.HTML("

Commencez une discussion

") # with gr.Row(): # with gr.Row(scale=9): chat_output = gr.HTML() user_input = gr.Textbox(placeholder="Ecrivez votre question ici ...", show_label=False) state = gr.State([]) likes = gr.State({}) comment_state = gr.State({}) demo.load(store_data, [state, likes, comment_state], None) demo.unload(save_and_clean_data) agent = gr.State(init_agent()) user_input.submit(respond, [agent, user_input, state, comment_state, likes], [agent, user_input, chat_output, state, comment_state]) # Hidden component to receive JS feedback via endpoint with gr.Row(scale=0,elem_id="hidden"): feedback_input = gr.Textbox(elem_id='like') feedback_input.change(handle_feedback, [state, comment_state, feedback_input, likes], [chat_output, feedback_input, likes]) comment_text = gr.Textbox(elem_id='comm') comment_text.change(add_comment, [comment_text, comment_state, state, likes], [chat_output, comment_state], show_progress=False) if __name__ == "__main__": demo.launch(share=True)