fmegahed commited on
Commit
2be7c39
·
verified ·
1 Parent(s): d8f172b

Updated plotting function

Browse files
Files changed (1) hide show
  1. app.py +122 -10
app.py CHANGED
@@ -45,13 +45,37 @@ def load_data(file):
45
  except Exception as e:
46
  return None, f"Error loading data: {str(e)}"
47
 
48
- # Function to generate and return a plot
49
- def create_forecast_plot(forecast_df, original_df, title="Forecasting Results"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  plt.figure(figsize=(12, 7))
51
  unique_ids = forecast_df['unique_id'].unique()
52
- forecast_cols = [col for col in forecast_df.columns if col not in ['unique_id', 'ds', 'cutoff']]
53
 
54
  colors = plt.cm.tab10.colors
 
 
55
  min_cutoff = None
56
 
57
  for i, unique_id in enumerate(unique_ids):
@@ -59,20 +83,62 @@ def create_forecast_plot(forecast_df, original_df, title="Forecasting Results"):
59
  plt.plot(original_data['ds'], original_data['y'], 'k-', linewidth=2, label=f'{unique_id} (Actual)')
60
 
61
  forecast_data = forecast_df[forecast_df['unique_id'] == unique_id]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  for j, col in enumerate(forecast_cols):
63
  if col in forecast_data.columns:
 
 
 
 
 
64
  plt.plot(forecast_data['ds'], forecast_data[col],
65
  color=colors[j % len(colors)],
66
  linestyle='--',
67
  linewidth=1.5,
68
- label=f'{col}')
69
 
70
  plt.title(title, fontsize=16)
71
  plt.xlabel('Date', fontsize=12)
72
  plt.ylabel('Value', fontsize=12)
73
  plt.grid(True, alpha=0.3)
74
- plt.legend(loc='best')
75
- plt.tight_layout()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  # Format date labels better
78
  fig = plt.gcf()
@@ -82,13 +148,18 @@ def create_forecast_plot(forecast_df, original_df, title="Forecasting Results"):
82
  return fig
83
 
84
  # Function to create a plot for future forecasts
85
- def create_future_forecast_plot(forecast_df, original_df):
86
  plt.figure(figsize=(12, 7))
87
  unique_ids = forecast_df['unique_id'].unique()
88
  forecast_cols = [col for col in forecast_df.columns if col not in ['unique_id', 'ds']]
89
 
90
  colors = plt.cm.tab10.colors
91
 
 
 
 
 
 
92
  for i, unique_id in enumerate(unique_ids):
93
  # Plot historical data
94
  original_data = original_df[original_df['unique_id'] == unique_id]
@@ -102,20 +173,61 @@ def create_future_forecast_plot(forecast_df, original_df):
102
  forecast_start = forecast_data['ds'].min()
103
  plt.axvline(x=forecast_start, color='gray', linestyle='--', alpha=0.5)
104
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  for j, col in enumerate(forecast_cols):
106
  if col in forecast_data.columns:
 
 
 
 
 
107
  plt.plot(forecast_data['ds'], forecast_data[col],
108
  color=colors[j % len(colors)],
109
  linestyle='--',
110
  linewidth=1.5,
111
- label=f'{col}')
112
 
113
  plt.title('Future Forecast', fontsize=16)
114
  plt.xlabel('Date', fontsize=12)
115
  plt.ylabel('Value', fontsize=12)
116
  plt.grid(True, alpha=0.3)
117
- plt.legend(loc='best')
118
- plt.tight_layout()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
  # Format date labels better
121
  fig = plt.gcf()
 
45
  except Exception as e:
46
  return None, f"Error loading data: {str(e)}"
47
 
48
+ # Helper function to calculate date offset based on frequency and horizon
49
+ def calculate_date_offset(freq, horizon):
50
+ """Calculate a timedelta based on frequency code and horizon"""
51
+ if freq == 'H':
52
+ return pd.Timedelta(hours=horizon)
53
+ elif freq == 'D':
54
+ return pd.Timedelta(days=horizon)
55
+ elif freq == 'B':
56
+ # For business days, use approximately 1.4x multiplier to account for weekends
57
+ return pd.Timedelta(days=int(horizon * 1.4))
58
+ elif freq == 'WS':
59
+ return pd.Timedelta(weeks=horizon)
60
+ elif freq == 'MS':
61
+ return pd.Timedelta(days=horizon * 30) # Approximate
62
+ elif freq == 'QS':
63
+ return pd.Timedelta(days=horizon * 90) # Approximate
64
+ elif freq == 'YS':
65
+ return pd.Timedelta(days=horizon * 365) # Approximate
66
+ else:
67
+ # Default fallback
68
+ return pd.Timedelta(days=horizon)
69
+
70
+ # Function to generate and return a plot for validation results
71
+ def create_forecast_plot(forecast_df, original_df, title="Forecasting Results", horizon=None, freq='D'):
72
  plt.figure(figsize=(12, 7))
73
  unique_ids = forecast_df['unique_id'].unique()
74
+ forecast_cols = [col for col in forecast_df.columns if col not in ['unique_id', 'ds', 'cutoff', 'y']]
75
 
76
  colors = plt.cm.tab10.colors
77
+
78
+ # Track min and max dates for x-axis limits
79
  min_cutoff = None
80
 
81
  for i, unique_id in enumerate(unique_ids):
 
83
  plt.plot(original_data['ds'], original_data['y'], 'k-', linewidth=2, label=f'{unique_id} (Actual)')
84
 
85
  forecast_data = forecast_df[forecast_df['unique_id'] == unique_id]
86
+
87
+ # Find the earliest cutoff date if available
88
+ if 'cutoff' in forecast_data.columns:
89
+ cutoffs = pd.to_datetime(forecast_data['cutoff'].unique())
90
+ if len(cutoffs) > 0:
91
+ earliest_cutoff = cutoffs.min()
92
+ if min_cutoff is None or earliest_cutoff < min_cutoff:
93
+ min_cutoff = earliest_cutoff
94
+
95
+ # Add vertical line at each cutoff
96
+ for cutoff in cutoffs:
97
+ plt.axvline(x=cutoff, color='gray', linestyle='--', alpha=0.4)
98
+
99
+ # Plot main prediction lines
100
  for j, col in enumerate(forecast_cols):
101
  if col in forecast_data.columns:
102
+ # Clean up model name for legend
103
+ model_name = col.replace('_', ' ').title()
104
+ if model_name == 'Timegpt':
105
+ model_name = 'TimeGPT'
106
+
107
  plt.plot(forecast_data['ds'], forecast_data[col],
108
  color=colors[j % len(colors)],
109
  linestyle='--',
110
  linewidth=1.5,
111
+ label=f'{model_name}')
112
 
113
  plt.title(title, fontsize=16)
114
  plt.xlabel('Date', fontsize=12)
115
  plt.ylabel('Value', fontsize=12)
116
  plt.grid(True, alpha=0.3)
117
+
118
+ # Better legend with smaller font and outside placement
119
+ plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.15), ncol=3, fontsize=10)
120
+ plt.tight_layout(rect=[0, 0.05, 1, 0.95]) # Adjust layout to make room for legend
121
+
122
+ # Set x-axis limits based on cutoff and horizon
123
+ if min_cutoff is not None and horizon is not None:
124
+ # Calculate date offset based on frequency and horizon
125
+ date_offset = calculate_date_offset(freq, horizon)
126
+
127
+ # Calculate start date as 'horizon' units before the first cutoff
128
+ start_date = min_cutoff - date_offset
129
+
130
+ # Find max date from forecast
131
+ max_date = forecast_df['ds'].max()
132
+
133
+ plt.xlim(start_date, max_date)
134
+
135
+ # Add an annotation for the training/test split
136
+ plt.annotate('Training | Test',
137
+ xy=(min_cutoff, plt.ylim()[0]),
138
+ xytext=(0, -40),
139
+ textcoords='offset points',
140
+ horizontalalignment='center',
141
+ fontsize=10)
142
 
143
  # Format date labels better
144
  fig = plt.gcf()
 
148
  return fig
149
 
150
  # Function to create a plot for future forecasts
151
+ def create_future_forecast_plot(forecast_df, original_df, horizon=None, freq='D'):
152
  plt.figure(figsize=(12, 7))
153
  unique_ids = forecast_df['unique_id'].unique()
154
  forecast_cols = [col for col in forecast_df.columns if col not in ['unique_id', 'ds']]
155
 
156
  colors = plt.cm.tab10.colors
157
 
158
+ # Track the forecast start date (min of forecast data)
159
+ forecast_start = None
160
+ if not forecast_df.empty:
161
+ forecast_start = pd.to_datetime(forecast_df['ds'].min())
162
+
163
  for i, unique_id in enumerate(unique_ids):
164
  # Plot historical data
165
  original_data = original_df[original_df['unique_id'] == unique_id]
 
173
  forecast_start = forecast_data['ds'].min()
174
  plt.axvline(x=forecast_start, color='gray', linestyle='--', alpha=0.5)
175
 
176
+ # Add a shaded area for the forecast period
177
+ plt.axvspan(forecast_start, forecast_data['ds'].max(), alpha=0.1, color='blue')
178
+
179
+ # Annotate the split point
180
+ plt.annotate('Historical | Forecast',
181
+ xy=(forecast_start, plt.ylim()[0]),
182
+ xytext=(0, -40),
183
+ textcoords='offset points',
184
+ horizontalalignment='center',
185
+ fontsize=10)
186
+
187
+ # Plot main prediction lines
188
  for j, col in enumerate(forecast_cols):
