File size: 10,090 Bytes
9455ec6
 
 
fb6ca91
9455ec6
 
 
06ed069
9455ec6
 
 
 
 
 
 
 
 
c9b451d
 
97fbbe3
9455ec6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97fbbe3
d88cede
fb6ca91
 
6f155ab
 
 
 
 
 
 
 
 
fb6ca91
d88cede
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb6ca91
 
 
 
97fbbe3
 
 
 
9455ec6
 
6f155ab
9455ec6
 
 
 
 
 
 
 
 
 
 
 
 
d88cede
 
9455ec6
 
 
d88cede
9455ec6
 
fb6ca91
 
9455ec6
06ed069
fb6ca91
9455ec6
 
fb6ca91
9455ec6
188cf42
fb6ca91
9455ec6
 
fb6ca91
9455ec6
188cf42
fb6ca91
9455ec6
 
fb6ca91
9455ec6
 
fb6ca91
9455ec6
 
d88cede
9455ec6
 
fb6ca91
9455ec6
d88cede
9455ec6
 
9696d75
9455ec6
d88cede
97fbbe3
d88cede
30c7366
9455ec6
d88cede
 
 
 
 
 
 
 
 
fb6ca91
9455ec6
d88cede
c9b451d
97fbbe3
9455ec6
fb6ca91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97fbbe3
9455ec6
 
fb6ca91
9455ec6
 
 
 
97fbbe3
9455ec6
fb6ca91
9455ec6
 
d88cede
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9455ec6
 
d88cede
 
 
 
 
 
 
 
 
 
 
776c727
9455ec6
 
 
6f155ab
9455ec6
 
d88cede
9455ec6
d88cede
9455ec6
57f0f2b
9455ec6
97fbbe3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
import pandas as pd
import matplotlib.pyplot as plt
import gradio as gr
import tempfile

from statsforecast import StatsForecast
from statsforecast.models import (
    HistoricAverage,
    Naive,
    SeasonalNaive,
    WindowAverage,
    SeasonalWindowAverage,
    AutoETS,
    AutoARIMA
)

from utilsforecast.evaluation import evaluate
from utilsforecast.losses import *

# Function to load and process uploaded CSV
def load_data(file):
    if file is None:
        return None, "Please upload a CSV file"
    try:
        df = pd.read_csv(file)
        required_cols = ['unique_id', 'ds', 'y']
        missing_cols = [col for col in required_cols if col not in df.columns]
        if missing_cols:
            return None, f"Missing required columns: {', '.join(missing_cols)}"
        df['ds'] = pd.to_datetime(df['ds'])
        df = df.sort_values(['unique_id', 'ds'])
        return df, "Data loaded successfully!"
    except Exception as e:
        return None, f"Error loading data: {str(e)}"

# Function to generate and return a plot
def create_forecast_plot(forecast_df, original_df, title="Forecasting Results"):
    plt.figure(figsize=(10, 6))
    unique_ids = forecast_df['unique_id'].unique()
    forecast_cols = [col for col in forecast_df.columns if col not in ['unique_id', 'ds', 'cutoff']]

    for unique_id in unique_ids:
        original_data = original_df[original_df['unique_id'] == unique_id]
        plt.plot(original_data['ds'], original_data['y'], 'k-', label='Actual')
        forecast_data = forecast_df[forecast_df['unique_id'] == unique_id]
        for col in forecast_cols:
            if col in forecast_data.columns:
                plt.plot(forecast_data['ds'], forecast_data[col], label=col)

    plt.title(title)
    plt.xlabel('Date')
    plt.ylabel('Value')
    plt.legend()
    plt.grid(True)
    fig = plt.gcf()
    return fig

