Spaces:
Running
Running
| import gradio as gr | |
| import pandas as pd | |
| import io | |
| import torch | |
| import numpy as np | |
| from tirex import load_model | |
| import matplotlib.pyplot as plt | |
| from datetime import timedelta | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| # Load model (once) | |
| model = load_model("NX-AI/TiRex") | |
| def load_columns(file): | |
| if file is None: | |
| return (gr.Dropdown(choices=[], label="Select Time Column", interactive=True), | |
| gr.Dropdown(choices=[], label="Select Value Column", interactive=True), | |
| gr.Slider(minimum=1, maximum=1, value=1, step=1, label="Historical Start Index"), | |
| gr.Slider(minimum=1, maximum=1, value=1, step=1, label="Historical End Index")) | |
| try: | |
| # Handle file as path string (Gradio convention) | |
| with open(file, 'rb') as f: | |
| content = f.read() | |
| df_preview = pd.read_csv(io.BytesIO(content)) | |
| # All columns for time selection | |
| all_cols = df_preview.columns.tolist() | |
| time_choices = [(col, col) for col in all_cols] | |
| time_value = all_cols[0] if all_cols else None | |
| # Available numeric columns for forecast | |
| numeric_cols = df_preview.select_dtypes(include=['number']).columns.tolist() | |
| if numeric_cols: | |
| value_choices = [(col, col) for col in numeric_cols] | |
| value_value = numeric_cols[0] | |
| else: | |
| value_choices = [] | |
| value_value = None | |
| n_rows = len(df_preview) | |
| time_dropdown = gr.Dropdown( | |
| choices=time_choices, | |
| value=time_value, | |
| label="Select Time Column", | |
| interactive=True | |
| ) | |
| value_dropdown = gr.Dropdown( | |
| choices=value_choices, | |
| value=value_value, | |
| label="Select Value Column", | |
| interactive=True | |
| ) if value_choices else gr.Dropdown( | |
| choices=[], | |
| value=None, | |
| label="No numeric columns found", | |
| interactive=False | |
| ) | |
| start_slider = gr.Slider( | |
| minimum=1, maximum=n_rows, value=1, step=1, | |
| label="Historical Start Index" | |
| ) | |
| end_slider = gr.Slider( | |
| minimum=1, maximum=n_rows, value=n_rows, step=1, | |
| label="Historical End Index" | |
| ) | |
| return time_dropdown, value_dropdown, start_slider, end_slider | |
| except Exception as e: | |
| return (gr.Dropdown( | |
| choices=[], | |
| value=None, | |
| label=f"Error loading CSV: {str(e)}", | |
| interactive=False | |
| ), gr.Dropdown( | |
| choices=[], | |
| value=None, | |
| label=f"Error loading CSV: {str(e)}", | |
| interactive=False | |
| ), gr.Slider(minimum=1, maximum=1, value=1, step=1, label="Historical Start Index"), | |
| gr.Slider(minimum=1, maximum=1, value=1, step=1, label="Historical End Index")) | |
| def update_ma_visibility(add_ma): | |
| return gr.Slider(visible=add_ma) | |
| def run_forecast(file, time_col, selected_col, start_idx, end_idx, prediction_length, confidence, add_trendline, add_moving_average, ma_window, add_skew_viz): | |
| if file is None or time_col is None or selected_col is None: | |
| return None, "### Error\nPlease upload a CSV and select time and value columns!" | |
| try: | |
| # Handle file as path string (Gradio convention) | |
| with open(file, 'rb') as f: | |
| content = f.read() | |
| df = pd.read_csv(io.BytesIO(content)) | |
| # Validate columns exist | |
| if time_col not in df.columns or selected_col not in df.columns: | |
| return None, f"### Error\nSelected columns '{time_col}' or '{selected_col}' not found in CSV." | |
| # Rename selected columns | |
| df = df.rename(columns={time_col: 'date', selected_col: 'sales'}) | |
| # Validate | |
| required_cols = ['date', 'sales'] | |
| if not all(col in df.columns for col in required_cols): | |
| return None, f"### Error\nMissing renamed columns." | |
| # Prep data | |
| df['date'] = pd.to_datetime(df['date']) | |
| df = df.set_index('date').sort_index() | |
| full_len = len(df) | |
| context_start = max(0, int(start_idx) - 1) | |
| context_end = min(full_len, int(end_idx)) | |
| context_df = df.iloc[context_start:context_end] | |
| held_out_df = df.iloc[context_end:] if context_end < full_len else pd.DataFrame(index=pd.DatetimeIndex([]), columns=df.columns) | |
| if len(context_df) < 10: | |
| return None, "### Error\nNeed at least 10 data points in the selected historical range." | |
| context_series = context_df['sales'].dropna().values | |
| print(f"Loaded context: {len(context_series)} points from {context_df.index.min().date()} to {context_df.index.max().date()} (Column: {selected_col})") # For logs | |
| # Infer freq | |
| freq = pd.infer_freq(context_df.index) | |
| if freq is None: | |
| freq = 'D' | |
| print(f"Frequency: '{freq}'.") | |
| # Prep context | |
| context_len = min(len(context_series), 2048) | |
| context = torch.tensor(context_series[-context_len:]).unsqueeze(0).float() | |
| pred_len = prediction_length | |
| conf_level = confidence / 100.0 | |
| lower_alpha_slider = (1 - conf_level) / 2 | |
| upper_alpha_slider = 1 - lower_alpha_slider | |
| # Fixed inner level: 50% | |
| lower_alpha_50 = 0.25 | |
| upper_alpha_50 = 0.75 | |
| quantiles, mean = model.forecast(context=context, prediction_length=pred_len) | |
| # Median is always 50th percentile (index 4) | |
| median = quantiles[0, :, 4].numpy() | |
| # Extract quantiles array | |
| q = quantiles[0].detach().numpy() # (pred_len, 9) | |
| alphas = np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]) | |
| # Compute bounds for 50% and slider | |
| lower50 = np.zeros(pred_len) | |
| upper50 = np.zeros(pred_len) | |
| lower_slider = np.zeros(pred_len) | |
| upper_slider = np.zeros(pred_len) | |
| skew_ratios = np.zeros(pred_len) | |
| delta_skews = np.zeros(pred_len) | |
| skew_directions = [] | |
| epsilon = 1e-8 | |
| for t in range(pred_len): | |
| q_t = q[t] | |
| lower50[t] = np.interp(lower_alpha_50, alphas, q_t) | |
| upper50[t] = np.interp(upper_alpha_50, alphas, q_t) | |
| lower_slider[t] = np.interp(lower_alpha_slider, alphas, q_t) | |
| upper_slider[t] = np.interp(upper_alpha_slider, alphas, q_t) | |
| # Compute skew direction based on asymmetry around median | |
| med = median[t] | |
| upside_dist = upper_slider[t] - med | |
| downside_dist = med - lower_slider[t] | |
| total_dist = upside_dist + downside_dist + epsilon | |
| skew_ratios[t] = (upside_dist - downside_dist) / total_dist | |
| # Delta for momentum (shift from previous step) | |
| if t == 0: | |
| delta_skews[t] = 0.0 | |
| else: | |
| delta_skews[t] = skew_ratios[t] - skew_ratios[t-1] | |
| # Existing categorical (optional: derive from skew_ratio for compat) | |
| if skew_ratios[t] > 0.1: | |
| skew_directions.append("Upside") | |
| elif skew_ratios[t] < -0.1: | |
| skew_directions.append("Downside") | |
| else: | |
| skew_directions.append("Neutral") | |
| # Mean forecast | |
| mean_forecast = mean[0].detach().numpy() | |
| # Future dates | |
| last_date = context_df.index[-1] | |
| if freq == 'D': | |
| future_dates = pd.date_range(start=last_date + timedelta(days=1), periods=pred_len, freq='D') | |
| else: | |
| future_dates = pd.date_range(start=last_date + pd.DateOffset(1), periods=pred_len, freq=freq) | |
| pred_df = pd.DataFrame({ | |
| 'date': future_dates, | |
| 'predicted_sales_median': median, | |
| 'predicted_sales_lower': lower_slider, | |
| 'predicted_sales_upper': upper_slider, | |
| 'predicted_sales_mean': mean_forecast, | |
| 'skew_direction': skew_directions, | |
| 'skew_ratio': skew_ratios, | |
| 'delta_skew': delta_skews | |
| }).set_index('date') | |
| # Count skews for summary | |
| upside_count = sum(1 for r in skew_ratios if r > 0.1) | |
| downside_count = sum(1 for r in skew_ratios if r < -0.1) | |
| neutral_count = pred_len - upside_count - downside_count | |
| # NEW: Summary stats for skew momentum | |
| avg_skew = skew_ratios.mean() | |
| max_momentum_shift = abs(delta_skews).max() | |
| # Prepare markdown output (broken into smaller strings to avoid multiline f-string parsing issues) | |
| markdown_text = "### Summary\n" | |
| markdown_text += "- **Number of Historical Periods Used:** {} points\n".format(len(context_series)) | |
| markdown_text += "- **Held Out Periods:** {} points {}\n".format(len(held_out_df), "(Full Context Used)" if len(held_out_df) == 0 else "(For Validation)") | |
| markdown_text += "- **Prediction Length:** {} periods\n".format(pred_len) | |
| markdown_text += "- **Prediction Interval:** {}% (alphas: {:.3f} - {:.3f})\n".format(confidence, lower_alpha_slider, upper_alpha_slider) | |
| markdown_text += "- **Sum of Median Predicted Values:** {:.2f}\n".format(pred_df['predicted_sales_median'].sum()) | |
| markdown_text += "- **Sum of Mean Predicted Values:** {:.2f}\n".format(pred_df['predicted_sales_mean'].sum()) | |
| markdown_text += "- **Skew Distribution:** {} Upside, {} Downside, {} Neutral\n".format(upside_count, downside_count, neutral_count) | |
| markdown_text += "- **Average Skew Ratio:** {:.3f} (momentum: max |Ξ| = {:.3f})\n\n".format(avg_skew, max_momentum_shift) | |
| forecast_table = "### TiRex Forecast Results (Median + {}% Prediction Interval)\n\n".format(confidence) | |
| forecast_table += "| Date | Median | Lower Bound | Upper Bound | Mean | Skew Direction | Skew Ratio | Ξ Skew |\n" | |
| forecast_table += "|------|--------|-------------|-------------|------|----------------|------------|--------|\n" | |
| for idx, row in pred_df.iterrows(): | |
| forecast_table += "| {} | {:.2f} | {:.2f} | {:.2f} | {:.2f} | {} | {:.3f} | {:.3f} |\n".format( | |
| idx.strftime('%Y-%m-%d'), | |
| row['predicted_sales_median'], | |
| row['predicted_sales_lower'], | |
| row['predicted_sales_upper'], | |
| row['predicted_sales_mean'], | |
| row['skew_direction'], | |
| row['skew_ratio'], | |
| row['delta_skew'] | |
| ) | |
| sample_data = "### Sample Historical Data (Context)\n" | |
| sample_data += "```\n" + context_df.head().to_string() + "\n```" | |
| markdown_text += f'\n<details><summary>Click to expand Forecast Table</summary>\n\n{forecast_table}\n</details>\n\n' | |
| markdown_text += f'<details><summary>Click to expand Sample Historical Data</summary>\n\n{sample_data}\n</details>' | |
| # Create plot (single subplot) | |
| fig, ax = plt.subplots(figsize=(14, 7)) | |
| fig.set_dpi(300) # High resolution for PNG zoom | |
| # Historical and held-out | |
| ax.plot(context_df.index, context_df['sales'], label='Historical Data', color='#1f77b4', linewidth=1.5, alpha=0.8) | |
| if not held_out_df.empty: | |
| ax.plot(held_out_df.index, held_out_df['sales'], label='Held Out Actual (Validation)', color='#2ca02c', linestyle=':', linewidth=2) | |
| if add_trendline: | |
| x = np.arange(len(context_df)) | |
| y = context_df['sales'].values | |
| if len(x) > 1: | |
| coeffs = np.polyfit(x, y, 1) | |
| trend = np.polyval(coeffs, x) | |
| ax.plot(context_df.index, trend, label='Trendline', color='black', linestyle='-', linewidth=1.5) | |
| if add_moving_average: | |
| window = int(ma_window) | |
| ma = context_df['sales'].rolling(window=window, min_periods=1).mean() | |
| ax.plot(context_df.index, ma, label=f'Moving Average ({window} periods)', color='purple', linewidth=2) | |
| # Median forecast: regular green line | |
| ax.plot(pred_df.index, median, label='Median Forecast', color='green', linewidth=2, alpha=0.9) | |
| # Fan chart: non-overlapping bands | |
| # Inner 50% (lightest, center) | |
| ax.fill_between(pred_df.index, lower50, upper50, | |
| color='#d62728', alpha=0.1, label='50% Prediction Interval') | |
| # Wings: between 50% and slider level (medium) | |
| ax.fill_between(pred_df.index, lower_slider, lower50, | |
| color='#d62728', alpha=0.3) | |
| ax.fill_between(pred_df.index, upper50, upper_slider, | |
| color='#d62728', alpha=0.3, label=f'{confidence}% Prediction Interval') | |
| # Optional skew visualization on twin axis (light lines) | |
| skew_handles = [] | |
| if add_skew_viz: | |
| ax2 = ax.twinx() | |
| # Light line for skew_ratio | |
| line1, = ax2.plot(pred_df.index, skew_ratios, label='Skew Ratio', color='lightblue', linewidth=1, alpha=0.6) | |
| skew_handles.append(line1) | |
| # Light line for delta_skew (momentum) - milder color | |
| line2, = ax2.plot(pred_df.index, delta_skews, label='Skew Momentum', color='lightgray', linewidth=1, alpha=0.6) | |
| skew_handles.append(line2) | |
| ax2.set_ylabel('Skew (-1 to 1)', color='lightblue') | |
| ax2.tick_params(colors='lightblue') | |
| # Set limits for visibility | |
| ax2.set_ylim(-1.2, 1.2) | |
| ax.set_title(f'{selected_col} Forecast with TiRex (Context: {context_start+1}-{context_end}, Horizon: {pred_len})', fontsize=16, fontweight='bold') | |
| ax.set_xlabel('Date', fontsize=12) | |
| ax.set_ylabel(selected_col, fontsize=12) | |
| # Combined legend to avoid overlap | |
| if add_skew_viz: | |
| handles1, labels1 = ax.get_legend_handles_labels() | |
| handles2, labels2 = ax2.get_legend_handles_labels() | |
| ax.legend(handles1 + handles2, labels1 + labels2, fontsize=10, loc='upper left') | |
| else: | |
| ax.legend(fontsize=10) | |
| ax.tick_params(axis='x', rotation=45) | |
| plt.tight_layout() | |
| return fig, markdown_text | |
| except Exception as e: | |
| return None, f"### Error\n{str(e)}\n\nTips: Ensure the time column can be parsed as dates; check NaNs/zeros; ensure data is valid." | |
| # Create the Gradio interface | |
| with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="red"), title="TiRex Forecaster") as demo: | |
| gr.HTML(""" | |
| <link rel="preconnect" href="https://fonts.googleapis.com"> | |
| <link rel="preconnect" href="https://fonts.gstatic.com" crossorigin> | |
| <link href="https://fonts.googleapis.com/css2?family=Inter:[email protected]&display=swap" rel="stylesheet"> | |
| <style> | |
| :root { | |
| --font-family: Inter, ui-sans-serif, system-ui, sans-serif; | |
| } | |
| .gradio-container * { | |
| font-family: var(--font-family) !important; | |
| } | |
| </style> | |
| """) | |
| gr.Markdown(""" | |
| # TiRex Forecaster Dashboard | |
| Upload a CSV file with a time column and numeric columns. Select the time column and one numeric column to forecast future values using the TiRex model. | |
| """) | |
| with gr.Row(variant="panel"): | |
| with gr.Column(scale=1): | |
| csv_file = gr.File( | |
| file_types=[".csv"], | |
| label="Upload CSV File", | |
| elem_id="file_upload" | |
| ) | |
| gr.Markdown("The minimum effective input is around 128 time steps per series. Use a full context of 2048 steps for optimal performance.") | |
| time_dropdown = gr.Dropdown( | |
| choices=[], | |
| label="Select Time Column", | |
| interactive=True, | |
| elem_id="time_select" | |
| ) | |
| column_dropdown = gr.Dropdown( | |
| choices=[], | |
| label="Select Value Column", | |
| interactive=True, | |
| elem_id="column_select" | |
| ) | |
| start_slider = gr.Slider( | |
| minimum=1, maximum=1, value=1, step=1, | |
| label="Historical Start Index", | |
| elem_id="start_idx" | |
| ) | |
| end_slider = gr.Slider( | |
| minimum=1, maximum=1, value=1, step=1, | |
| label="Historical End Index", | |
| elem_id="end_idx" | |
| ) | |
| prediction_length = gr.Slider( | |
| minimum=1, maximum=720, value=100, step=1, | |
| label="Prediction Length", | |
| elem_id="pred_length" | |
| ) | |
| confidence = gr.Slider( | |
| minimum=50, maximum=95, value=80, step=5, | |
| label="Prediction Interval (%)", | |
| elem_id="confidence" | |
| ) | |
| trend_checkbox = gr.Checkbox( | |
| label="Add Trendline", | |
| value=False | |
| ) | |
| ma_checkbox = gr.Checkbox( | |
| label="Add Moving Average", | |
| value=False | |
| ) | |
| ma_slider = gr.Slider( | |
| minimum=3, maximum=30, value=7, step=1, | |
| label="Moving Average Window (Periods)", | |
| elem_id="ma_window", | |
| visible=False | |
| ) | |
| skew_checkbox = gr.Checkbox( | |
| label="Add Skew Ratio & Momentum", | |
| value=False | |
| ) | |
| run_button = gr.Button( | |
| "Run forecast", | |
| variant="primary", | |
| size="lg", | |
| elem_id="run_btn" | |
| ) | |
| with gr.Column(scale=2): | |
| forecast_plot = gr.Plot( | |
| label="Forecast Visualization", | |
| elem_id="plot" | |
| ) | |
| output_text = gr.Markdown( | |
| "### Welcome!\nUpload your CSV to get started.", | |
| elem_id="output" | |
| ) | |
| gr.Markdown("**Built by** [next one gmbh](https://nextone.at/?utm_source=dashboard&utm_medium=referrer&utm_campaign=tirex)") | |
| # Event for updating dropdowns on file upload | |
| csv_file.change( | |
| load_columns, | |
| inputs=csv_file, | |
| outputs=[time_dropdown, column_dropdown, start_slider, end_slider] | |
| ) | |
| # Event for updating MA slider visibility | |
| ma_checkbox.change( | |
| update_ma_visibility, | |
| inputs=[ma_checkbox], | |
| outputs=[ma_slider] | |
| ) | |
| # Event for running forecast | |
| run_button.click( | |
| run_forecast, | |
| inputs=[csv_file, time_dropdown, column_dropdown, start_slider, end_slider, prediction_length, confidence, trend_checkbox, ma_checkbox, ma_slider, skew_checkbox], | |
| outputs=[forecast_plot, output_text] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |