fmegahed commited on
Commit
277fd87
·
verified ·
1 Parent(s): a6c4361

Fixed issues with TimeGPT

Browse files
Files changed (1) hide show
  1. app.py +503 -189
app.py CHANGED
@@ -18,15 +18,341 @@ from statsforecast.models import (
18
 
19
  from utilsforecast.evaluation import evaluate
20
  from utilsforecast.losses import *
21
- from utilsforecast.plotting import plot_series
22
 
 
23
  from nixtla import NixtlaClient
24
 
25
- # Initialize TimeGPT client using Hugging Face secret
26
- nixtla_api_key = os.getenv("NIXTLA_API_KEY")
27
- nixtla_client = NixtlaClient(api_key=nixtla_api_key)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- # Sample Dataset
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  def download_sample():
31
  sample_data = """unique_id,ds,y
32
  ^GSPC,2023-01-03,3824.139892578125
@@ -597,223 +923,211 @@ def download_sample():
597
  temp.close()
598
  return temp.name
599
 
600
- # Load and validate user data
601
- def load_data(file):
602
- if file is None:
603
- return None, "Please upload a CSV file"
604
- try:
605
- df = pd.read_csv(file)
606
- required_cols = ['unique_id', 'ds', 'y']
607
- missing = [c for c in required_cols if c not in df.columns]
608
- if missing:
609
- return None, f"Missing required columns: {', '.join(missing)}"
610
- df['ds'] = pd.to_datetime(df['ds'])
611
- df = df.sort_values(['unique_id', 'ds']).reset_index(drop=True)
612
- if df['y'].isna().any():
613
- return None, "Data contains missing values in 'y'"
614
- return df, "Data loaded successfully!"
615
- except Exception as e:
616
- return None, f"Error loading data: {e}"
617
-
618
- # Export results to CSV files for download
619
- def export_results(eval_df, validation_df, future_df):
620
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
621
- temp_dir = tempfile.mkdtemp()
622
- result_files = []
623
- if eval_df is not None:
624
- path = os.path.join(temp_dir, f"evaluation_metrics_{timestamp}.csv")
625
- eval_df.to_csv(path, index=False)
626
- result_files.append(path)
627
- if validation_df is not None:
628
- path = os.path.join(temp_dir, f"validation_results_{timestamp}.csv")
629
- validation_df.to_csv(path, index=False)
630
- result_files.append(path)
631
- if future_df is not None:
632
- path = os.path.join(temp_dir, f"future_forecasts_{timestamp}.csv")
633
- future_df.to_csv(path, index=False)
634
- result_files.append(path)
635
- return result_files
636
-
637
- # Main forecasting logic
638
- def run_forecast(
639
- file,
640
- frequency,
641
- eval_strategy,
642
- horizon,
643
- step_size,
644
- num_windows,
645
- use_historical_avg,
646
- use_naive,
647
- use_seasonal_naive,
648
- seasonality,
649
- use_window_avg,
650
- window_size,
651
- use_seasonal_window_avg,
652
- seasonal_window_size,
653
- use_autoets,
654
- use_autoarima,
655
- use_timegpt,
656
- future_horizon
657
- ):
658
- df, msg = load_data(file)
659
- if df is None:
660
- # return placeholders plus message
661
- return [None]*9 + [msg]
662
-
663
- # Build model list
664
- models = []
665
- if use_historical_avg:
666
- models.append(HistoricAverage())
667
- if use_naive:
668
- models.append(Naive())
669
- if use_seasonal_naive:
670
- models.append(SeasonalNaive(season_length=seasonality))
671
- if use_window_avg:
672
- models.append(WindowAverage(window_size=window_size))
673
- if use_seasonal_window_avg:
674
- models.append(SeasonalWindowAverage(season_length=seasonal_window_size))
675
- if use_autoets:
676
- models.append(AutoETS(season_length=seasonality))
677
- if use_autoarima:
678
- models.append(AutoARIMA(season_length=seasonality))
679
-
680
- if not models and not use_timegpt:
681
- return [None]*9 + ["Please select at least one forecasting model"]
682
-
683
- # StatsForecast run
684
- sf = StatsForecast(models=models, freq=frequency, n_jobs=-1) if models else None
685
-
686
- # Cross validation or fixed-window evaluation
687
- validation_df, fig_val = None, None
688
- if sf is not None:
689
- if eval_strategy == "Cross Validation":
690
- validation_df = sf.cross_validation(
691
- df=df, h=horizon, step_size=step_size, periods=num_windows
692
- )
693
- else:
694
- # Fixed window splits
695
- cutoff = df['ds'].max() - pd.to_timedelta(horizon, unit=frequency)
696
- validation_df = sf.forecast(df[df['ds'] <= cutoff], h=horizon)
697
- eval_df = evaluate(validation_df)
698
- fig_val = plot_series(df=df, forecast_df=validation_df, title="Validation Results")
699
- else:
700
- eval_df = None
701
-
702
- # Future forecast with StatsForecast
703
- future_df, fig_future = None, None
704
- if sf is not None:
705
- future_df = sf.forecast(df=df, h=future_horizon)
706
- fig_future = plot_series(df=df, forecast_df=future_df, title="Future Forecast")
707
-
708
- # TimeGPT / Transformer forecast
709
- tg_df, fig_tg = None, None
710
- if use_timegpt:
711
- tdf = df[['unique_id', 'ds', 'y']]
712
- tg_df = nixtla_client.forecast(
713
- df=tdf, h=future_horizon, freq=frequency, level=[95]
714
- )
715
- fig_tg = nixtla_client.plot(
716
- df=tdf, forecasts_df=tg_df, level=[95]
717
- )
718
-
719
- # Export all results
720
- files = export_results(
721
- eval_df if sf is not None else None,
722
- validation_df,
723
- future_df
724
- )
725
-
726
- return (
727
- eval_df,
728
- validation_df,
729
- fig_val,
730
- future_df,
731
- fig_future,
732
- tg_df,
733
- fig_tg,
734
- files,
735
- "Analysis completed successfully!"
736
- )
737
 
738
- # Build Gradio interface
739
- theme = None # adjust or import your theme
740
  with gr.Blocks(title="Time Series Forecasting App", theme=theme) as app:
741
  gr.Markdown("# 📈 Time Series Forecasting App")
742
- gr.Markdown(
743
- "> **Disclaimer:** For simplicity, external predictors (covariates) are not supported in this app. "
744
- "However, they are supported by the `AutoARIMA()` model in *statsforecast* and the `TimeGPT` model from *nixtla*.")
 
 
 
 
 
 
 
745
 
746
  with gr.Row():
747
  with gr.Column(scale=2):
748
  file_input = gr.File(label="Upload CSV file", file_types=[".csv"])
 
749
  download_btn = gr.Button("Download Sample Data", variant="secondary")
750
- download_output = gr.File(label="Download Sample", visible=True)
751
  download_btn.click(fn=download_sample, outputs=download_output)
752
 
753
  with gr.Accordion("Data & Validation Settings", open=True):
754
  frequency = gr.Dropdown(
755
- choices=["D", "W", "M", "H"], value="D", label="Frequency"
 
 
 
 
 
 
 
 
 
756
  )
 
 
757
  eval_strategy = gr.Radio(
758
- choices=["Cross Validation", "Fixed Window"],
759
- value="Cross Validation",
760
- label="Validation Strategy"
761
  )
762
- horizon = gr.Slider(1, 100, value=10, label="CV Horizon")
763
- step_size = gr.Slider(1, 100, value=1, label="CV Step Size")
764
- num_windows = gr.Number(value=3, label="Number of CV Windows")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
765
 
766
- with gr.Accordion("Models", open=True):
767
- use_historical_avg = gr.Checkbox(label="Historical Average", value=True)
768
- use_naive = gr.Checkbox(label="Naive", value=False)
769
- use_seasonal_naive = gr.Checkbox(label="Seasonal Naive", value=False)
770
- seasonality = gr.Number(value=12, label="Seasonality")
771
- use_window_avg = gr.Checkbox(label="Window Average", value=False)
772
- window_size = gr.Number(value=3, label="Window Size")
773
- use_seasonal_window_avg = gr.Checkbox(label="Seasonal Window Average", value=False)
774
- seasonal_window_size = gr.Number(value=12, label="Seasonal Window Size")
775
- use_autoets = gr.Checkbox(label="AutoETS", value=False)
776
- use_autoarima = gr.Checkbox(label="AutoARIMA", value=False)
777
- gr.Markdown("### Transformer Models")
778
- use_timegpt = gr.Checkbox(label="TimeGPT (Transformer)", value=False)
779
-
780
- future_horizon = gr.Slider(1, 100, value=12, label="Future Forecast Horizon")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
781
 
782
  with gr.Column(scale=3):
783
- eval_output = gr.Dataframe(label="Evaluation Metrics")
784
- with gr.Tabs():
 
785
  with gr.TabItem("Validation Results"):
786
- validation_output = gr.Dataframe(label="Validation Data")
787
  validation_plot = gr.Plot(label="Validation Plot")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
788
  with gr.TabItem("Future Forecast"):
789
- forecast_output = gr.Dataframe(label="Future Forecast Data")
790
  forecast_plot = gr.Plot(label="Future Forecast Plot")
791
- with gr.TabItem("Transformer Forecast"):
792
- tg_output = gr.Dataframe(label="TimeGPT Forecast Data")
793
- tg_plot = gr.Plot(label="TimeGPT Forecast Plot")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
794
  with gr.TabItem("Export Results"):
795
  export_files = gr.Files(label="Download Results")
796
 
797
- message_output = gr.Markdown()
 
 
 
 
 
 
 
 
 
 
 
 
798
 
799
- submit_btn = gr.Button("Run Validation and Forecast", variant="primary", size="lg")
800
  submit_btn.click(
801
  fn=run_forecast,
802
  inputs=[
803
  file_input, frequency, eval_strategy, horizon, step_size, num_windows,
804
  use_historical_avg, use_naive, use_seasonal_naive, seasonality,
805
  use_window_avg, window_size, use_seasonal_window_avg, seasonal_window_size,
806
- use_autoets, use_autoarima, use_timegpt, future_horizon
 
807
  ],
808
- outputs=[
809
- eval_output,
810
- validation_output, validation_plot,
811
- forecast_output, forecast_plot,
812
- tg_output, tg_plot,
813
- export_files,
814
- message_output
815
- ]
816
  )
817
 
818
  if __name__ == "__main__":
819
- app.launch(share=False)
 
18
 
19
  from utilsforecast.evaluation import evaluate
20
  from utilsforecast.losses import *
 
21
 
22
+ # Import for TimeGPT
23
  from nixtla import NixtlaClient
24
 
25
+ # Function to load and process uploaded CSV
26
+ def load_data(file):
27
+ if file is None:
28
+ return None, "Please upload a CSV file"
29
+ try:
30
+ df = pd.read_csv(file)
31
+ required_cols = ['unique_id', 'ds', 'y']
32
+ missing_cols = [col for col in required_cols if col not in df.columns]
33
+ if missing_cols:
34
+ return None, f"Missing required columns: {', '.join(missing_cols)}"
35
+
36
+ df['ds'] = pd.to_datetime(df['ds'])
37
+ df = df.sort_values(['unique_id', 'ds']).reset_index(drop=True)
38
+
39
+ # Check for NaN values
40
+ if df['y'].isna().any():
41
+ return None, "Data contains missing values in the 'y' column"
42
+
43
+ return df, "Data loaded successfully!"
44
+ except Exception as e:
45
+ return None, f"Error loading data: {str(e)}"
46
+
47
+ # Function to generate and return a plot
48
+ def create_forecast_plot(forecast_df, original_df, title="Forecasting Results"):
49
+ plt.figure(figsize=(12, 7))
50
+ unique_ids = forecast_df['unique_id'].unique()
51
+ forecast_cols = [col for col in forecast_df.columns if col not in ['unique_id', 'ds', 'cutoff']]
52
+
53
+ colors = plt.cm.tab10.colors
54
+
55
+ for i, unique_id in enumerate(unique_ids):
56
+ original_data = original_df[original_df['unique_id'] == unique_id]
57
+ plt.plot(original_data['ds'], original_data['y'], 'k-', linewidth=2, label=f'{unique_id} (Actual)')
58
+
59
+ forecast_data = forecast_df[forecast_df['unique_id'] == unique_id]
60
+ for j, col in enumerate(forecast_cols):
61
+ if col in forecast_data.columns:
62
+ plt.plot(forecast_data['ds'], forecast_data[col],
63
+ color=colors[j % len(colors)],
64
+ linestyle='--',
65
+ linewidth=1.5,
66
+ label=f'{col}')
67
+
68
+ plt.title(title, fontsize=16)
69
+ plt.xlabel('Date', fontsize=12)
70
+ plt.ylabel('Value', fontsize=12)
71
+ plt.grid(True, alpha=0.3)
72
+ plt.legend(loc='best')
73
+ plt.tight_layout()
74
+
75
+ # Format date labels better
76
+ fig = plt.gcf()
77
+ ax = plt.gca()
78
+ fig.autofmt_xdate()
79
+
80
+ return fig
81
 
82
+ # Function to create a plot for future forecasts
83
+ def create_future_forecast_plot(forecast_df, original_df):
84
+ plt.figure(figsize=(12, 7))
85
+ unique_ids = forecast_df['unique_id'].unique()
86
+ forecast_cols = [col for col in forecast_df.columns if col not in ['unique_id', 'ds']]
87
+
88
+ colors = plt.cm.tab10.colors
89
+
90
+ for i, unique_id in enumerate(unique_ids):
91
+ # Plot historical data
92
+ original_data = original_df[original_df['unique_id'] == unique_id]
93
+ plt.plot(original_data['ds'], original_data['y'], 'k-', linewidth=2, label=f'{unique_id} (Historical)')
94
+
95
+ # Plot forecast data with shaded vertical line separator
96
+ forecast_data = forecast_df[forecast_df['unique_id'] == unique_id]
97
+
98
+ # Add vertical line at the forecast start
99
+ if not forecast_data.empty and not original_data.empty:
100
+ forecast_start = forecast_data['ds'].min()
101
+ plt.axvline(x=forecast_start, color='gray', linestyle='--', alpha=0.5)
102
+
103
+ for j, col in enumerate(forecast_cols):
104
+ if col in forecast_data.columns:
105
+ plt.plot(forecast_data['ds'], forecast_data[col],
106
+ color=colors[j % len(colors)],
107
+ linestyle='--',
108
+ linewidth=1.5,
109
+ label=f'{col}')
110
+
111
+ plt.title('Future Forecast', fontsize=16)
112
+ plt.xlabel('Date', fontsize=12)
113
+ plt.ylabel('Value', fontsize=12)
114
+ plt.grid(True, alpha=0.3)
115
+ plt.legend(loc='best')
116
+ plt.tight_layout()
117
+
118
+ # Format date labels better
119
+ fig = plt.gcf()
120
+ ax = plt.gca()
121
+ fig.autofmt_xdate()
122
+
123
+ return fig
124
+
125
+ # Function to export results to CSV
126
+ def export_results(eval_df, cv_results, future_forecasts):
127
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
128
+
129
+ # Create temp directory if it doesn't exist
130
+ temp_dir = tempfile.mkdtemp()
131
+
132
+ result_files = []
133
+
134
+ if eval_df is not None:
135
+ eval_path = os.path.join(temp_dir, f"evaluation_metrics_{timestamp}.csv")
136
+ eval_df.to_csv(eval_path, index=False)
137
+ result_files.append(eval_path)
138
+
139
+ if cv_results is not None:
140
+ cv_path = os.path.join(temp_dir, f"cross_validation_results_{timestamp}.csv")
141
+ cv_results.to_csv(cv_path, index=False)
142
+ result_files.append(cv_path)
143
+
144
+ if future_forecasts is not None:
145
+ forecast_path = os.path.join(temp_dir, f"forecasts_{timestamp}.csv")
146
+ future_forecasts.to_csv(forecast_path, index=False)
147
+ result_files.append(forecast_path)
148
+
149
+ return result_files
150
+
151
+ # Main forecasting logic
152
+ def run_forecast(
153
+ file,
154
+ frequency,
155
+ eval_strategy,
156
+ horizon,
157
+ step_size,
158
+ num_windows,
159
+ use_historical_avg,
160
+ use_naive,
161
+ use_seasonal_naive,
162
+ seasonality,
163
+ use_window_avg,
164
+ window_size,
165
+ use_seasonal_window_avg,
166
+ seasonal_window_size,
167
+ use_autoets,
168
+ use_autoarima,
169
+ use_timegpt,
170
+ finetune_loss,
171
+ confidence_level,
172
+ future_horizon
173
+ ):
174
+ df, message = load_data(file)
175
+ if df is None:
176
+ return None, None, None, None, None, None, message
177
+
178
+ # Initialize results
179
+ eval_df = None
180
+ cv_results = None
181
+ future_forecasts = None
182
+
183
+ # Set up traditional statistical models
184
+ models = []
185
+ model_aliases = []
186
+
187
+ if use_historical_avg:
188
+ models.append(HistoricAverage(alias='historical_average'))
189
+ model_aliases.append('historical_average')
190
+ if use_naive:
191
+ models.append(Naive(alias='naive'))
192
+ model_aliases.append('naive')
193
+ if use_seasonal_naive:
194
+ models.append(SeasonalNaive(season_length=seasonality, alias='seasonal_naive'))
195
+ model_aliases.append('seasonal_naive')
196
+ if use_window_avg:
197
+ models.append(WindowAverage(window_size=window_size, alias='window_average'))
198
+ model_aliases.append('window_average')
199
+ if use_seasonal_window_avg:
200
+ models.append(SeasonalWindowAverage(season_length=seasonality, window_size=seasonal_window_size, alias='seasonal_window_average'))
201
+ model_aliases.append('seasonal_window_average')
202
+ if use_autoets:
203
+ models.append(AutoETS(alias='autoets', season_length=seasonality))
204
+ model_aliases.append('autoets')
205
+ if use_autoarima:
206
+ models.append(AutoARIMA(alias='autoarima', season_length=seasonality))
207
+ model_aliases.append('autoarima')
208
+
209
+ if not models and not use_timegpt:
210
+ return None, None, None, None, None, None, "Please select at least one forecasting model"
211
+
212
+ try:
213
+ # Initialize results with empty DataFrames
214
+ combined_eval_df = pd.DataFrame()
215
+ combined_cv_results = pd.DataFrame()
216
+ combined_future_forecasts = pd.DataFrame()
217
+
218
+ # Run traditional statistical models if any are selected
219
+ if models:
220
+ sf = StatsForecast(models=models, freq=frequency, n_jobs=-1)
221
+
222
+ # Run cross-validation for traditional models
223
+ if eval_strategy == "Cross Validation":
224
+ cv_results = sf.cross_validation(df=df, h=horizon, step_size=step_size, n_windows=num_windows)
225
+ evaluation = evaluate(df=cv_results, metrics=[bias, mae, rmse, mape], models=model_aliases)
226
+ eval_df = pd.DataFrame(evaluation).reset_index()
227
+ else: # Fixed window
228
+ cv_results = sf.cross_validation(df=df, h=horizon, step_size=10, n_windows=1) # any step size for 1 window
229
+ evaluation = evaluate(df=cv_results, metrics=[bias, mae, rmse, mape], models=model_aliases)
230
+ eval_df = pd.DataFrame(evaluation).reset_index()
231
+
232
+ # Generate future forecasts
233
+ future_forecasts = sf.forecast(df=df, h=future_horizon)
234
+
235
+ # Store results
236
+ combined_eval_df = eval_df.copy() if eval_df is not None else pd.DataFrame()
237
+ combined_cv_results = cv_results.copy() if cv_results is not None else pd.DataFrame()
238
+ combined_future_forecasts = future_forecasts.copy() if future_forecasts is not None else pd.DataFrame()
239
+
240
+ # Run TimeGPT if selected
241
+ if use_timegpt:
242
+ try:
243
+ # Get API key from environment variables
244
+ nixtla_api_key = os.getenv("NIXTLA_API_KEY")
245
+ if not nixtla_api_key:
246
+ return None, None, None, None, None, None, "TimeGPT API key not found. Please set the NIXTLA_API_KEY environment variable."
247
+
248
+ # Initialize Nixtla client
249
+ nixtla_client = NixtlaClient(api_key=nixtla_api_key)
250
+
251
+ # Convert confidence_level to list format
252
+ level = [float(confidence_level)]
253
+
254
+ # Run cross-validation for TimeGPT
255
+ if eval_strategy == "Cross Validation":
256
+ timegpt_cv_df = nixtla_client.cross_validation(
257
+ df=df,
258
+ h=horizon,
259
+ freq=frequency,
260
+ level=level,
261
+ n_windows=num_windows,
262
+ step_size=step_size
263
+ )
264
+ timegpt_cv_eval = evaluate(
265
+ df=timegpt_cv_df,
266
+ metrics=[mape, mae, rmse, bias],
267
+ models=['TimeGPT'],
268
+ level=level
269
+ )
270
+ timegpt_eval_df = pd.DataFrame(timegpt_cv_eval).reset_index()
271
+ else: # Fixed window
272
+ timegpt_cv_df = nixtla_client.cross_validation(
273
+ df=df,
274
+ h=horizon,
275
+ freq=frequency,
276
+ level=level,
277
+ n_windows=1,
278
+ step_size=10
279
+ )
280
+ timegpt_cv_eval = evaluate(
281
+ df=timegpt_cv_df,
282
+ metrics=[mape, mae, rmse, bias],
283
+ models=['TimeGPT'],
284
+ level=level
285
+ )
286
+ timegpt_eval_df = pd.DataFrame(timegpt_cv_eval).reset_index()
287
+
288
+ # Generate future forecasts with TimeGPT
289
+ forecast_timegpt = nixtla_client.forecast(
290
+ df=df,
291
+ h=future_horizon,
292
+ freq=frequency,
293
+ level=level,
294
+ finetune_loss=finetune_loss
295
+ )
296
+
297
+ # Combine results
298
+ if not combined_eval_df.empty and not timegpt_eval_df.empty:
299
+ combined_eval_df = pd.concat([combined_eval_df, timegpt_eval_df], ignore_index=True)
300
+ else:
301
+ combined_eval_df = timegpt_eval_df if not timegpt_eval_df.empty else combined_eval_df
302
+
303
+ if not combined_cv_results.empty and not timegpt_cv_df.empty:
304
+ # Make sure we're not duplicating the 'y' column
305
+ if 'y' in combined_cv_results.columns and 'y' in timegpt_cv_df.columns:
306
+ timegpt_cv_df_no_y = timegpt_cv_df.drop(columns=['y'])
307
+ combined_cv_results = pd.merge(
308
+ combined_cv_results,
309
+ timegpt_cv_df_no_y,
310
+ on=['unique_id', 'ds', 'cutoff'],
311
+ how='outer'
312
+ )
313
+ else:
314
+ combined_cv_results = pd.concat([combined_cv_results, timegpt_cv_df], ignore_index=True)
315
+ else:
316
+ combined_cv_results = timegpt_cv_df if not timegpt_cv_df.empty else combined_cv_results
317
+
318
+ if not combined_future_forecasts.empty and not forecast_timegpt.empty:
319
+ # Make sure we're merging on common columns
320
+ combined_future_forecasts = pd.merge(
321
+ combined_future_forecasts,
322
+ forecast_timegpt,
323
+ on=['unique_id', 'ds'],
324
+ how='outer'
325
+ )
326
+ else:
327
+ combined_future_forecasts = forecast_timegpt if not forecast_timegpt.empty else combined_future_forecasts
328
+
329
+ except Exception as e:
330
+ return None, None, None, None, None, None, f"Error with TimeGPT: {str(e)}"
331
+
332
+ # Create plots
333
+ if not combined_cv_results.empty:
334
+ fig_validation = create_forecast_plot(
335
+ combined_cv_results,
336
+ df,
337
+ f"{eval_strategy} Results"
338
+ )
339
+ else:
340
+ fig_validation = None
341
+
342
+ if not combined_future_forecasts.empty:
343
+ fig_future = create_future_forecast_plot(combined_future_forecasts, df)
344
+ else:
345
+ fig_future = None
346
+
347
+ # Export results
348
+ export_files = export_results(combined_eval_df, combined_cv_results, combined_future_forecasts)
349
+
350
+ return combined_eval_df, combined_cv_results, fig_validation, combined_future_forecasts, fig_future, export_files, "Analysis completed successfully!"
351
+
352
+ except Exception as e:
353
+ return None, None, None, None, None, None, f"Error during forecasting: {str(e)}"
354
+
355
+ # Sample CSV file generation
356
  def download_sample():
357
  sample_data = """unique_id,ds,y
358
  ^GSPC,2023-01-03,3824.139892578125
 
923
  temp.close()
924
  return temp.name
925
 
926
+ # Global theme
927
+ theme = gr.themes.Soft(
928
+ primary_hue="blue",
929
+ secondary_hue="indigo",
930
+ neutral_hue="gray"
931
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
932
 
933
+ # Gradio interface
 
934
  with gr.Blocks(title="Time Series Forecasting App", theme=theme) as app:
935
  gr.Markdown("# 📈 Time Series Forecasting App")
936
+ gr.Markdown("Upload a CSV with `unique_id`, `ds`, and `y` columns to apply forecasting models.")
937
+
938
+ # Disclaimer about external predictors
939
+ with gr.Accordion("Disclaimer", open=True):
940
+ gr.Markdown("""
941
+ **Disclaimer:** For simplicity, this app does not allow the use of external predictors.
942
+ However, they can be easily included in the underlying statsforecast (for AutoARIMA)
943
+ and the TimeGPT implementation by Nixtla. To use external predictors, you would need to modify
944
+ the code to include them in your forecasting models.
945
+ """)
946
 
947
  with gr.Row():
948
  with gr.Column(scale=2):
949
  file_input = gr.File(label="Upload CSV file", file_types=[".csv"])
950
+
951
  download_btn = gr.Button("Download Sample Data", variant="secondary")
952
+ download_output = gr.File(label="Click to download", visible=True)
953
  download_btn.click(fn=download_sample, outputs=download_output)
954
 
955
  with gr.Accordion("Data & Validation Settings", open=True):
956
  frequency = gr.Dropdown(
957
+ choices=[
958
+ ("Hourly", "H"),
959
+ ("Daily", "D"),
960
+ ("Weekly", "WS"),
961
+ ("Monthly", "MS"),
962
+ ("Quarterly", "QS"),
963
+ ("Yearly", "YS")
964
+ ],
965
+ label="Data Frequency",
966
+ value="D"
967
  )
968
+
969
+ # Evaluation Strategy
970
  eval_strategy = gr.Radio(
971
+ choices=["Fixed Window", "Cross Validation"],
972
+ label="Evaluation Strategy",
973
+ value="Cross Validation"
974
  )
975
+
976
+ # Fixed Window settings
977
+ with gr.Group(visible=True) as fixed_window_box:
978
+ gr.Markdown("### Fixed Window Settings")
979
+ horizon = gr.Slider(1, 100, value=10, step=1, label="Validation Horizon (steps ahead to predict)")
980
+
981
+ # Cross Validation settings
982
+ with gr.Group(visible=True) as cv_box:
983
+ gr.Markdown("### Cross Validation Settings")
984
+ with gr.Row():
985
+ step_size = gr.Slider(1, 50, value=10, step=1, label="Step Size (distance between windows)")
986
+ num_windows = gr.Slider(1, 20, value=5, step=1, label="Number of Windows")
987
+
988
+ # Future forecast settings (always visible)
989
+ with gr.Group():
990
+ gr.Markdown("### Future Forecast Settings")
991
+ future_horizon = gr.Slider(1, 100, value=10, step=1, label="Future Forecast Horizon (steps to predict)")
992
 
993
+ with gr.Accordion("Model Configuration", open=True):
994
+ with gr.Tabs() as model_tabs:
995
+ # Traditional Statistical Models Tab
996
+ with gr.TabItem("Statistical Models"):
997
+ gr.Markdown("## Basic Models")
998
+ with gr.Row():
999
+ use_historical_avg = gr.Checkbox(label="Historical Average", value=True)
1000
+ use_naive = gr.Checkbox(label="Naive", value=True)
1001
+
1002
+ # Common seasonality parameter at the top level
1003
+ with gr.Group():
1004
+ gr.Markdown("### Seasonality Configuration")
1005
+ gr.Markdown("This seasonality period affects Seasonal Naive, Seasonal Window Average, AutoETS, and AutoARIMA models")
1006
+ seasonality = gr.Number(label="Seasonality Period", value=5)
1007
+
1008
+ gr.Markdown("### Seasonal Models")
1009
+ with gr.Row():
1010
+ use_seasonal_naive = gr.Checkbox(label="Seasonal Naive", value=True)
1011
+
1012
+ gr.Markdown("### Window-based Models")
1013
+ with gr.Row():
1014
+ use_window_avg = gr.Checkbox(label="Window Average", value=True)
1015
+ window_size = gr.Number(label="Window Size", value=10)
1016
+
1017
+ with gr.Row():
1018
+ use_seasonal_window_avg = gr.Checkbox(label="Seasonal Window Average", value=True)
1019
+ seasonal_window_size = gr.Number(label="Seasonal Window Size", value=2)
1020
+
1021
+ gr.Markdown("### Advanced Models (use seasonality from above)")
1022
+ with gr.Row():
1023
+ use_autoets = gr.Checkbox(label="AutoETS (Exponential Smoothing)", value=True)
1024
+ use_autoarima = gr.Checkbox(label="AutoARIMA", value=True)
1025
+
1026
+ # Transformer Models Tab (TimeGPT)
1027
+ with gr.TabItem("Transformer Models"):
1028
+ gr.Markdown("## TimeGPT Model")
1029
+ gr.Markdown("TimeGPT uses a transformer architecture for state-of-the-art time series forecasting")
1030
+
1031
+ with gr.Row():
1032
+ use_timegpt = gr.Checkbox(label="Use TimeGPT", value=False)
1033
+
1034
+ with gr.Group():
1035
+ gr.Markdown("### TimeGPT Configuration")
1036
+ with gr.Row():
1037
+ finetune_loss = gr.Dropdown(
1038
+ choices=["mape", "mae", "rmse", "smape"],
1039
+ label="Finetune Loss Metric",
1040
+ value="mape"
1041
+ )
1042
+ confidence_level = gr.Slider(50, 99, value=95, step=1, label="Confidence Level (%)")
1043
+
1044
+ gr.Markdown("""
1045
+ **Note:** Using TimeGPT requires a valid API key. The API key should
1046
+ be set as an environment variable named `NIXTLA_API_KEY`. This space uses a trial key, which is rate limited.
1047
+ """)
1048
 
1049
  with gr.Column(scale=3):
1050
+ message_output = gr.Textbox(label="Status Message")
1051
+
1052
+ with gr.Tabs() as tabs:
1053
  with gr.TabItem("Validation Results"):
1054
+ eval_output = gr.Dataframe(label="Evaluation Metrics")
1055
  validation_plot = gr.Plot(label="Validation Plot")
1056
+ validation_output = gr.Dataframe(label="Validation Data", visible=False)
1057
+
1058
+ with gr.Row():
1059
+ show_data_btn = gr.Button("Show Validation Data")
1060
+ hide_data_btn = gr.Button("Hide Validation Data", visible=False)
1061
+
1062
+ def show_data():
1063
+ return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)
1064
+
1065
+ def hide_data():
1066
+ return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
1067
+
1068
+ show_data_btn.click(
1069
+ fn=show_data,
1070
+ outputs=[validation_output, hide_data_btn, show_data_btn]
1071
+ )
1072
+
1073
+ hide_data_btn.click(
1074
+ fn=hide_data,
1075
+ outputs=[validation_output, hide_data_btn, show_data_btn]
1076
+ )
1077
+
1078
  with gr.TabItem("Future Forecast"):
 
1079
  forecast_plot = gr.Plot(label="Future Forecast Plot")
1080
+ forecast_output = gr.Dataframe(label="Future Forecast Data", visible=False)
1081
+
1082
+ with gr.Row():
1083
+ show_forecast_btn = gr.Button("Show Forecast Data")
1084
+ hide_forecast_btn = gr.Button("Hide Forecast Data", visible=False)
1085
+
1086
+ def show_forecast():
1087
+ return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)
1088
+
1089
+ def hide_forecast():
1090
+ return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
1091
+
1092
+ show_forecast_btn.click(
1093
+ fn=show_forecast,
1094
+ outputs=[forecast_output, hide_forecast_btn, show_forecast_btn]
1095
+ )
1096
+
1097
+ hide_forecast_btn.click(
1098
+ fn=hide_forecast,
1099
+ outputs=[forecast_output, hide_forecast_btn, show_forecast_btn]
1100
+ )
1101
+
1102
  with gr.TabItem("Export Results"):
