fmegahed commited on
Commit
325c300
·
verified ·
1 Parent(s): 8dfaf6f

Some polishing with Claude's Help

Browse files
Files changed (1) hide show
  1. app.py +232 -67
app.py CHANGED
@@ -2,6 +2,8 @@ import pandas as pd
2
  import matplotlib.pyplot as plt
3
  import gradio as gr
4
  import tempfile
 
 
5
 
6
  from statsforecast import StatsForecast
7
  from statsforecast.models import (
@@ -27,59 +29,122 @@ def load_data(file):
27
  missing_cols = [col for col in required_cols if col not in df.columns]
28
  if missing_cols:
29
  return None, f"Missing required columns: {', '.join(missing_cols)}"
 
30
  df['ds'] = pd.to_datetime(df['ds'])
31
- df = df.sort_values(['unique_id', 'ds'])
 
 
 
 
 
32
  return df, "Data loaded successfully!"
33
  except Exception as e:
34
  return None, f"Error loading data: {str(e)}"
35
 
36
  # Function to generate and return a plot
37
  def create_forecast_plot(forecast_df, original_df, title="Forecasting Results"):
38
- plt.figure(figsize=(10, 6))
39
  unique_ids = forecast_df['unique_id'].unique()
40
  forecast_cols = [col for col in forecast_df.columns if col not in ['unique_id', 'ds', 'cutoff']]
41
-
42
- for unique_id in unique_ids:
 
 
43
  original_data = original_df[original_df['unique_id'] == unique_id]
44
- plt.plot(original_data['ds'], original_data['y'], 'k-', label='Actual')
 
45
  forecast_data = forecast_df[forecast_df['unique_id'] == unique_id]
46
- for col in forecast_cols:
47
  if col in forecast_data.columns:
48
- plt.plot(forecast_data['ds'], forecast_data[col], label=col)
 
 
 
 
49
 
50
- plt.title(title)
51
- plt.xlabel('Date')
52
- plt.ylabel('Value')
53
- plt.legend()
54
- plt.grid(True)
 
 
 
55
  fig = plt.gcf()
 
 
 
56
  return fig
57
 
58
  # Function to create a plot for future forecasts
59
  def create_future_forecast_plot(forecast_df, original_df):
60
- plt.figure(figsize=(10, 6))
61
  unique_ids = forecast_df['unique_id'].unique()
62
  forecast_cols = [col for col in forecast_df.columns if col not in ['unique_id', 'ds']]
63
-
64
- for unique_id in unique_ids:
 
 
65
  # Plot historical data
66
  original_data = original_df[original_df['unique_id'] == unique_id]
67
- plt.plot(original_data['ds'], original_data['y'], 'k-', label='Historical')
68
 
69
- # Plot forecast data
70
  forecast_data = forecast_df[forecast_df['unique_id'] == unique_id]
71
- for col in forecast_cols:
 
 
 
 
 
 
72
  if col in forecast_data.columns:
73
- plt.plot(forecast_data['ds'], forecast_data[col], label=col)
 
 
 
 
74
 
75
- plt.title('Future Forecast')
76
- plt.xlabel('Date')
77
- plt.ylabel('Value')
78
- plt.legend()
79
- plt.grid(True)
 
 
 
80
  fig = plt.gcf()
 
 
 
81
  return fig
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  # Main forecasting logic
84
  def run_forecast(
85
  file,
@@ -102,7 +167,7 @@ def run_forecast(
102
  ):
103
  df, message = load_data(file)
104
  if df is None:
105
- return None, None, None, None, None, message
106
 
107
  models = []
108
  model_aliases = []
@@ -130,7 +195,7 @@ def run_forecast(
130
  model_aliases.append('autoarima')
131
 
132
  if not models:
133
- return None, None, None, None, None, "Please select at least one forecasting model"
134
 
135
  sf = StatsForecast(models=models, freq=frequency, n_jobs=-1)
136
 
@@ -148,81 +213,130 @@ def run_forecast(
148
  fig_validation = create_forecast_plot(cv_results, df, "Fixed Window Validation Results")
149
 
150
  # Generate future forecasts
151
- future_forecasts = sf.forecast(df = df, h=horizon)
152
  fig_future = create_future_forecast_plot(future_forecasts, df)
153
 
154
- return eval_df, cv_results, fig_validation, future_forecasts, fig_future, "Analysis completed successfully!"
 
 
 
155
 
156
  except Exception as e:
157
- return None, None, None, None, None, f"Error during forecasting: {str(e)}"
158
 
159
  # Sample CSV file generation
160
  def download_sample():
161
  sample_data = """unique_id,ds,y
162
- series1,2023-01-01,100
163
- series1,2023-01-02,105
164
- series1,2023-01-03,102
165
- series1,2023-01-04,107
166
- series1,2023-01-05,104
167
- series1,2023-01-06,110
168
- series1,2023-01-07,108
169
- series1,2023-01-08,112
170
- series1,2023-01-09,115
171
- series1,2023-01-10,118
172
- series1,2023-01-11,120
173
- series1,2023-01-12,123
174
- series1,2023-01-13,126
175
- series1,2023-01-14,129
176
- series1,2023-01-15,131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  """
178
  temp = tempfile.NamedTemporaryFile(delete=False, suffix=".csv", mode='w', newline='')
179
  temp.write(sample_data)
180
  temp.close()
181
  return temp.name
182
 
 
 
 
 
 
 
 
183
  # Gradio interface
184
- with gr.Blocks(title="Extrapolative Forecasts for One or Many Time Series") as app:
185
- gr.Markdown("# 📈 Baselining without Exogenous Variables")
186
  gr.Markdown("Upload a CSV with `unique_id`, `ds`, and `y` columns to apply forecasting models.")
187
 
188
  with gr.Row():
189
  with gr.Column(scale=2):
190
  file_input = gr.File(label="Upload CSV file", file_types=[".csv"])
191
 
192
- download_btn = gr.Button("Download Sample Data")
193
  download_output = gr.File(label="Click to download", visible=True)
194
  download_btn.click(fn=download_sample, outputs=download_output)
195
 
196
  with gr.Accordion("Data & Validation Settings", open=True):
197
- frequency = gr.Dropdown(choices=["H", "D", "WS", "MS", "QS", "YS"], label="Frequency", value="D")
198
- eval_strategy = gr.Radio(choices=["Fixed Window", "Cross Validation"], label="Evaluation Strategy", value="Cross Validation")
199
- horizon = gr.Slider(1, 100, value=10, step=1, label="Validation Horizon")
200
- step_size = gr.Slider(1, 50, value=10, step=1, label="Step Size")
201
- num_windows = gr.Slider(1, 20, value=3, step=1, label="Number of Windows")
202
-
203
- with gr.Accordion("Forecast Settings", open=True):
204
- future_horizon = gr.Slider(1, 100, value=20, step=1, label="Future Forecast Horizon")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
  with gr.Accordion("Model Configuration", open=True):
207
- use_historical_avg = gr.Checkbox(label="Use Historical Average", value=True)
208
- use_naive = gr.Checkbox(label="Use Naive", value=True)
 
 
209
 
 
210
  with gr.Row():
211
- use_seasonal_naive = gr.Checkbox(label="Use Seasonal Naive")
212
- seasonality = gr.Number(label="Seasonality", value=10)
213
 
 
214
  with gr.Row():
215
- use_window_avg = gr.Checkbox(label="Use Window Average")
216
  window_size = gr.Number(label="Window Size", value=3)
217
 
218
  with gr.Row():
219
- use_seasonal_window_avg = gr.Checkbox(label="Use Seasonal Window Average")
220
  seasonal_window_size = gr.Number(label="Seasonal Window Size", value=2)
221
 
222
- use_autoets = gr.Checkbox(label="Use AutoETS")
223
- use_autoarima = gr.Checkbox(label="Use AutoARIMA")
 
 
224
 
225
- submit_btn = gr.Button("Run Forecast", variant="primary")
226
 
227
  with gr.Column(scale=3):
228
  message_output = gr.Textbox(label="Status Message")
@@ -230,13 +344,64 @@ with gr.Blocks(title="Extrapolative Forecasts for One or Many Time Series") as a
230
  with gr.Tabs() as tabs:
231
  with gr.TabItem("Validation Results"):
232
  eval_output = gr.Dataframe(label="Evaluation Metrics")
233
- validation_output = gr.Dataframe(label="Validation Data")
234
  validation_plot = gr.Plot(label="Validation Plot")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
 
236
  with gr.TabItem("Future Forecast"):
237
- forecast_output = gr.Dataframe(label="Future Forecast Data")
238
  forecast_plot = gr.Plot(label="Future Forecast Plot")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
 
 
240
  submit_btn.click(
241
  fn=run_forecast,
242
  inputs=[
@@ -245,7 +410,7 @@ with gr.Blocks(title="Extrapolative Forecasts for One or Many Time Series") as a
245
  use_window_avg, window_size, use_seasonal_window_avg, seasonal_window_size,
246
  use_autoets, use_autoarima, future_horizon
247
  ],
248
- outputs=[eval_output, validation_output, validation_plot, forecast_output, forecast_plot, message_output]
249
  )
250
 
251
  if __name__ == "__main__":
 
2
  import matplotlib.pyplot as plt
3
  import gradio as gr
4
  import tempfile
5
+ import os
6
+ from datetime import datetime
7
 
8
  from statsforecast import StatsForecast
9
  from statsforecast.models import (
 
29
  missing_cols = [col for col in required_cols if col not in df.columns]
30
  if missing_cols:
31
  return None, f"Missing required columns: {', '.join(missing_cols)}"
32
+
33
  df['ds'] = pd.to_datetime(df['ds'])
34
+ df = df.sort_values(['unique_id', 'ds']).reset_index(drop=True)
35
+
36
+ # Check for NaN values
37
+ if df['y'].isna().any():
38
+ return None, "Data contains missing values in the 'y' column"
39
+
40
  return df, "Data loaded successfully!"
41
  except Exception as e:
42
  return None, f"Error loading data: {str(e)}"
43
 
44
  # Function to generate and return a plot
45
  def create_forecast_plot(forecast_df, original_df, title="Forecasting Results"):
46
+ plt.figure(figsize=(12, 7))
47
  unique_ids = forecast_df['unique_id'].unique()
48
  forecast_cols = [col for col in forecast_df.columns if col not in ['unique_id', 'ds', 'cutoff']]
49
+
50
+ colors = plt.cm.tab10.colors
51
+
52
+ for i, unique_id in enumerate(unique_ids):
53
  original_data = original_df[original_df['unique_id'] == unique_id]
54
+ plt.plot(original_data['ds'], original_data['y'], 'k-', linewidth=2, label=f'{unique_id} (Actual)')
55
+
56
  forecast_data = forecast_df[forecast_df['unique_id'] == unique_id]
57
+ for j, col in enumerate(forecast_cols):
58
  if col in forecast_data.columns:
59
+ plt.plot(forecast_data['ds'], forecast_data[col],
60
+ color=colors[j % len(colors)],
61
+ linestyle='--',
62
+ linewidth=1.5,
63
+ label=f'{col}')
64
 
65
+ plt.title(title, fontsize=16)
66
+ plt.xlabel('Date', fontsize=12)
67
+ plt.ylabel('Value', fontsize=12)
68
+ plt.grid(True, alpha=0.3)
69
+ plt.legend(loc='best')
70
+ plt.tight_layout()
71
+
72
+ # Format date labels better
73
  fig = plt.gcf()
74
+ ax = plt.gca()
75
+ fig.autofmt_xdate()
76
+
77
  return fig
78
 
79
  # Function to create a plot for future forecasts
80
  def create_future_forecast_plot(forecast_df, original_df):
81
+ plt.figure(figsize=(12, 7))
82
  unique_ids = forecast_df['unique_id'].unique()
83
  forecast_cols = [col for col in forecast_df.columns if col not in ['unique_id', 'ds']]
84
+
85
+ colors = plt.cm.tab10.colors
86
+
87
+ for i, unique_id in enumerate(unique_ids):
88
  # Plot historical data
89
  original_data = original_df[original_df['unique_id'] == unique_id]
90
+ plt.plot(original_data['ds'], original_data['y'], 'k-', linewidth=2, label=f'{unique_id} (Historical)')
91
 
92
+ # Plot forecast data with shaded vertical line separator
93
  forecast_data = forecast_df[forecast_df['unique_id'] == unique_id]
94
+
95
+ # Add vertical line at the forecast start
96
+ if not forecast_data.empty and not original_data.empty:
97
+ forecast_start = forecast_data['ds'].min()
98
+ plt.axvline(x=forecast_start, color='gray', linestyle='--', alpha=0.5)
99
+
100
+ for j, col in enumerate(forecast_cols):
101
  if col in forecast_data.columns:
102
+ plt.plot(forecast_data['ds'], forecast_data[col],
103
+ color=colors[j % len(colors)],
104
+ linestyle='--',
105
+ linewidth=1.5,
106
+ label=f'{col}')
107
 
108
+ plt.title('Future Forecast', fontsize=16)
109
+ plt.xlabel('Date', fontsize=12)
110
+ plt.ylabel('Value', fontsize=12)
111
+ plt.grid(True, alpha=0.3)
112
+ plt.legend(loc='best')
113
+ plt.tight_layout()
114
+
115
+ # Format date labels better
116
  fig = plt.gcf()
117
+ ax = plt.gca()
118
+ fig.autofmt_xdate()
119
+
120
  return fig
121
 
122
+ # Function to export results to CSV
123
+ def export_results(eval_df, cv_results, future_forecasts):
124
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
125
+
126
+ # Create temp directory if it doesn't exist
127
+ temp_dir = tempfile.mkdtemp()
128
+
129
+ files = {}
130
+
131
+ if eval_df is not None:
132
+ eval_path = os.path.join(temp_dir, f"evaluation_metrics_{timestamp}.csv")
133
+ eval_df.to_csv(eval_path, index=False)
134
+ files["evaluation"] = eval_path
135
+
136
+ if cv_results is not None:
137
+ cv_path = os.path.join(temp_dir, f"cross_validation_results_{timestamp}.csv")
138
+ cv_results.to_csv(cv_path, index=False)
139
+ files["validation"] = cv_path
140
+
141
+ if future_forecasts is not None:
142
+ forecast_path = os.path.join(temp_dir, f"forecasts_{timestamp}.csv")
143
+ future_forecasts.to_csv(forecast_path, index=False)
144
+ files["forecast"] = forecast_path
145
+
146
+ return files
147
+
148
  # Main forecasting logic
149
  def run_forecast(
150
  file,
 
167
  ):
168
  df, message = load_data(file)
169
  if df is None:
170
+ return None, None, None, None, None, None, message
171
 
172
  models = []
173
  model_aliases = []
 
195
  model_aliases.append('autoarima')
196
 
197
  if not models:
198
+ return None, None, None, None, None, None, "Please select at least one forecasting model"
199
 
200
  sf = StatsForecast(models=models, freq=frequency, n_jobs=-1)
201
 
 
213
  fig_validation = create_forecast_plot(cv_results, df, "Fixed Window Validation Results")
214
 
215
  # Generate future forecasts
216
+ future_forecasts = sf.forecast(df=df, h=future_horizon)
217
  fig_future = create_future_forecast_plot(future_forecasts, df)
218
 
219
+ # Export results
220
+ export_files = export_results(eval_df, cv_results, future_forecasts)
221
+
222
+ return eval_df, cv_results, fig_validation, future_forecasts, fig_future, export_files, "Analysis completed successfully!"
223
 
224
  except Exception as e:
225
+ return None, None, None, None, None, None, f"Error during forecasting: {str(e)}"
226
 
227
  # Sample CSV file generation
228
  def download_sample():
229
  sample_data = """unique_id,ds,y
230
+ series1,2025-01-01,100
231
+ series1,2025-01-02,105
232
+ series1,2025-01-03,102
233
+ series1,2025-01-04,107
234
+ series1,2025-01-05,104
235
+ series1,2025-01-06,110
236
+ series1,2025-01-07,108
237
+ series1,2025-01-08,112
238
+ series1,2025-01-09,115
239
+ series1,2025-01-10,118
240
+ series1,2025-01-11,120
241
+ series1,2025-01-12,123
242
+ series1,2025-01-13,126
243
+ series1,2025-01-14,129
244
+ series1,2025-01-15,131
245
+ series2,2025-01-01,200
246
+ series2,2025-01-02,195
247
+ series2,2025-01-03,205
248
+ series2,2025-01-04,210
249
+ series2,2025-01-05,215
250
+ series2,2025-01-06,212
251
+ series2,2025-01-07,208
252
+ series2,2025-01-08,215
253
+ series2,2025-01-09,220
254
+ series2,2025-01-10,218
255
+ series2,2025-01-11,225
256
+ series2,2025-01-12,230
257
+ series2,2025-01-13,235
258
+ series2,2025-01-14,232
259
+ series2,2025-01-15,240
260
  """
261
  temp = tempfile.NamedTemporaryFile(delete=False, suffix=".csv", mode='w', newline='')
262
  temp.write(sample_data)
263
  temp.close()
264
  return temp.name
265
 
266
+ # Global theme
267
+ theme = gr.themes.Soft(
268
+ primary_hue="blue",
269
+ secondary_hue="indigo",
270
+ neutral_hue="gray"
271
+ )
272
+
273
  # Gradio interface
274
+ with gr.Blocks(title="Time Series Forecasting App", theme=theme) as app:
275
+ gr.Markdown("# 📈 Time Series Forecasting App")
276
  gr.Markdown("Upload a CSV with `unique_id`, `ds`, and `y` columns to apply forecasting models.")
277
 
278
  with gr.Row():
279
  with gr.Column(scale=2):
280
  file_input = gr.File(label="Upload CSV file", file_types=[".csv"])
281
 
282
+ download_btn = gr.Button("Download Sample Data", variant="secondary")
283
  download_output = gr.File(label="Click to download", visible=True)
284
  download_btn.click(fn=download_sample, outputs=download_output)
285
 
286
  with gr.Accordion("Data & Validation Settings", open=True):
287
+ frequency = gr.Dropdown(
288
+ choices=[
289
+ ("Hourly", "H"),
290
+ ("Daily", "D"),
291
+ ("Weekly", "WS"),
292
+ ("Monthly", "MS"),
293
+ ("Quarterly", "QS"),
294
+ ("Yearly", "YS")
295
+ ],
296
+ label="Data Frequency",
297
+ value="D"
298
+ )
299
+
300
+ eval_strategy = gr.Radio(
301
+ choices=["Fixed Window", "Cross Validation"],
302
+ label="Evaluation Strategy",
303
+ value="Cross Validation"
304
+ )
305
+
306
+ with gr.Row():
307
+ horizon = gr.Slider(1, 100, value=10, step=1, label="Validation Horizon")
308
+ future_horizon = gr.Slider(1, 100, value=20, step=1, label="Future Forecast Horizon")
309
+
310
+ with gr.Row(visible=lambda: eval_strategy == "Cross Validation"):
311
+ step_size = gr.Slider(1, 50, value=10, step=1, label="Step Size")
312
+ num_windows = gr.Slider(1, 20, value=3, step=1, label="Number of Windows")
313
 
314
  with gr.Accordion("Model Configuration", open=True):
315
+ gr.Markdown("### Basic Models")
316
+ with gr.Row():
317
+ use_historical_avg = gr.Checkbox(label="Historical Average", value=True)
318
+ use_naive = gr.Checkbox(label="Naive", value=True)
319
 
320
+ gr.Markdown("### Seasonal Models")
321
  with gr.Row():
322
+ use_seasonal_naive = gr.Checkbox(label="Seasonal Naive")
323
+ seasonality = gr.Number(label="Seasonality Period", value=7)
324
 
325
+ gr.Markdown("### Window-based Models")
326
  with gr.Row():
327
+ use_window_avg = gr.Checkbox(label="Window Average")
328
  window_size = gr.Number(label="Window Size", value=3)
329
 
330
  with gr.Row():
331
+ use_seasonal_window_avg = gr.Checkbox(label="Seasonal Window Average")
332
  seasonal_window_size = gr.Number(label="Seasonal Window Size", value=2)
333
 
334
+ gr.Markdown("### Advanced Models")
335
+ with gr.Row():
336
+ use_autoets = gr.Checkbox(label="AutoETS (Exponential Smoothing)")
337
+ use_autoarima = gr.Checkbox(label="AutoARIMA")
338
 
339
+ submit_btn = gr.Button("Run Forecast", variant="primary", size="lg")
340
 
341
  with gr.Column(scale=3):
342
  message_output = gr.Textbox(label="Status Message")
 
344
  with gr.Tabs() as tabs:
345
  with gr.TabItem("Validation Results"):
346
  eval_output = gr.Dataframe(label="Evaluation Metrics")
 
347
  validation_plot = gr.Plot(label="Validation Plot")
348
+ validation_output = gr.Dataframe(label="Validation Data", visible=False)
349
+
350
+ with gr.Row():
351
+ show_data_btn = gr.Button("Show Validation Data")
352
+ hide_data_btn = gr.Button("Hide Validation Data", visible=False)
353
+
354
+ def show_data():
355
+ return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)
356
+
357
+ def hide_data():
358
+ return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
359
+
360
+ show_data_btn.click(
361
+ fn=show_data,
362
+ outputs=[validation_output, hide_data_btn, show_data_btn]
363
+ )
364
+
365
+ hide_data_btn.click(
366
+ fn=hide_data,
367
+ outputs=[validation_output, hide_data_btn, show_data_btn]
368
+ )
369
 
370
  with gr.TabItem("Future Forecast"):
 
371
  forecast_plot = gr.Plot(label="Future Forecast Plot")
372
+ forecast_output = gr.Dataframe(label="Future Forecast Data", visible=False)
373
+
374
+ with gr.Row():
375
+ show_forecast_btn = gr.Button("Show Forecast Data")
376
+ hide_forecast_btn = gr.Button("Hide Forecast Data", visible=False)
377
+
378
+ def show_forecast():
379
+ return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)
380
+
381
+ def hide_forecast():
382
+ return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
383
+
384
+ show_forecast_btn.click(
385
+ fn=show_forecast,
386
+ outputs=[forecast_output, hide_forecast_btn, show_forecast_btn]
387
+ )
388
+
389
+ hide_forecast_btn.click(
390
+ fn=hide_forecast,
391
+ outputs=[forecast_output, hide_forecast_btn, show_forecast_btn]
392
+ )
393
+
394
+ with gr.TabItem("Export Results"):
395
+ export_files = gr.Files(label="Download Results")
396
+
397
+ # Update visibility of step_size and num_windows based on eval_strategy
398
+ eval_strategy.change(
399
+ fn=lambda x: gr.update(visible=x == "Cross Validation"),
400
+ inputs=[eval_strategy],
401
+ outputs=[gr.Row.update(visible=lambda: eval_strategy == "Cross Validation")]
402
+ )
403
 
404
+ # Run forecast when button is clicked
405
  submit_btn.click(
406
  fn=run_forecast,
407
  inputs=[
 
410
  use_window_avg, window_size, use_seasonal_window_avg, seasonal_window_size,
411
  use_autoets, use_autoarima, future_horizon
412
  ],
413
+ outputs=[eval_output, validation_output, validation_plot, forecast_output, forecast_plot, export_files, message_output]
414
  )
415
 
416
  if __name__ == "__main__":