Vera-ZWY commited on
Commit
4ebe04e
·
verified ·
1 Parent(s): 1a69598

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -41
app.py CHANGED
@@ -97,54 +97,73 @@ def linePlot_time_series(viz_type, weight, top_n):
97
  return df
98
 
99
 
100
- def update_visualization(viz_type, weight, top_n):
101
- """
102
- Update visualization based on user inputs and selected visualization type
103
 
104
- Parameters:
105
- -----------
106
- viz_type : str
107
- Type of visualization to show ('emotions', 'topics', or 'grid')
108
- weight : float
109
- Weight for scoring (0-1)
110
- top_n : int
111
- Number of top items to show
112
- """
113
- try:
114
-
115
- # return None, "Error: Start date must be before end date"
116
- series = linePlot_time_series(viz_type, weight, top_n)
117
- if viz_type == "emotions":
118
- # Create emotion time series
119
- # series = linePlot_time_series(viz_type, weight, top_n)
120
- fig = plot_stacked_time_series(
121
- series,
122
- f'Top {top_n} Emotions Popularity'
123
- )
124
- message = "Emotion time series updated"
125
 
126
- elif viz_type == "topics":
127
- # Create topic time series
128
- # series = linePlot_time_series(viz_type, weight, top_n)
129
- fig = plot_stacked_time_series(
130
- series,
131
- f'Top {top_n} Topics Popularity'
132
- )
133
- message = "Topic time series updated"
134
 
135
- else: # viz_type == "grid"
136
- # Create emotion-topic grid
137
- # pair_series = linePlot_time_series(viz_type, weight, top_n)
138
- fig = plot_emotion_topic_grid(series, top_n)
139
- message = "Emotion-Topic grid updated"
140
 
141
- return fig, message
142
 
143
- except Exception as e:
144
- return None, f"Error: {str(e)}"
145
 
 
 
 
 
 
 
 
146
 
147
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
 
150
 
@@ -276,7 +295,7 @@ with gr.Blocks(title="Reddit Election Analysis") as demo:
276
  )
277
 
278
  linePlot_btn.click(
279
- fn = update_visualization,
280
  inputs = [viz_dropdown,weight_slider,top_n_slider],
281
  outputs = [time_series_fig, linePlot_status_text]
282
  )
 
97
  return df
98
 
99
 
100
+ # def update_visualization(viz_type, weight, top_n):
101
+ # """
102
+ # Update visualization based on user inputs and selected visualization type
103
 
104
+ # Parameters:
105
+ # -----------
106
+ # viz_type : str
107
+ # Type of visualization to show ('emotions', 'topics', or 'grid')
108
+ # weight : float
109
+ # Weight for scoring (0-1)
110
+ # top_n : int
111
+ # Number of top items to show
112
+ # """
113
+ # try:
114
+
115
+ # # return None, "Error: Start date must be before end date"
116
+ # series = linePlot_time_series(viz_type, weight, top_n)
117
+ # if viz_type == "emotions":
118
+ # # Create emotion time series
119
+ # # series = linePlot_time_series(viz_type, weight, top_n)
120
+ # fig = plot_stacked_time_series(
121
+ # series,
122
+ # f'Top {top_n} Emotions Popularity'
123
+ # )
124
+ # message = "Emotion time series updated"
125
 
126
+ # elif viz_type == "topics":
127
+ # # Create topic time series
128
+ # # series = linePlot_time_series(viz_type, weight, top_n)
129
+ # fig = plot_stacked_time_series(
130
+ # series,
131
+ # f'Top {top_n} Topics Popularity'
132
+ # )
133
+ # message = "Topic time series updated"
134
 
135
+ # else: # viz_type == "grid"
136
+ # # Create emotion-topic grid
137
+ # # pair_series = linePlot_time_series(viz_type, weight, top_n)
138
+ # fig = plot_emotion_topic_grid(series, top_n)
139
+ # message = "Emotion-Topic grid updated"
140
 
141
+ # return fig, message
142
 
143
+ # except Exception as e:
144
+ # return None, f"Error: {str(e)}"
145
 
146
+ def decode_plot(plot_base64):
147
+ plot_bytes = base64.b64decode(plot_base64['plot'].split(',')[1])
148
+ img = plt.imread(BytesIO(plot_bytes), format='PNG')
149
+ plt.imshow(img)
150
+ plt.axis('off')
151
+ plt.show()
152
+ return plt.gcf()
153
 
154
 
155
+ def linePlot(viz_type, weight, top_n):
156
+ client = Client("mangoesai/Elections_Comparison_Agent_V4.1")
157
+ result = client.predict(
158
+ viz_type=viz_type,
159
+ weight=weight,
160
+ top_n=top_n,
161
+ api_name="/linePlot_3C1"
162
+ )
163
+
164
+ return decode_plot(result)
165
+
166
+
167
 
168
 
169
 
 
295
  )
296
 
297
  linePlot_btn.click(
298
+ fn = linePlot,
299
  inputs = [viz_dropdown,weight_slider,top_n_slider],
300
  outputs = [time_series_fig, linePlot_status_text]
301
  )