# Function to create a plot for future forecasts
def create_future_forecast_plot(forecast_df, original_df):
    plt.figure(figsize=(10, 6))
    unique_ids = forecast_df['unique_id'].unique()
    forecast_cols = [col for col in forecast_df.columns if col not in ['unique_id', 'ds']]

    for unique_id in unique_ids:
        # Plot historical data
        original_data = original_df[original_df['unique_id'] == unique_id]
        plt.plot(original_data['ds'], original_data['y'], 'k-', label='Historical')
        
        # Plot forecast data
        forecast_data = forecast_df[forecast_df['unique_id'] == unique_id]
        for col in forecast_cols:
            if col in forecast_data.columns:
                plt.plot(forecast_data['ds'], forecast_data[col], label=col)

    plt.title('Future Forecast')
    plt.xlabel('Date')
    plt.ylabel('Value')
    plt.legend()
    plt.grid(True)
    fig = plt.gcf()
    return fig

# Main forecasting logic
def run_forecast(
    file,
    frequency,
    eval_strategy,
    horizon,
    step_size,
    num_windows,
    use_historical_avg,
    use_naive,
    use_seasonal_naive,
    seasonality,
    use_window_avg,
    window_size,
    use_seasonal_window_avg,
    seasonal_window_size,
    use_autoets,
    use_autoarima,
    future_horizon
):
    df, message = load_data(file)
    if df is None:
        return None, None, None, None, None, message

    models = []
    model_aliases = []

    if use_historical_avg:
        models.append(HistoricAverage(alias='historical_average'))
        model_aliases.append('historical_average')
    if use_naive:
        models.append(Naive(alias='naive'))
        model_aliases.append('naive')
    if use_seasonal_naive:
        models.append(SeasonalNaive(season_length=seasonality, alias='seasonal_naive'))
        model_aliases.append('seasonal_naive')
    if use_window_avg:
        models.append(WindowAverage(window_size=window_size, alias='window_average'))
        model_aliases.append('window_average')
    if use_seasonal_window_avg:
        models.append(SeasonalWindowAverage(season_length=seasonality, window_size=seasonal_window_size, alias='seasonal_window_average'))
        model_aliases.append('seasonal_window_average')
    if use_autoets:
        models.append(AutoETS(alias='autoets'))
        model_aliases.append('autoets')
    if use_autoarima:
        models.append(AutoARIMA(alias='autoarima'))
        model_aliases.append('autoarima')

    if not models:
        return None, None, None, None, None, "Please select at least one forecasting model"

    sf = StatsForecast(models=models, freq=frequency, n_jobs=-1)

    try:
        # Run cross-validation
        if eval_strategy == "Cross Validation":
            cv_results = sf.cross_validation(df=df, h=horizon, step_size=step_size, n_windows=num_windows)
            evaluation = evaluate(df=cv_results, metrics=[bias, mae, rmse, mape], models=model_aliases)
            eval_df = pd.DataFrame(evaluation).reset_index()
            fig_validation = create_forecast_plot(cv_results, df, "Cross Validation Results")
        else:  # Fixed window
            cv_results = sf.cross_validation(df=df, h=horizon, step_size=10, n_windows=1)  # any step size for 1 window
            evaluation = evaluate(df=cv_results, metrics=[bias, mae, rmse, mape], models=model_aliases)
            eval_df = pd.DataFrame(evaluation).reset_index()
            fig_validation = create_forecast_plot(cv_results, df, "Fixed Window Validation Results")

        # Generate future forecasts
        fitted_sf = StatsForecast(models=models, freq=frequency, n_jobs=-1)
        fitted_sf.fit(df)
        future_forecasts = fitted_sf.forecast(h=future_horizon)
        fig_future = create_future_forecast_plot(future_forecasts, df)
        
        return eval_df, cv_results, fig_validation, future_forecasts, fig_future, "Analysis completed successfully!"

    except Exception as e:
        return None, None, None, None, None, f"Error during forecasting: {str(e)}"

