Spaces:
Running
Running
Fixing the indentation
Browse files
app.py
CHANGED
|
@@ -15,7 +15,7 @@ from statsforecast.models import (
|
|
| 15 |
)
|
| 16 |
|
| 17 |
from utilsforecast.evaluation import evaluate
|
| 18 |
-
from utilsforecast.losses import *
|
| 19 |
|
| 20 |
# Function to load and process uploaded CSV
|
| 21 |
def load_data(file):
|
|
@@ -33,20 +33,15 @@ def load_data(file):
|
|
| 33 |
except Exception as e:
|
| 34 |
return None, f"Error loading data: {str(e)}"
|
| 35 |
|
| 36 |
-
|
| 37 |
-
# Global store to hold cross-validation forecasts
|
| 38 |
-
forecast_store = {}
|
| 39 |
-
|
| 40 |
-
# Function to generate and return a plot
|
| 41 |
-
|
| 42 |
def create_forecast_plot(forecast_df, original_df, window=None):
|
| 43 |
plt.figure(figsize=(10, 6))
|
| 44 |
unique_ids = forecast_df['unique_id'].unique()
|
| 45 |
-
if window is not None and 'cutoff' in forecast_df.columns:
|
| 46 |
-
forecast_df = forecast_df[forecast_df['cutoff'] == window]
|
| 47 |
-
|
| 48 |
forecast_cols = [col for col in forecast_df.columns if col not in ['unique_id', 'ds', 'cutoff']]
|
| 49 |
|
|
|
|
|
|
|
|
|
|
| 50 |
for unique_id in unique_ids:
|
| 51 |
original_data = original_df[original_df['unique_id'] == unique_id]
|
| 52 |
plt.plot(original_data['ds'], original_data['y'], 'k-', label='Actual')
|
|
@@ -55,7 +50,7 @@ def create_forecast_plot(forecast_df, original_df, window=None):
|
|
| 55 |
if col in forecast_data.columns:
|
| 56 |
plt.plot(forecast_data['ds'], forecast_data[col], label=col)
|
| 57 |
|
| 58 |
-
plt.title('Forecasting Results')
|
| 59 |
plt.xlabel('Date')
|
| 60 |
plt.ylabel('Value')
|
| 61 |
plt.legend()
|
|
@@ -84,7 +79,7 @@ def run_forecast(
|
|
| 84 |
):
|
| 85 |
df, message = load_data(file)
|
| 86 |
if df is None:
|
| 87 |
-
return None, None, None, message
|
| 88 |
|
| 89 |
models = []
|
| 90 |
model_aliases = []
|
|
@@ -112,7 +107,7 @@ def run_forecast(
|
|
| 112 |
model_aliases.append('autoarima')
|
| 113 |
|
| 114 |
if not models:
|
| 115 |
-
return None, None, None, "Please select at least one forecasting model"
|
| 116 |
|
| 117 |
sf = StatsForecast(models=models, freq=frequency, n_jobs=-1)
|
| 118 |
|
|
@@ -121,7 +116,6 @@ def run_forecast(
|
|
| 121 |
cv_results = sf.cross_validation(df=df, h=horizon, step_size=step_size, n_windows=num_windows)
|
| 122 |
evaluation = evaluate(df=cv_results, metrics=[bias, mae, rmse, mape], models=model_aliases)
|
| 123 |
eval_df = pd.DataFrame(evaluation).reset_index()
|
| 124 |
-
forecast_store['cv'] = {'forecast': cv_results, 'original': df}
|
| 125 |
unique_cutoffs = sorted(cv_results['cutoff'].unique())
|
| 126 |
fig_forecast = create_forecast_plot(cv_results, df, window=unique_cutoffs[0])
|
| 127 |
return eval_df, cv_results, fig_forecast, "Cross validation completed successfully!", unique_cutoffs
|
|
@@ -129,7 +123,7 @@ def run_forecast(
|
|
| 129 |
else: # Fixed window
|
| 130 |
train_size = len(df) - horizon
|
| 131 |
if train_size <= 0:
|
| 132 |
-
return None, None, None, f"Not enough data for horizon={horizon}"
|
| 133 |
|
| 134 |
train_df = df.iloc[:train_size]
|
| 135 |
test_df = df.iloc[train_size:]
|
|
@@ -141,16 +135,7 @@ def run_forecast(
|
|
| 141 |
return eval_df, forecast, fig_forecast, "Fixed window evaluation completed successfully!", []
|
| 142 |
|
| 143 |
except Exception as e:
|
| 144 |
-
return None, None, None, f"Error during forecasting: {str(e)}"
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
# Function to update forecast plot for selected CV window
|
| 148 |
-
def update_forecast_plot(selected_window):
|
| 149 |
-
data = forecast_store.get('cv')
|
| 150 |
-
if not data:
|
| 151 |
-
return None
|
| 152 |
-
return create_forecast_plot(data['forecast'], data['original'], window=selected_window)
|
| 153 |
-
|
| 154 |
|
| 155 |
# Sample CSV file generation
|
| 156 |
def download_sample():
|
|
@@ -211,14 +196,11 @@ with gr.Blocks(title="StatsForecast Demo") as app:
|
|
| 211 |
submit_btn = gr.Button("Run Forecast")
|
| 212 |
|
| 213 |
with gr.Column(scale=3):
|
| 214 |
-
window_selector = gr.Dropdown(label='Select CV Window', choices=[], visible=False)
|
| 215 |
eval_output = gr.Dataframe(label="Evaluation Results")
|
| 216 |
forecast_output = gr.Dataframe(label="Forecast Data")
|
| 217 |
plot_output = gr.Plot(label="Forecast Plot")
|
| 218 |
message_output = gr.Textbox(label="Message")
|
| 219 |
-
|
| 220 |
-
def handle_forecast_output(eval_df, forecast_df, plot, msg, windows):
|
| 221 |
-
return eval_df, forecast_df, plot, msg, gr.update(choices=[str(w) for w in windows], visible=bool(windows), value=str(windows[0]) if windows else None)
|
| 222 |
|
| 223 |
submit_btn.click(
|
| 224 |
fn=run_forecast,
|
|
@@ -231,8 +213,7 @@ with gr.Blocks(title="StatsForecast Demo") as app:
|
|
| 231 |
outputs=[eval_output, forecast_output, plot_output, message_output, window_selector]
|
| 232 |
)
|
| 233 |
|
|
|
|
|
|
|
| 234 |
if __name__ == "__main__":
|
| 235 |
app.launch(share=False)
|
| 236 |
-
|
| 237 |
-
# Update plot when a window is selected
|
| 238 |
-
window_selector.change(fn=update_forecast_plot, inputs=window_selector, outputs=plot_output)
|
|
|
|
| 15 |
)
|
| 16 |
|
| 17 |
from utilsforecast.evaluation import evaluate
|
| 18 |
+
from utilsforecast.losses import * # Assuming you need the metrics like bias, mae, rmse, mape
|
| 19 |
|
| 20 |
# Function to load and process uploaded CSV
|
| 21 |
def load_data(file):
|
|
|
|
| 33 |
except Exception as e:
|
| 34 |
return None, f"Error loading data: {str(e)}"
|
| 35 |
|
| 36 |
+
# Function to generate and return a plot for a specific cross-validation window
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
def create_forecast_plot(forecast_df, original_df, window=None):
|
| 38 |
plt.figure(figsize=(10, 6))
|
| 39 |
unique_ids = forecast_df['unique_id'].unique()
|
|
|
|
|
|
|
|
|
|
| 40 |
forecast_cols = [col for col in forecast_df.columns if col not in ['unique_id', 'ds', 'cutoff']]
|
| 41 |
|
| 42 |
+
if window is not None and 'cutoff' in forecast_df.columns:
|
| 43 |
+
forecast_df = forecast_df[forecast_df['cutoff'] == window]
|
| 44 |
+
|
| 45 |
for unique_id in unique_ids:
|
| 46 |
original_data = original_df[original_df['unique_id'] == unique_id]
|
| 47 |
plt.plot(original_data['ds'], original_data['y'], 'k-', label='Actual')
|
|
|
|
| 50 |
if col in forecast_data.columns:
|
| 51 |
plt.plot(forecast_data['ds'], forecast_data[col], label=col)
|
| 52 |
|
| 53 |
+
plt.title(f'Forecasting Results{" (Window: " + str(window) + ")" if window else ""}')
|
| 54 |
plt.xlabel('Date')
|
| 55 |
plt.ylabel('Value')
|
| 56 |
plt.legend()
|
|
|
|
| 79 |
):
|
| 80 |
df, message = load_data(file)
|
| 81 |
if df is None:
|
| 82 |
+
return None, None, None, message, []
|
| 83 |
|
| 84 |
models = []
|
| 85 |
model_aliases = []
|
|
|
|
| 107 |
model_aliases.append('autoarima')
|
| 108 |
|
| 109 |
if not models:
|
| 110 |
+
return None, None, None, "Please select at least one forecasting model", []
|
| 111 |
|
| 112 |
sf = StatsForecast(models=models, freq=frequency, n_jobs=-1)
|
| 113 |
|
|
|
|
| 116 |
cv_results = sf.cross_validation(df=df, h=horizon, step_size=step_size, n_windows=num_windows)
|
| 117 |
evaluation = evaluate(df=cv_results, metrics=[bias, mae, rmse, mape], models=model_aliases)
|
| 118 |
eval_df = pd.DataFrame(evaluation).reset_index()
|
|
|
|
| 119 |
unique_cutoffs = sorted(cv_results['cutoff'].unique())
|
| 120 |
fig_forecast = create_forecast_plot(cv_results, df, window=unique_cutoffs[0])
|
| 121 |
return eval_df, cv_results, fig_forecast, "Cross validation completed successfully!", unique_cutoffs
|
|
|
|
| 123 |
else: # Fixed window
|
| 124 |
train_size = len(df) - horizon
|
| 125 |
if train_size <= 0:
|
| 126 |
+
return None, None, None, f"Not enough data for horizon={horizon}", []
|
| 127 |
|
| 128 |
train_df = df.iloc[:train_size]
|
| 129 |
test_df = df.iloc[train_size:]
|
|
|
|
| 135 |
return eval_df, forecast, fig_forecast, "Fixed window evaluation completed successfully!", []
|
| 136 |
|
| 137 |
except Exception as e:
|
| 138 |
+
return None, None, None, f"Error during forecasting: {str(e)}", []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
|
| 140 |
# Sample CSV file generation
|
| 141 |
def download_sample():
|
|
|
|
| 196 |
submit_btn = gr.Button("Run Forecast")
|
| 197 |
|
| 198 |
with gr.Column(scale=3):
|
|
|
|
| 199 |
eval_output = gr.Dataframe(label="Evaluation Results")
|
| 200 |
forecast_output = gr.Dataframe(label="Forecast Data")
|
| 201 |
plot_output = gr.Plot(label="Forecast Plot")
|
| 202 |
message_output = gr.Textbox(label="Message")
|
| 203 |
+
window_selector = gr.Dropdown(label="Select Forecast Window", choices=[], visible=False)
|
|
|
|
|
|
|
| 204 |
|
| 205 |
submit_btn.click(
|
| 206 |
fn=run_forecast,
|
|
|
|
| 213 |
outputs=[eval_output, forecast_output, plot_output, message_output, window_selector]
|
| 214 |
)
|
| 215 |
|
| 216 |
+
window_selector.change(fn=create_forecast_plot, inputs=window_selector, outputs=plot_output)
|
| 217 |
+
|
| 218 |
if __name__ == "__main__":
|
| 219 |
app.launch(share=False)
|
|
|
|
|
|
|
|
|