Spaces:
Sleeping
Sleeping
Updated plotting function
Browse files
app.py
CHANGED
|
@@ -45,13 +45,37 @@ def load_data(file):
|
|
| 45 |
except Exception as e:
|
| 46 |
return None, f"Error loading data: {str(e)}"
|
| 47 |
|
| 48 |
-
#
|
| 49 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
plt.figure(figsize=(12, 7))
|
| 51 |
unique_ids = forecast_df['unique_id'].unique()
|
| 52 |
-
forecast_cols = [col for col in forecast_df.columns if col not in ['unique_id', 'ds', 'cutoff']]
|
| 53 |
|
| 54 |
colors = plt.cm.tab10.colors
|
|
|
|
|
|
|
| 55 |
min_cutoff = None
|
| 56 |
|
| 57 |
for i, unique_id in enumerate(unique_ids):
|
|
@@ -59,20 +83,62 @@ def create_forecast_plot(forecast_df, original_df, title="Forecasting Results"):
|
|
| 59 |
plt.plot(original_data['ds'], original_data['y'], 'k-', linewidth=2, label=f'{unique_id} (Actual)')
|
| 60 |
|
| 61 |
forecast_data = forecast_df[forecast_df['unique_id'] == unique_id]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
for j, col in enumerate(forecast_cols):
|
| 63 |
if col in forecast_data.columns:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
plt.plot(forecast_data['ds'], forecast_data[col],
|
| 65 |
color=colors[j % len(colors)],
|
| 66 |
linestyle='--',
|
| 67 |
linewidth=1.5,
|
| 68 |
-
label=f'{
|
| 69 |
|
| 70 |
plt.title(title, fontsize=16)
|
| 71 |
plt.xlabel('Date', fontsize=12)
|
| 72 |
plt.ylabel('Value', fontsize=12)
|
| 73 |
plt.grid(True, alpha=0.3)
|
| 74 |
-
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
|
| 77 |
# Format date labels better
|
| 78 |
fig = plt.gcf()
|
|
@@ -82,13 +148,18 @@ def create_forecast_plot(forecast_df, original_df, title="Forecasting Results"):
|
|
| 82 |
return fig
|
| 83 |
|
| 84 |
# Function to create a plot for future forecasts
|
| 85 |
-
def create_future_forecast_plot(forecast_df, original_df):
|
| 86 |
plt.figure(figsize=(12, 7))
|
| 87 |
unique_ids = forecast_df['unique_id'].unique()
|
| 88 |
forecast_cols = [col for col in forecast_df.columns if col not in ['unique_id', 'ds']]
|
| 89 |
|
| 90 |
colors = plt.cm.tab10.colors
|
| 91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
for i, unique_id in enumerate(unique_ids):
|
| 93 |
# Plot historical data
|
| 94 |
original_data = original_df[original_df['unique_id'] == unique_id]
|
|
@@ -102,20 +173,61 @@ def create_future_forecast_plot(forecast_df, original_df):
|
|
| 102 |
forecast_start = forecast_data['ds'].min()
|
| 103 |
plt.axvline(x=forecast_start, color='gray', linestyle='--', alpha=0.5)
|
| 104 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
for j, col in enumerate(forecast_cols):
|
| 106 |
if col in forecast_data.columns:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
plt.plot(forecast_data['ds'], forecast_data[col],
|
| 108 |
color=colors[j % len(colors)],
|
| 109 |
linestyle='--',
|
| 110 |
linewidth=1.5,
|
| 111 |
-
label=f'{
|
| 112 |
|
| 113 |
plt.title('Future Forecast', fontsize=16)
|
| 114 |
plt.xlabel('Date', fontsize=12)
|
| 115 |
plt.ylabel('Value', fontsize=12)
|
| 116 |
plt.grid(True, alpha=0.3)
|
| 117 |
-
|
| 118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
# Format date labels better
|
| 121 |
fig = plt.gcf()
|
|
|
|
| 45 |
except Exception as e:
|
| 46 |
return None, f"Error loading data: {str(e)}"
|
| 47 |
|
| 48 |
+
# Helper function to calculate date offset based on frequency and horizon
|
| 49 |
+
def calculate_date_offset(freq, horizon):
|
| 50 |
+
"""Calculate a timedelta based on frequency code and horizon"""
|
| 51 |
+
if freq == 'H':
|
| 52 |
+
return pd.Timedelta(hours=horizon)
|
| 53 |
+
elif freq == 'D':
|
| 54 |
+
return pd.Timedelta(days=horizon)
|
| 55 |
+
elif freq == 'B':
|
| 56 |
+
# For business days, use approximately 1.4x multiplier to account for weekends
|
| 57 |
+
return pd.Timedelta(days=int(horizon * 1.4))
|
| 58 |
+
elif freq == 'WS':
|
| 59 |
+
return pd.Timedelta(weeks=horizon)
|
| 60 |
+
elif freq == 'MS':
|
| 61 |
+
return pd.Timedelta(days=horizon * 30) # Approximate
|
| 62 |
+
elif freq == 'QS':
|
| 63 |
+
return pd.Timedelta(days=horizon * 90) # Approximate
|
| 64 |
+
elif freq == 'YS':
|
| 65 |
+
return pd.Timedelta(days=horizon * 365) # Approximate
|
| 66 |
+
else:
|
| 67 |
+
# Default fallback
|
| 68 |
+
return pd.Timedelta(days=horizon)
|
| 69 |
+
|
| 70 |
+
# Function to generate and return a plot for validation results
|
| 71 |
+
def create_forecast_plot(forecast_df, original_df, title="Forecasting Results", horizon=None, freq='D'):
|
| 72 |
plt.figure(figsize=(12, 7))
|
| 73 |
unique_ids = forecast_df['unique_id'].unique()
|
| 74 |
+
forecast_cols = [col for col in forecast_df.columns if col not in ['unique_id', 'ds', 'cutoff', 'y']]
|
| 75 |
|
| 76 |
colors = plt.cm.tab10.colors
|
| 77 |
+
|
| 78 |
+
# Track min and max dates for x-axis limits
|
| 79 |
min_cutoff = None
|
| 80 |
|
| 81 |
for i, unique_id in enumerate(unique_ids):
|
|
|
|
| 83 |
plt.plot(original_data['ds'], original_data['y'], 'k-', linewidth=2, label=f'{unique_id} (Actual)')
|
| 84 |
|
| 85 |
forecast_data = forecast_df[forecast_df['unique_id'] == unique_id]
|
| 86 |
+
|
| 87 |
+
# Find the earliest cutoff date if available
|
| 88 |
+
if 'cutoff' in forecast_data.columns:
|
| 89 |
+
cutoffs = pd.to_datetime(forecast_data['cutoff'].unique())
|
| 90 |
+
if len(cutoffs) > 0:
|
| 91 |
+
earliest_cutoff = cutoffs.min()
|
| 92 |
+
if min_cutoff is None or earliest_cutoff < min_cutoff:
|
| 93 |
+
min_cutoff = earliest_cutoff
|
| 94 |
+
|
| 95 |
+
# Add vertical line at each cutoff
|
| 96 |
+
for cutoff in cutoffs:
|
| 97 |
+
plt.axvline(x=cutoff, color='gray', linestyle='--', alpha=0.4)
|
| 98 |
+
|
| 99 |
+
# Plot main prediction lines
|
| 100 |
for j, col in enumerate(forecast_cols):
|
| 101 |
if col in forecast_data.columns:
|
| 102 |
+
# Clean up model name for legend
|
| 103 |
+
model_name = col.replace('_', ' ').title()
|
| 104 |
+
if model_name == 'Timegpt':
|
| 105 |
+
model_name = 'TimeGPT'
|
| 106 |
+
|
| 107 |
plt.plot(forecast_data['ds'], forecast_data[col],
|
| 108 |
color=colors[j % len(colors)],
|
| 109 |
linestyle='--',
|
| 110 |
linewidth=1.5,
|
| 111 |
+
label=f'{model_name}')
|
| 112 |
|
| 113 |
plt.title(title, fontsize=16)
|
| 114 |
plt.xlabel('Date', fontsize=12)
|
| 115 |
plt.ylabel('Value', fontsize=12)
|
| 116 |
plt.grid(True, alpha=0.3)
|
| 117 |
+
|
| 118 |
+
# Better legend with smaller font and outside placement
|
| 119 |
+
plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.15), ncol=3, fontsize=10)
|
| 120 |
+
plt.tight_layout(rect=[0, 0.05, 1, 0.95]) # Adjust layout to make room for legend
|
| 121 |
+
|
| 122 |
+
# Set x-axis limits based on cutoff and horizon
|
| 123 |
+
if min_cutoff is not None and horizon is not None:
|
| 124 |
+
# Calculate date offset based on frequency and horizon
|
| 125 |
+
date_offset = calculate_date_offset(freq, horizon)
|
| 126 |
+
|
| 127 |
+
# Calculate start date as 'horizon' units before the first cutoff
|
| 128 |
+
start_date = min_cutoff - date_offset
|
| 129 |
+
|
| 130 |
+
# Find max date from forecast
|
| 131 |
+
max_date = forecast_df['ds'].max()
|
| 132 |
+
|
| 133 |
+
plt.xlim(start_date, max_date)
|
| 134 |
+
|
| 135 |
+
# Add an annotation for the training/test split
|
| 136 |
+
plt.annotate('Training | Test',
|
| 137 |
+
xy=(min_cutoff, plt.ylim()[0]),
|
| 138 |
+
xytext=(0, -40),
|
| 139 |
+
textcoords='offset points',
|
| 140 |
+
horizontalalignment='center',
|
| 141 |
+
fontsize=10)
|
| 142 |
|
| 143 |
# Format date labels better
|
| 144 |
fig = plt.gcf()
|
|
|
|
| 148 |
return fig
|
| 149 |
|
| 150 |
# Function to create a plot for future forecasts
|
| 151 |
+
def create_future_forecast_plot(forecast_df, original_df, horizon=None, freq='D'):
|
| 152 |
plt.figure(figsize=(12, 7))
|
| 153 |
unique_ids = forecast_df['unique_id'].unique()
|
| 154 |
forecast_cols = [col for col in forecast_df.columns if col not in ['unique_id', 'ds']]
|
| 155 |
|
| 156 |
colors = plt.cm.tab10.colors
|
| 157 |
|
| 158 |
+
# Track the forecast start date (min of forecast data)
|
| 159 |
+
forecast_start = None
|
| 160 |
+
if not forecast_df.empty:
|
| 161 |
+
forecast_start = pd.to_datetime(forecast_df['ds'].min())
|
| 162 |
+
|
| 163 |
for i, unique_id in enumerate(unique_ids):
|
| 164 |
# Plot historical data
|
| 165 |
original_data = original_df[original_df['unique_id'] == unique_id]
|
|
|
|
| 173 |
forecast_start = forecast_data['ds'].min()
|
| 174 |
plt.axvline(x=forecast_start, color='gray', linestyle='--', alpha=0.5)
|
| 175 |
|
| 176 |
+
# Add a shaded area for the forecast period
|
| 177 |
+
plt.axvspan(forecast_start, forecast_data['ds'].max(), alpha=0.1, color='blue')
|
| 178 |
+
|
| 179 |
+
# Annotate the split point
|
| 180 |
+
plt.annotate('Historical | Forecast',
|
| 181 |
+
xy=(forecast_start, plt.ylim()[0]),
|
| 182 |
+
xytext=(0, -40),
|
| 183 |
+
textcoords='offset points',
|
| 184 |
+
horizontalalignment='center',
|
| 185 |
+
fontsize=10)
|
| 186 |
+
|
| 187 |
+
# Plot main prediction lines
|
| 188 |
for j, col in enumerate(forecast_cols):
|
| 189 |
if col in forecast_data.columns:
|
| 190 |
+
# Clean up model name for legend
|
| 191 |
+
model_name = col.replace('_', ' ').title()
|
| 192 |
+
if model_name == 'Timegpt':
|
| 193 |
+
model_name = 'TimeGPT'
|
| 194 |
+
|
| 195 |
plt.plot(forecast_data['ds'], forecast_data[col],
|
| 196 |
color=colors[j % len(colors)],
|
| 197 |
linestyle='--',
|
| 198 |
linewidth=1.5,
|
| 199 |
+
label=f'{model_name}')
|
| 200 |
|
| 201 |
plt.title('Future Forecast', fontsize=16)
|
| 202 |
plt.xlabel('Date', fontsize=12)
|
| 203 |
plt.ylabel('Value', fontsize=12)
|
| 204 |
plt.grid(True, alpha=0.3)
|
| 205 |
+
|
| 206 |
+
# Better legend with smaller font and outside placement
|
| 207 |
+
plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.15), ncol=3, fontsize=10)
|
| 208 |
+
plt.tight_layout(rect=[0, 0.05, 1, 0.95]) # Adjust layout to make room for legend
|
| 209 |
+
|
| 210 |
+
# Set x-axis limits based on forecast start and horizon
|
| 211 |
+
if forecast_start is not None and horizon is not None:
|
| 212 |
+
# Calculate date offset based on frequency and horizon
|
| 213 |
+
date_offset = calculate_date_offset(freq, horizon)
|
| 214 |
+
|
| 215 |
+
# Calculate start date as 'horizon' units before the forecast start
|
| 216 |
+
start_date = forecast_start - date_offset
|
| 217 |
+
|
| 218 |
+
# Get the last date from historical data that's before or at the start_date
|
| 219 |
+
historical_dates = pd.to_datetime(original_df['ds'])
|
| 220 |
+
historical_dates_before_start = historical_dates[historical_dates <= start_date]
|
| 221 |
+
|
| 222 |
+
if not historical_dates_before_start.empty:
|
| 223 |
+
# Use the last available date in the historical data that's before our calculated start_date
|
| 224 |
+
adjusted_start_date = historical_dates_before_start.max()
|
| 225 |
+
else:
|
| 226 |
+
# Fallback to using the original start_date
|
| 227 |
+
adjusted_start_date = start_date
|
| 228 |
+
|
| 229 |
+
# Set the x-axis limits
|
| 230 |
+
plt.xlim(adjusted_start_date, forecast_df['ds'].max())
|
| 231 |
|
| 232 |
# Format date labels better
|
| 233 |
fig = plt.gcf()
|