fmegahed commited on
Commit
cecf5f5
·
verified ·
1 Parent(s): 380ba72

first attempt to include timegpt

Browse files
Files changed (1) hide show
  1. app.py +191 -340
app.py CHANGED
@@ -18,213 +18,15 @@ from statsforecast.models import (
18
 
19
  from utilsforecast.evaluation import evaluate
20
  from utilsforecast.losses import *
 
21
 
22
- # Function to load and process uploaded CSV
23
- def load_data(file):
24
- if file is None:
25
- return None, "Please upload a CSV file"
26
- try:
27
- df = pd.read_csv(file)
28
- required_cols = ['unique_id', 'ds', 'y']
29
- missing_cols = [col for col in required_cols if col not in df.columns]
30
- if missing_cols:
31
- return None, f"Missing required columns: {', '.join(missing_cols)}"
32
-
33
- df['ds'] = pd.to_datetime(df['ds'])
34
- df = df.sort_values(['unique_id', 'ds']).reset_index(drop=True)
35
-
36
- # Check for NaN values
37
- if df['y'].isna().any():
38
- return None, "Data contains missing values in the 'y' column"
39
-
40
- return df, "Data loaded successfully!"
41
- except Exception as e:
42
- return None, f"Error loading data: {str(e)}"
43
-
44
- # Function to generate and return a plot
45
- def create_forecast_plot(forecast_df, original_df, title="Forecasting Results"):
46
- plt.figure(figsize=(12, 7))
47
- unique_ids = forecast_df['unique_id'].unique()
48
- forecast_cols = [col for col in forecast_df.columns if col not in ['unique_id', 'ds', 'cutoff']]
49
-
50
- colors = plt.cm.tab10.colors
51
-
52
- for i, unique_id in enumerate(unique_ids):
53
- original_data = original_df[original_df['unique_id'] == unique_id]
54
- plt.plot(original_data['ds'], original_data['y'], 'k-', linewidth=2, label=f'{unique_id} (Actual)')
55
-
56
- forecast_data = forecast_df[forecast_df['unique_id'] == unique_id]
57
- for j, col in enumerate(forecast_cols):
58
- if col in forecast_data.columns:
59
- plt.plot(forecast_data['ds'], forecast_data[col],
60
- color=colors[j % len(colors)],
61
- linestyle='--',
62
- linewidth=1.5,
63
- label=f'{col}')
64
 
65
- plt.title(title, fontsize=16)
66
- plt.xlabel('Date', fontsize=12)
67
- plt.ylabel('Value', fontsize=12)
68
- plt.grid(True, alpha=0.3)
69
- plt.legend(loc='best')
70
- plt.tight_layout()
71
-
72
- # Format date labels better
73
- fig = plt.gcf()
74
- ax = plt.gca()
75
- fig.autofmt_xdate()
76
-
77
- return fig
78
 
79
- # Function to create a plot for future forecasts
80
- def create_future_forecast_plot(forecast_df, original_df):
81
- plt.figure(figsize=(12, 7))
82
- unique_ids = forecast_df['unique_id'].unique()
83
- forecast_cols = [col for col in forecast_df.columns if col not in ['unique_id', 'ds']]
84
-
85
- colors = plt.cm.tab10.colors
86
-
87
- for i, unique_id in enumerate(unique_ids):
88
- # Plot historical data
89
- original_data = original_df[original_df['unique_id'] == unique_id]
90
- plt.plot(original_data['ds'], original_data['y'], 'k-', linewidth=2, label=f'{unique_id} (Historical)')
91
-
92
- # Plot forecast data with shaded vertical line separator
93
- forecast_data = forecast_df[forecast_df['unique_id'] == unique_id]
94
-
95
- # Add vertical line at the forecast start
96
- if not forecast_data.empty and not original_data.empty:
97
- forecast_start = forecast_data['ds'].min()
98
- plt.axvline(x=forecast_start, color='gray', linestyle='--', alpha=0.5)
99
-
100
- for j, col in enumerate(forecast_cols):
101
- if col in forecast_data.columns:
102
- plt.plot(forecast_data['ds'], forecast_data[col],
103
- color=colors[j % len(colors)],
104
- linestyle='--',
105
- linewidth=1.5,
106
- label=f'{col}')
107
-
108
- plt.title('Future Forecast', fontsize=16)
109
- plt.xlabel('Date', fontsize=12)
110
- plt.ylabel('Value', fontsize=12)
111
- plt.grid(True, alpha=0.3)
112
- plt.legend(loc='best')
113
- plt.tight_layout()
114
-
115
- # Format date labels better
116
- fig = plt.gcf()
117
- ax = plt.gca()
118
- fig.autofmt_xdate()
119
-
120
- return fig
121
-
122
- # Function to export results to CSV
123
- def export_results(eval_df, cv_results, future_forecasts):
124
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
125
-
126
- # Create temp directory if it doesn't exist
127
- temp_dir = tempfile.mkdtemp()
128
-
129
- result_files = []
130
-
131
- if eval_df is not None:
132
- eval_path = os.path.join(temp_dir, f"evaluation_metrics_{timestamp}.csv")
133
- eval_df.to_csv(eval_path, index=False)
134
- result_files.append(eval_path)
135
-
136
- if cv_results is not None:
137
- cv_path = os.path.join(temp_dir, f"cross_validation_results_{timestamp}.csv")
138
- cv_results.to_csv(cv_path, index=False)
139
- result_files.append(cv_path)
140
-
141
- if future_forecasts is not None:
142
- forecast_path = os.path.join(temp_dir, f"forecasts_{timestamp}.csv")
143
- future_forecasts.to_csv(forecast_path, index=False)
144
- result_files.append(forecast_path)
145
-
146
- return result_files
147
-
148
- # Main forecasting logic
149
- def run_forecast(
150
- file,
151
- frequency,
152
- eval_strategy,
153
- horizon,
154
- step_size,
155
- num_windows,
156
- use_historical_avg,
157
- use_naive,
158
- use_seasonal_naive,
159
- seasonality,
160
- use_window_avg,
161
- window_size,
162
- use_seasonal_window_avg,
163
- seasonal_window_size,
164
- use_autoets,
165
- use_autoarima,
166
- future_horizon
167
- ):
168
- df, message = load_data(file)
169
- if df is None:
170
- return None, None, None, None, None, None, message
171
-
172
- models = []
173
- model_aliases = []
174
-
175
- if use_historical_avg:
176
- models.append(HistoricAverage(alias='historical_average'))
177
- model_aliases.append('historical_average')
178
- if use_naive:
179
- models.append(Naive(alias='naive'))
180
- model_aliases.append('naive')
181
- if use_seasonal_naive:
182
- models.append(SeasonalNaive(season_length=seasonality, alias='seasonal_naive'))
183
- model_aliases.append('seasonal_naive')
184
- if use_window_avg:
185
- models.append(WindowAverage(window_size=window_size, alias='window_average'))
186
- model_aliases.append('window_average')
187
- if use_seasonal_window_avg:
188
- models.append(SeasonalWindowAverage(season_length=seasonality, window_size=seasonal_window_size, alias='seasonal_window_average'))
189
- model_aliases.append('seasonal_window_average')
190
- if use_autoets:
191
- models.append(AutoETS(alias='autoets', season_length=seasonality))
192
- model_aliases.append('autoets')
193
- if use_autoarima:
194
- models.append(AutoARIMA(alias='autoarima', season_length=seasonality))
195
- model_aliases.append('autoarima')
196
-
197
- if not models:
198
- return None, None, None, None, None, None, "Please select at least one forecasting model"
199
-
200
- sf = StatsForecast(models=models, freq=frequency, n_jobs=-1)
201
-
202
- try:
203
- # Run cross-validation
204
- if eval_strategy == "Cross Validation":
205
- cv_results = sf.cross_validation(df=df, h=horizon, step_size=step_size, n_windows=num_windows)
206
- evaluation = evaluate(df=cv_results, metrics=[bias, mae, rmse, mape], models=model_aliases)
207
- eval_df = pd.DataFrame(evaluation).reset_index()
208
- fig_validation = create_forecast_plot(cv_results, df, "Cross Validation Results")
209
- else: # Fixed window
210
- cv_results = sf.cross_validation(df=df, h=horizon, step_size=10, n_windows=1) # any step size for 1 window
211
- evaluation = evaluate(df=cv_results, metrics=[bias, mae, rmse, mape], models=model_aliases)
212
- eval_df = pd.DataFrame(evaluation).reset_index()
213
- fig_validation = create_forecast_plot(cv_results, df, "Fixed Window Validation Results")
214
-
215
- # Generate future forecasts
216
- future_forecasts = sf.forecast(df=df, h=future_horizon)
217
- fig_future = create_future_forecast_plot(future_forecasts, df)
218
-
219
- # Export results
220
- export_files = export_results(eval_df, cv_results, future_forecasts)
221
-
222
- return eval_df, cv_results, fig_validation, future_forecasts, fig_future, export_files, "Analysis completed successfully!"
223
-
224
- except Exception as e:
225
- return None, None, None, None, None, None, f"Error during forecasting: {str(e)}"
226
-
227
- # Sample CSV file generation
228
  def download_sample():
229
  sample_data = """unique_id,ds,y
230
  ^GSPC,2023-01-03,3824.139892578125
@@ -795,175 +597,224 @@ def download_sample():
795
  temp.close()
796
  return temp.name
797
 
798
- # Global theme
799
- theme = gr.themes.Soft(
800
- primary_hue="blue",
801
- secondary_hue="indigo",
802
- neutral_hue="gray"
803
- )
 
 
 
 
 
 
 
 
 
 
 
804
 
805
- # Gradio interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
806
  with gr.Blocks(title="Time Series Forecasting App", theme=theme) as app:
807
  gr.Markdown("# 📈 Time Series Forecasting App")
808
- gr.Markdown("Upload a CSV with `unique_id`, `ds`, and `y` columns to apply forecasting models.")
 
 
 
809
 
810
  with gr.Row():
811
  with gr.Column(scale=2):
812
  file_input = gr.File(label="Upload CSV file", file_types=[".csv"])
813
-
814
  download_btn = gr.Button("Download Sample Data", variant="secondary")
815
- download_output = gr.File(label="Click to download", visible=True)
816
  download_btn.click(fn=download_sample, outputs=download_output)
817
 
818
  with gr.Accordion("Data & Validation Settings", open=True):
819
  frequency = gr.Dropdown(
820
- choices=[
821
- ("Hourly", "H"),
822
- ("Daily", "D"),
823
- ("Weekly", "WS"),
824
- ("Monthly", "MS"),
825
- ("Quarterly", "QS"),
826
- ("Yearly", "YS")
827
- ],
828
- label="Data Frequency",
829
- value="D"
830
  )
831
-
832
- # Evaluation Strategy
833
  eval_strategy = gr.Radio(
834
- choices=["Fixed Window", "Cross Validation"],
835
- label="Evaluation Strategy",
836
- value="Cross Validation"
837
  )
838
-
839
- # Fixed Window settings
840
- with gr.Group(visible=True) as fixed_window_box:
841
- gr.Markdown("### Fixed Window Settings")
842
- horizon = gr.Slider(1, 100, value=10, step=1, label="Validation Horizon (steps ahead to predict)")
843
-
844
- # Cross Validation settings
845
- with gr.Group(visible=True) as cv_box:
846
- gr.Markdown("### Cross Validation Settings")
847
- with gr.Row():
848
- step_size = gr.Slider(1, 50, value=10, step=1, label="Step Size (distance between windows)")
849
- num_windows = gr.Slider(1, 20, value=5, step=1, label="Number of Windows")
850
-
851
- # Future forecast settings (always visible)
852
- with gr.Group():
853
- gr.Markdown("### Future Forecast Settings")
854
- future_horizon = gr.Slider(1, 100, value=10, step=1, label="Future Forecast Horizon (steps to predict)")
855
 
856
- with gr.Accordion("Model Configuration", open=True):
857
- gr.Markdown("## Basic Models")
858
- with gr.Row():
859
- use_historical_avg = gr.Checkbox(label="Historical Average", value=True)
860
- use_naive = gr.Checkbox(label="Naive", value=True)
861
-
862
- # Common seasonality parameter at the top level
863
- with gr.Group():
864
- gr.Markdown("### Seasonality Configuration")
865
- gr.Markdown("This seasonality period affects Seasonal Naive, Seasonal Window Average, AutoETS, and AutoARIMA models")
866
- seasonality = gr.Number(label="Seasonality Period", value=5)
867
-
868
- gr.Markdown("### Seasonal Models")
869
- with gr.Row():
870
- use_seasonal_naive = gr.Checkbox(label="Seasonal Naive", value=True)
871
-
872
- gr.Markdown("### Window-based Models")
873
- with gr.Row():
874
- use_window_avg = gr.Checkbox(label="Window Average", value=True)
875
- window_size = gr.Number(label="Window Size", value=10)
876
-
877
- with gr.Row():
878
- use_seasonal_window_avg = gr.Checkbox(label="Seasonal Window Average", value=True)
879
- seasonal_window_size = gr.Number(label="Seasonal Window Size", value=2)
880
-
881
- gr.Markdown("### Advanced Models (use seasonality from above)")
882
- with gr.Row():
883
- use_autoets = gr.Checkbox(label="AutoETS (Exponential Smoothing)", value=True)
884
- use_autoarima = gr.Checkbox(label="AutoARIMA", value=True)
885
 
886
  with gr.Column(scale=3):
887
- message_output = gr.Textbox(label="Status Message")
888
-
889
- with gr.Tabs() as tabs:
890
  with gr.TabItem("Validation Results"):
891
- eval_output = gr.Dataframe(label="Evaluation Metrics")
892
  validation_plot = gr.Plot(label="Validation Plot")
893
- validation_output = gr.Dataframe(label="Validation Data", visible=False)
894
-
895
- with gr.Row():
896
- show_data_btn = gr.Button("Show Validation Data")
897
- hide_data_btn = gr.Button("Hide Validation Data", visible=False)
898
-
899
- def show_data():
900
- return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)
901
-
902
- def hide_data():
903
- return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
904
-
905
- show_data_btn.click(
906
- fn=show_data,
907
- outputs=[validation_output, hide_data_btn, show_data_btn]
908
- )
909
-
910
- hide_data_btn.click(
911
- fn=hide_data,
912
- outputs=[validation_output, hide_data_btn, show_data_btn]
913
- )
914
-
915
  with gr.TabItem("Future Forecast"):
 
916
  forecast_plot = gr.Plot(label="Future Forecast Plot")
917
- forecast_output = gr.Dataframe(label="Future Forecast Data", visible=False)
918
-
919
- with gr.Row():
920
- show_forecast_btn = gr.Button("Show Forecast Data")
921
- hide_forecast_btn = gr.Button("Hide Forecast Data", visible=False)
922
-
923
- def show_forecast():
924
- return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)
925
-
926
- def hide_forecast():
927
- return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
928
-
929
- show_forecast_btn.click(
930
- fn=show_forecast,
931
- outputs=[forecast_output, hide_forecast_btn, show_forecast_btn]
932
- )
933
-
934
- hide_forecast_btn.click(
935
- fn=hide_forecast,
936
- outputs=[forecast_output, hide_forecast_btn, show_forecast_btn]
937
- )
938
-
939
  with gr.TabItem("Export Results"):
940
  export_files = gr.Files(label="Download Results")
941
 
942
- with gr.Row(visible=True) as run_row:
943
- submit_btn = gr.Button("Run Validation and Forecast", variant="primary", size="lg")
944
-
945
- # Update visibility of the appropriate box based on evaluation strategy
946
- def update_eval_boxes(strategy):
947
- return (gr.update(visible=strategy == "Fixed Window"),
948
- gr.update(visible=strategy == "Cross Validation"))
949
-
950
- eval_strategy.change(
951
- fn=update_eval_boxes,
952
- inputs=[eval_strategy],
953
- outputs=[fixed_window_box, cv_box]
954
- )
955
 
956
- # Run forecast when button is clicked
957
  submit_btn.click(
958
  fn=run_forecast,
959
  inputs=[
960
  file_input, frequency, eval_strategy, horizon, step_size, num_windows,
961
  use_historical_avg, use_naive, use_seasonal_naive, seasonality,
962
  use_window_avg, window_size, use_seasonal_window_avg, seasonal_window_size,
963
- use_autoets, use_autoarima, future_horizon
964
  ],
965
- outputs=[eval_output, validation_output, validation_plot, forecast_output, forecast_plot, export_files, message_output]
 
 
 
 
 
 
 
966
  )
967
 
968
  if __name__ == "__main__":
969
- app.launch(share=False)
 
18
 
19
  from utilsforecast.evaluation import evaluate
20
  from utilsforecast.losses import *
21
+ from utilsforecast.plotting import plot_series
22
 
23
+ from nixtla import NixtlaClient
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+ # Initialize TimeGPT client using Hugging Face secret
26
+ nixtla_api_key = os.getenv("NIXTLA_API_KEY")
27
+ nixtla_client = NixtlaClient(api_key=nixtla_api_key)
 
 
 
 
 
 
 
 
 
 
28
 
29
+ # Sample Dataset
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  def download_sample():
31
  sample_data = """unique_id,ds,y
32
  ^GSPC,2023-01-03,3824.139892578125
 
597
  temp.close()
598
  return temp.name
599
 
600
+ # Load and validate user data
601
+ def load_data(file):
602
+ if file is None:
603
+ return None, "Please upload a CSV file"
604
+ try:
605
+ df = pd.read_csv(file)
606
+ required_cols = ['unique_id', 'ds', 'y']
607
+ missing = [c for c in required_cols if c not in df.columns]
608
+ if missing:
609
+ return None, f"Missing required columns: {', '.join(missing)}"
610
+ df['ds'] = pd.to_datetime(df['ds'])
611
+ df = df.sort_values(['unique_id', 'ds']).reset_index(drop=True)
612
+ if df['y'].isna().any():
613
+ return None, "Data contains missing values in 'y'"
614
+ return df, "Data loaded successfully!"
615
+ except Exception as e:
616
+ return None, f"Error loading data: {e}"
617
 
618
+ # Export results to CSV files for download
619
+ def export_results(eval_df, validation_df, future_df):
620
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
621
+ temp_dir = tempfile.mkdtemp()
622
+ result_files = []
623
+ if eval_df is not None:
624
+ path = os.path.join(temp_dir, f"evaluation_metrics_{timestamp}.csv")
625
+ eval_df.to_csv(path, index=False)
626
+ result_files.append(path)
627
+ if validation_df is not None:
628
+ path = os.path.join(temp_dir, f"validation_results_{timestamp}.csv")
629
+ validation_df.to_csv(path, index=False)
630
+ result_files.append(path)
631
+ if future_df is not None:
632
+ path = os.path.join(temp_dir, f"future_forecasts_{timestamp}.csv")
633
+ future_df.to_csv(path, index=False)
634
+ result_files.append(path)
635
+ return result_files
636
+
637
+ # Main forecasting logic
638
+ def run_forecast(
639
+ file,
640
+ frequency,
641
+ eval_strategy,
642
+ horizon,
643
+ step_size,
644
+ num_windows,
645
+ use_historical_avg,
646
+ use_naive,
647
+ use_seasonal_naive,
648
+ seasonality,
649
+ use_window_avg,
650
+ window_size,
651
+ use_seasonal_window_avg,
652
+ seasonal_window_size,
653
+ use_autoets,
654
+ use_autoarima,
655
+ use_timegpt,
656
+ future_horizon
657
+ ):
658
+ df, msg = load_data(file)
659
+ if df is None:
660
+ # return placeholders plus message
661
+ return [None]*9 + [msg]
662
+
663
+ # Build model list
664
+ models = []
665
+ if use_historical_avg:
666
+ models.append(HistoricAverage())
667
+ if use_naive:
668
+ models.append(Naive())
669
+ if use_seasonal_naive:
670
+ models.append(SeasonalNaive(season_length=seasonality))
671
+ if use_window_avg:
672
+ models.append(WindowAverage(window_size=window_size))
673
+ if use_seasonal_window_avg:
674
+ models.append(SeasonalWindowAverage(season_length=seasonal_window_size))
675
+ if use_autoets:
676
+ models.append(AutoETS(season_length=seasonality))
677
+ if use_autoarima:
678
+ models.append(AutoARIMA(season_length=seasonality))
679
+
680
+ if not models and not use_timegpt:
681
+ return [None]*9 + ["Please select at least one forecasting model"]
682
+
683
+ # StatsForecast run
684
+ sf = StatsForecast(models=models, freq=frequency, n_jobs=-1) if models else None
685
+
686
+ # Cross validation or fixed-window evaluation
687
+ validation_df, fig_val = None, None
688
+ if sf is not None:
689
+ if eval_strategy == "Cross Validation":
690
+ validation_df = sf.cross_validation(
691
+ df=df, h=horizon, step_size=step_size, periods=num_windows
692
+ )
693
+ else:
694
+ # Fixed window splits
695
+ cutoff = df['ds'].max() - pd.to_timedelta(horizon, unit=frequency)
696
+ validation_df = sf.forecast(df[df['ds'] <= cutoff], h=horizon)
697
+ eval_df = evaluate(validation_df)
698
+ fig_val = plot_series(df=df, forecast_df=validation_df, title="Validation Results")
699
+ else:
700
+ eval_df = None
701
+
702
+ # Future forecast with StatsForecast
703
+ future_df, fig_future = None, None
704
+ if sf is not None:
705
+ future_df = sf.forecast(df=df, h=future_horizon)
706
+ fig_future = plot_series(df=df, forecast_df=future_df, title="Future Forecast")
707
+
708
+ # TimeGPT / Transformer forecast
709
+ tg_df, fig_tg = None, None
710
+ if use_timegpt:
711
+ tdf = df[['unique_id', 'ds', 'y']]
712
+ tg_df = nixtla_client.forecast(
713
+ df=tdf, h=future_horizon, freq=frequency, level=[95]
714
+ )
715
+ fig_tg = nixtla_client.plot(
716
+ df=tdf, forecasts_df=tg_df, level=[95]
717
+ )
718
+
719
+ # Export all results
720
+ files = export_results(
721
+ eval_df if sf is not None else None,
722
+ validation_df,
723
+ future_df
724
+ )
725
+
726
+ return (
727
+ eval_df,
728
+ validation_df,
729
+ fig_val,
730
+ future_df,
731
+ fig_future,
732
+ tg_df,
733
+ fig_tg,
734
+ files,
735
+ "Analysis completed successfully!"
736
+ )
737
+
738
+ # Build Gradio interface
739
+ theme = None # adjust or import your theme
740
  with gr.Blocks(title="Time Series Forecasting App", theme=theme) as app:
741
  gr.Markdown("# 📈 Time Series Forecasting App")
742
+ gr.Markdown(
743
+ "> **Disclaimer:** For simplicity, external predictors (covariates) are not supported in this demo. "
744
+ "However, you can include them by passing an `X_df` to StatsForecast (via `sf.forecast(...)`) "
745
+ "or to TimeGPT via `X_df=` in `nixtla_client.forecast(...)`.")
746
 
747
  with gr.Row():
748
  with gr.Column(scale=2):
749
  file_input = gr.File(label="Upload CSV file", file_types=[".csv"])
 
750
  download_btn = gr.Button("Download Sample Data", variant="secondary")
751
+ download_output = gr.File(label="Download Sample", visible=True)
752
  download_btn.click(fn=download_sample, outputs=download_output)
753
 
754
  with gr.Accordion("Data & Validation Settings", open=True):
755
  frequency = gr.Dropdown(
756
+ choices=["D", "W", "M", "H"], value="D", label="Frequency"
 
 
 
 
 
 
 
 
 
757
  )
 
 
758
  eval_strategy = gr.Radio(
759
+ choices=["Cross Validation", "Fixed Window"],
760
+ value="Cross Validation",
761
+ label="Validation Strategy"
762
  )
763
+ horizon = gr.Slider(1, 100, value=10, label="CV Horizon")
764
+ step_size = gr.Slider(1, 100, value=1, label="CV Step Size")
765
+ num_windows = gr.Number(value=3, label="Number of CV Windows")
766
+
767
+ with gr.Accordion("Models", open=True):
768
+ use_historical_avg = gr.Checkbox(label="Historical Average", value=True)
769
+ use_naive = gr.Checkbox(label="Naive", value=False)
770
+ use_seasonal_naive = gr.Checkbox(label="Seasonal Naive", value=False)
771
+ seasonality = gr.Number(value=12, label="Seasonality")
772
+ use_window_avg = gr.Checkbox(label="Window Average", value=False)
773
+ window_size = gr.Number(value=3, label="Window Size")
774
+ use_seasonal_window_avg = gr.Checkbox(label="Seasonal Window Average", value=False)
775
+ seasonal_window_size = gr.Number(value=12, label="Seasonal Window Size")
776
+ use_autoets = gr.Checkbox(label="AutoETS", value=False)
777
+ use_autoarima = gr.Checkbox(label="AutoARIMA", value=False)
778
+ gr.Markdown("### Transformer Models")
779
+ use_timegpt = gr.Checkbox(label="TimeGPT (Transformer)", value=False)
780
 
781
+ future_horizon = gr.Slider(1, 100, value=12, label="Future Forecast Horizon")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
782
 
783
  with gr.Column(scale=3):
784
+ eval_output = gr.Dataframe(label="Evaluation Metrics")
785
+ with gr.Tabs():
 
786
  with gr.TabItem("Validation Results"):
787
+ validation_output = gr.Dataframe(label="Validation Data")
788
  validation_plot = gr.Plot(label="Validation Plot")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
789
  with gr.TabItem("Future Forecast"):
790
+ forecast_output = gr.Dataframe(label="Future Forecast Data")
791
  forecast_plot = gr.Plot(label="Future Forecast Plot")
792
+ with gr.TabItem("Transformer Forecast"):
793
+ tg_output = gr.Dataframe(label="TimeGPT Forecast Data")
794
+ tg_plot = gr.Plot(label="TimeGPT Forecast Plot")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
795
  with gr.TabItem("Export Results"):
796
  export_files = gr.Files(label="Download Results")
797
 
798
+ message_output = gr.Markdown()
 
 
 
 
 
 
 
 
 
 
 
 
799
 
800
+ submit_btn = gr.Button("Run Validation and Forecast", variant="primary", size="lg")
801
  submit_btn.click(
802
  fn=run_forecast,
803
  inputs=[
804
  file_input, frequency, eval_strategy, horizon, step_size, num_windows,
805
  use_historical_avg, use_naive, use_seasonal_naive, seasonality,
806
  use_window_avg, window_size, use_seasonal_window_avg, seasonal_window_size,
807
+ use_autoets, use_autoarima, use_timegpt, future_horizon
808
  ],
809
+ outputs=[
810
+ eval_output,
811
+ validation_output, validation_plot,
812
+ forecast_output, forecast_plot,
813
+ tg_output, tg_plot,
814
+ export_files,
815
+ message_output
816
+ ]
817
  )
818
 
819
  if __name__ == "__main__":
820
+ app.launch(share=False)