1103
  export_files = gr.Files(label="Download Results")
1104
 
1105
+ with gr.Row(visible=True) as run_row:
1106
+ submit_btn = gr.Button("Run Validation and Forecast", variant="primary", size="lg")
1107
+
1108
+ # Update visibility of the appropriate box based on evaluation strategy
1109
+ def update_eval_boxes(strategy):
1110
+ return (gr.update(visible=strategy == "Fixed Window"),
1111
+ gr.update(visible=strategy == "Cross Validation"))
1112
+
1113
+ eval_strategy.change(
1114
+ fn=update_eval_boxes,
1115
+ inputs=[eval_strategy],
1116
+ outputs=[fixed_window_box, cv_box]
1117
+ )
1118
 
1119
+ # Run forecast when button is clicked
1120
  submit_btn.click(
1121
  fn=run_forecast,
1122
  inputs=[
1123
  file_input, frequency, eval_strategy, horizon, step_size, num_windows,
1124
  use_historical_avg, use_naive, use_seasonal_naive, seasonality,
1125
  use_window_avg, window_size, use_seasonal_window_avg, seasonal_window_size,
1126
+ use_autoets, use_autoarima, use_timegpt, finetune_loss, confidence_level,
1127
+ future_horizon
1128
  ],
1129
+ outputs=[eval_output, validation_output, validation_plot, forecast_output, forecast_plot, export_files, message_output]
 
 
 
 
 
 
 
1130
  )
1131
 
1132
  if __name__ == "__main__":
1133
+ app.launch(share=False)