189
  if col in forecast_data.columns:
190
+ # Clean up model name for legend
191
+ model_name = col.replace('_', ' ').title()
192
+ if model_name == 'Timegpt':
193
+ model_name = 'TimeGPT'
194
+
195
  plt.plot(forecast_data['ds'], forecast_data[col],
196
  color=colors[j % len(colors)],
197
  linestyle='--',
198
  linewidth=1.5,
199
+ label=f'{model_name}')
200
 
201
  plt.title('Future Forecast', fontsize=16)
202
  plt.xlabel('Date', fontsize=12)
203
  plt.ylabel('Value', fontsize=12)
204
  plt.grid(True, alpha=0.3)
205
+
206
+ # Better legend with smaller font and outside placement
207
+ plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.15), ncol=3, fontsize=10)
208
+ plt.tight_layout(rect=[0, 0.05, 1, 0.95]) # Adjust layout to make room for legend
209
+
210
+ # Set x-axis limits based on forecast start and horizon
211
+ if forecast_start is not None and horizon is not None:
212
+ # Calculate date offset based on frequency and horizon
213
+ date_offset = calculate_date_offset(freq, horizon)
214
+
215
+ # Calculate start date as 'horizon' units before the forecast start
216
+ start_date = forecast_start - date_offset
217
+
218
+ # Get the last date from historical data that's before or at the start_date
219
+ historical_dates = pd.to_datetime(original_df['ds'])
220
+ historical_dates_before_start = historical_dates[historical_dates <= start_date]
221
+
222
+ if not historical_dates_before_start.empty:
223
+ # Use the last available date in the historical data that's before our calculated start_date
224
+ adjusted_start_date = historical_dates_before_start.max()
225
+ else:
226
+ # Fallback to using the original start_date
227
+ adjusted_start_date = start_date
228
+
229
+ # Set the x-axis limits
230
+ plt.xlim(adjusted_start_date, forecast_df['ds'].max())
231
 
232
  # Format date labels better
233
  fig = plt.gcf()