Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from gradio_client import Client, handle_file | |
| import seaborn as sns | |
| import matplotlib.pyplot as plt | |
| import os | |
| import pandas as pd | |
| from io import StringIO, BytesIO | |
| import base64 | |
| # from linePlot import plot_stacked_time_series, plot_emotion_topic_grid | |
| # Define your Hugging Face token (make sure to set it as an environment variable) | |
| HF_TOKEN = os.getenv("HF_TOKEN") # Replace with your actual token if not using an environment variable | |
| # Initialize the Gradio Client for the specified API | |
| client = Client("mangoesai/Elections_Comparison_Agent_V4.1", hf_token=HF_TOKEN) | |
| # client_name = ['2016 Election','2024 Election', 'Comparison two years'] | |
| def stream_chat_with_rag( | |
| message: str, | |
| # history: list, | |
| client_name: str | |
| ): | |
| # print(f"Message: {message}") | |
| #answer = client.predict(question=question, api_name="/run_graph") | |
| answer, fig = client.predict( | |
| query= message, | |
| election_year=client_name, | |
| api_name="/process_query" | |
| ) | |
| # Debugging: Print the raw response | |
| print("Raw answer from API:") | |
| print(answer) | |
| print("top works from API:") | |
| print(fig) | |
| # return answer, fig | |
| return answer | |
| def heatmap(top_n): | |
| # df = pd.read_csv('submission_emotiontopics2024GPTresult.csv') | |
| # topics_df = gr.Dataframe(value=df, label="Data Input") | |
| pivot_table = client.predict( | |
| top_n= top_n, | |
| api_name="/get_heatmap_pivot_table" | |
| ) | |
| print(pivot_table) | |
| print(type(pivot_table)) | |
| """ | |
| pivot_table is a dict like: | |
| {'headers': ['Index', 'economy', 'human rights', 'immigrant', 'politics'], | |
| 'data': [['anger', 55880.0, 557679.0, 147766.0, 180094.0], | |
| ['disgust', 26911.0, 123112.0, 64567.0, 46460.0], | |
| ['fear', 51466.0, 188898.0, 113174.0, 150578.0], | |
| ['neutral', 77005.0, 192945.0, 20549.0, 190793.0]], | |
| 'metadata': None} | |
| """ | |
| # transfere dictionary to df | |
| df = pd.DataFrame(pivot_table['data'], columns=pivot_table['headers']) | |
| df.set_index('Index', inplace=True) | |
| plt.figure(figsize=(10, 8)) | |
| sns.heatmap(df, | |
| cmap='YlOrRd', | |
| cbar_kws={'label': 'Weighted Frequency'}, | |
| square=True) | |
| plt.title(f'Top {top_n} Emotions vs Topics Weighted Frequency') | |
| plt.xlabel('Topics') | |
| plt.ylabel('Emotions') | |
| plt.xticks(rotation=45, ha='right') | |
| plt.tight_layout() | |
| return plt.gcf() | |
| # def linePlot_time_series(viz_type, weight, top_n): | |
| # result = client.predict( | |
| # viz_type=viz_type, | |
| # weight=weight, | |
| # top_n=top_n, | |
| # api_name="/linePlot_time_series" | |
| # ) | |
| # print("============== timeseries df transfer from pivate to public ===============") | |
| # print(result) | |
| # print(type(result)) | |
| # df = pd.DataFrame(result['data'], columns=result['headers']) | |
| # df.set_index('Index', inplace=True) | |
| # return df | |
| # def update_visualization(viz_type, weight, top_n): | |
| # """ | |
| # Update visualization based on user inputs and selected visualization type | |
| # Parameters: | |
| # ----------- | |
| # viz_type : str | |
| # Type of visualization to show ('emotions', 'topics', or 'grid') | |
| # weight : float | |
| # Weight for scoring (0-1) | |
| # top_n : int | |
| # Number of top items to show | |
| # """ | |
| # try: | |
| # # return None, "Error: Start date must be before end date" | |
| # series = linePlot_time_series(viz_type, weight, top_n) | |
| # if viz_type == "emotions": | |
| # # Create emotion time series | |
| # # series = linePlot_time_series(viz_type, weight, top_n) | |
| # fig = plot_stacked_time_series( | |
| # series, | |
| # f'Top {top_n} Emotions Popularity' | |
| # ) | |
| # message = "Emotion time series updated" | |
| # elif viz_type == "topics": | |
| # # Create topic time series | |
| # # series = linePlot_time_series(viz_type, weight, top_n) | |
| # fig = plot_stacked_time_series( | |
| # series, | |
| # f'Top {top_n} Topics Popularity' | |
| # ) | |
| # message = "Topic time series updated" | |
| # else: # viz_type == "grid" | |
| # # Create emotion-topic grid | |
| # # pair_series = linePlot_time_series(viz_type, weight, top_n) | |
| # fig = plot_emotion_topic_grid(series, top_n) | |
| # message = "Emotion-Topic grid updated" | |
| # return fig, message | |
| # except Exception as e: | |
| # return None, f"Error: {str(e)}" | |
| def decode_plot(plot_base64): | |
| plot_bytes = base64.b64decode(plot_base64['plot'].split(',')[1]) | |
| img = plt.imread(BytesIO(plot_bytes), format='PNG') | |
| plt.imshow(img) | |
| plt.axis('off') | |
| plt.show() | |
| return plt.gcf() | |
| def linePlot(viz_type, weight, top_n): | |
| # client = Client("mangoesai/Elections_Comparison_Agent_V4.1") | |
| result = client.predict( | |
| viz_type=viz_type, | |
| weight=weight, | |
| top_n=top_n, | |
| api_name="/linePlot_3C1" | |
| ) | |
| # print(result) | |
| # result is a tuble of dictionary of plot_base64, and a string message of description of the plot | |
| return decode_plot(result[0]) | |
| # Create Gradio interface | |
| with gr.Blocks(title="Reddit Election Analysis") as demo: | |
| gr.Markdown("# Reddit Public sentiment & Social topic distribution ") | |
| with gr.Row(): | |
| with gr.Column(): | |
| top_n = gr.Dropdown(choices=[1,2,3,4,5,6,7,8,9,10]) | |
| fresh_btn = gr.Button("Refresh Heatmap") | |
| with gr.Column(): | |
| # with gr.Row(): | |
| output_heatmap = gr.Plot( | |
| label="Top Public sentiment & Social topic Heatmap", | |
| container=True, # Ensures the plot is contained within its area | |
| elem_classes="heatmap-plot" # Add a custom class for styling | |
| ) | |
| gr.Markdown("# Get the time series of the Public sentiment & Social topic") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| # Control panel | |
| lineGraph_type = gr.Dropdown(choices = ['emotions', 'topics', '2Dmatrix']) | |
| weight_slider = gr.Slider( | |
| minimum=0, | |
| maximum=1, | |
| value=0.5, | |
| step=0.1, | |
| label="Weight (Score vs. Frequency)" | |
| ) | |
| top_n_slider = gr.Slider( | |
| minimum=2, | |
| maximum=10, | |
| value=5, | |
| step=1, | |
| label="Top N Items" | |
| ) | |
| # with gr.Column(): | |
| viz_dropdown = gr.Dropdown( | |
| choices=["emotions", "topics", "grid"], | |
| value="emotions", | |
| label="Visualization Type", | |
| info="Select the type of visualization to display" | |
| ) | |
| linePlot_btn = gr.Button("Update Visualizations") | |
| linePlot_status_text = gr.Textbox(label="Status", interactive=False) | |
| with gr.Column(scale=3): | |
| time_series_fig = gr.Plot() | |
| gr.Markdown("# Reddit Election Posts/Comments Analysis") | |
| gr.Markdown("Ask questions about election-related comments and posts") | |
| with gr.Row(): | |
| with gr.Column(): | |
| year_selector = gr.Radio( | |
| choices=["2016 Election", "2024 Election", "Comparison two years"], | |
| label="Select Election Year", | |
| value="2016 Election" | |
| ) | |
| query_input = gr.Textbox( | |
| label="Your Question", | |
| placeholder="Ask about election comments or posts..." | |
| ) | |
| submit_btn = gr.Button("Submit") | |
| gr.Markdown(""" | |
| ## Example Questions: | |
| - Is there any comments don't like the election results | |
| - Summarize the main discussions about voting process | |
| - What are the common opinions about candidates? | |
| """) | |
| with gr.Column(): | |
| output_text = gr.Textbox( | |
| label="Response", | |
| lines=20 | |
| ) | |
| gr.Markdown("## Top works of the relevant Q&A") | |
| with gr.Row(): | |
| output_plot = gr.Plot( | |
| label="Topic Distribution", | |
| container=True, # Ensures the plot is contained within its area | |
| elem_classes="topic-plot" # Add a custom class for styling | |
| ) | |
| # Add custom CSS to ensure proper plot sizing | |
| gr.HTML(""" | |
| <style> | |
| .topic-plot { | |
| min-height: 600px; | |
| width: 100%; | |
| margin: auto; | |
| } | |
| .heatmap-plot { | |
| min-height: 400px; | |
| width: 100%; | |
| margin: auto; | |
| } | |
| </style> | |
| """) | |
| # topics_df = gr.Dataframe(value=df, label="Data Input") | |
| fresh_btn.click( | |
| fn=heatmap, | |
| inputs=top_n, | |
| outputs=output_heatmap | |
| ) | |
| linePlot_btn.click( | |
| fn = linePlot, | |
| inputs = [viz_dropdown,weight_slider,top_n_slider], | |
| outputs = [time_series_fig, linePlot_status_text] | |
| ) | |
| # Update both outputs when submit is clicked | |
| submit_btn.click( | |
| fn=stream_chat_with_rag, | |
| inputs=[query_input, year_selector], | |
| outputs=output_text | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=True) |