# Sample CSV file generation
def download_sample():
    sample_data = """unique_id,ds,y
series1,2023-01-01,100
series1,2023-01-02,105
series1,2023-01-03,102
series1,2023-01-04,107
series1,2023-01-05,104
series1,2023-01-06,110
series1,2023-01-07,108
series1,2023-01-08,112
series1,2023-01-09,115
series1,2023-01-10,118
series1,2023-01-11,120
series1,2023-01-12,123
series1,2023-01-13,126
series1,2023-01-14,129
series1,2023-01-15,131
"""
    temp = tempfile.NamedTemporaryFile(delete=False, suffix=".csv", mode='w', newline='')
    temp.write(sample_data)
    temp.close()
    return temp.name

# Gradio interface
with gr.Blocks(title="StatsForecast Demo") as app:
    gr.Markdown("# 📈 StatsForecast Demo App")
    gr.Markdown("Upload a CSV with `unique_id`, `ds`, and `y` columns to apply forecasting models.")

    with gr.Row():
        with gr.Column(scale=2):
            file_input = gr.File(label="Upload CSV file", file_types=[".csv"])

            download_btn = gr.Button("Download Sample Data")
            download_output = gr.File(label="Click to download", visible=True)
            download_btn.click(fn=download_sample, outputs=download_output)

            with gr.Accordion("Data & Validation Settings", open=True):
                frequency = gr.Dropdown(choices=["H", "D", "WS", "MS", "QS", "YS"], label="Frequency", value="D")
                eval_strategy = gr.Radio(choices=["Fixed Window", "Cross Validation"], label="Evaluation Strategy", value="Cross Validation")
                horizon = gr.Slider(1, 100, value=10, step=1, label="Validation Horizon")
                step_size = gr.Slider(1, 50, value=10, step=1, label="Step Size")
                num_windows = gr.Slider(1, 20, value=3, step=1, label="Number of Windows")

            with gr.Accordion("Forecast Settings", open=True):
                future_horizon = gr.Slider(1, 100, value=20, step=1, label="Future Forecast Horizon")

            with gr.Accordion("Model Configuration", open=True):
                use_historical_avg = gr.Checkbox(label="Use Historical Average", value=True)
                use_naive = gr.Checkbox(label="Use Naive", value=True)
                
                with gr.Row():
                    use_seasonal_naive = gr.Checkbox(label="Use Seasonal Naive")
                    seasonality = gr.Number(label="Seasonality", value=10)
                
                with gr.Row():
                    use_window_avg = gr.Checkbox(label="Use Window Average")
                    window_size = gr.Number(label="Window Size", value=3)
                
                with gr.Row():
                    use_seasonal_window_avg = gr.Checkbox(label="Use Seasonal Window Average")
                    seasonal_window_size = gr.Number(label="Seasonal Window Size", value=2)
                
                use_autoets = gr.Checkbox(label="Use AutoETS")
                use_autoarima = gr.Checkbox(label="Use AutoARIMA")

            submit_btn = gr.Button("Run Forecast", variant="primary")

        with gr.Column(scale=3):
            message_output = gr.Textbox(label="Status Message")
            
            with gr.Tabs() as tabs:
                with gr.TabItem("Validation Results"):
                    eval_output = gr.Dataframe(label="Evaluation Metrics")
                    validation_output = gr.Dataframe(label="Validation Data")
                    validation_plot = gr.Plot(label="Validation Plot")
                
                with gr.TabItem("Future Forecast"):
                    forecast_output = gr.Dataframe(label="Future Forecast Data")
                    forecast_plot = gr.Plot(label="Future Forecast Plot")

    submit_btn.click(
        fn=run_forecast,
        inputs=[
            file_input, frequency, eval_strategy, horizon, step_size, num_windows,
            use_historical_avg, use_naive, use_seasonal_naive, seasonality,
            use_window_avg, window_size, use_seasonal_window_avg, seasonal_window_size,
            use_autoets, use_autoarima, future_horizon
        ],
        outputs=[eval_output, validation_output, validation_plot, forecast_output, forecast_plot, message_output]
    )

if __name__ == "__main__":
    app.launch(share=False)