import gradio as gr import pandas as pd import io import torch import numpy as np from tirex import load_model import matplotlib.pyplot as plt from datetime import timedelta import warnings warnings.filterwarnings('ignore') # Load model (once) model = load_model("NX-AI/TiRex") def load_columns(file): if file is None: return (gr.Dropdown(choices=[], label="Select Time Column", interactive=True), gr.Dropdown(choices=[], label="Select Value Column", interactive=True), gr.Slider(minimum=1, maximum=1, value=1, step=1, label="Historical Start Index"), gr.Slider(minimum=1, maximum=1, value=1, step=1, label="Historical End Index")) try: # Handle file as path string (Gradio convention) with open(file, 'rb') as f: content = f.read() df_preview = pd.read_csv(io.BytesIO(content)) # All columns for time selection all_cols = df_preview.columns.tolist() time_choices = [(col, col) for col in all_cols] time_value = all_cols[0] if all_cols else None # Available numeric columns for forecast numeric_cols = df_preview.select_dtypes(include=['number']).columns.tolist() if numeric_cols: value_choices = [(col, col) for col in numeric_cols] value_value = numeric_cols[0] else: value_choices = [] value_value = None n_rows = len(df_preview) time_dropdown = gr.Dropdown( choices=time_choices, value=time_value, label="Select Time Column", interactive=True ) value_dropdown = gr.Dropdown( choices=value_choices, value=value_value, label="Select Value Column", interactive=True ) if value_choices else gr.Dropdown( choices=[], value=None, label="No numeric columns found", interactive=False ) start_slider = gr.Slider( minimum=1, maximum=n_rows, value=1, step=1, label="Historical Start Index" ) end_slider = gr.Slider( minimum=1, maximum=n_rows, value=n_rows, step=1, label="Historical End Index" ) return time_dropdown, value_dropdown, start_slider, end_slider except Exception as e: return (gr.Dropdown( choices=[], value=None, label=f"Error loading CSV: {str(e)}", interactive=False ), gr.Dropdown( choices=[], value=None, label=f"Error loading CSV: {str(e)}", interactive=False ), gr.Slider(minimum=1, maximum=1, value=1, step=1, label="Historical Start Index"), gr.Slider(minimum=1, maximum=1, value=1, step=1, label="Historical End Index")) def update_ma_visibility(add_ma): return gr.Slider(visible=add_ma) def run_forecast(file, time_col, selected_col, start_idx, end_idx, prediction_length, confidence, add_trendline, add_moving_average, ma_window, add_skew_viz): if file is None or time_col is None or selected_col is None: return None, "### Error\nPlease upload a CSV and select time and value columns!" try: # Handle file as path string (Gradio convention) with open(file, 'rb') as f: content = f.read() df = pd.read_csv(io.BytesIO(content)) # Validate columns exist if time_col not in df.columns or selected_col not in df.columns: return None, f"### Error\nSelected columns '{time_col}' or '{selected_col}' not found in CSV." # Rename selected columns df = df.rename(columns={time_col: 'date', selected_col: 'sales'}) # Validate required_cols = ['date', 'sales'] if not all(col in df.columns for col in required_cols): return None, f"### Error\nMissing renamed columns." # Prep data df['date'] = pd.to_datetime(df['date']) df = df.set_index('date').sort_index() full_len = len(df) context_start = max(0, int(start_idx) - 1) context_end = min(full_len, int(end_idx)) context_df = df.iloc[context_start:context_end] held_out_df = df.iloc[context_end:] if context_end < full_len else pd.DataFrame(index=pd.DatetimeIndex([]), columns=df.columns) if len(context_df) < 10: return None, "### Error\nNeed at least 10 data points in the selected historical range." context_series = context_df['sales'].dropna().values print(f"Loaded context: {len(context_series)} points from {context_df.index.min().date()} to {context_df.index.max().date()} (Column: {selected_col})") # For logs # Infer freq freq = pd.infer_freq(context_df.index) if freq is None: freq = 'D' print(f"Frequency: '{freq}'.") # Prep context context_len = min(len(context_series), 2048) context = torch.tensor(context_series[-context_len:]).unsqueeze(0).float() pred_len = prediction_length conf_level = confidence / 100.0 lower_alpha_slider = (1 - conf_level) / 2 upper_alpha_slider = 1 - lower_alpha_slider # Fixed inner level: 50% lower_alpha_50 = 0.25 upper_alpha_50 = 0.75 quantiles, mean = model.forecast(context=context, prediction_length=pred_len) # Median is always 50th percentile (index 4) median = quantiles[0, :, 4].numpy() # Extract quantiles array q = quantiles[0].detach().numpy() # (pred_len, 9) alphas = np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]) # Compute bounds for 50% and slider lower50 = np.zeros(pred_len) upper50 = np.zeros(pred_len) lower_slider = np.zeros(pred_len) upper_slider = np.zeros(pred_len) skew_ratios = np.zeros(pred_len) delta_skews = np.zeros(pred_len) skew_directions = [] epsilon = 1e-8 for t in range(pred_len): q_t = q[t] lower50[t] = np.interp(lower_alpha_50, alphas, q_t) upper50[t] = np.interp(upper_alpha_50, alphas, q_t) lower_slider[t] = np.interp(lower_alpha_slider, alphas, q_t) upper_slider[t] = np.interp(upper_alpha_slider, alphas, q_t) # Compute skew direction based on asymmetry around median med = median[t] upside_dist = upper_slider[t] - med downside_dist = med - lower_slider[t] total_dist = upside_dist + downside_dist + epsilon skew_ratios[t] = (upside_dist - downside_dist) / total_dist # Delta for momentum (shift from previous step) if t == 0: delta_skews[t] = 0.0 else: delta_skews[t] = skew_ratios[t] - skew_ratios[t-1] # Existing categorical (optional: derive from skew_ratio for compat) if skew_ratios[t] > 0.1: skew_directions.append("Upside") elif skew_ratios[t] < -0.1: skew_directions.append("Downside") else: skew_directions.append("Neutral") # Mean forecast mean_forecast = mean[0].detach().numpy() # Future dates last_date = context_df.index[-1] if freq == 'D': future_dates = pd.date_range(start=last_date + timedelta(days=1), periods=pred_len, freq='D') else: future_dates = pd.date_range(start=last_date + pd.DateOffset(1), periods=pred_len, freq=freq) pred_df = pd.DataFrame({ 'date': future_dates, 'predicted_sales_median': median, 'predicted_sales_lower': lower_slider, 'predicted_sales_upper': upper_slider, 'predicted_sales_mean': mean_forecast, 'skew_direction': skew_directions, 'skew_ratio': skew_ratios, 'delta_skew': delta_skews }).set_index('date') # Count skews for summary upside_count = sum(1 for r in skew_ratios if r > 0.1) downside_count = sum(1 for r in skew_ratios if r < -0.1) neutral_count = pred_len - upside_count - downside_count # NEW: Summary stats for skew momentum avg_skew = skew_ratios.mean() max_momentum_shift = abs(delta_skews).max() # Prepare markdown output (broken into smaller strings to avoid multiline f-string parsing issues) markdown_text = "### Summary\n" markdown_text += "- **Number of Historical Periods Used:** {} points\n".format(len(context_series)) markdown_text += "- **Held Out Periods:** {} points {}\n".format(len(held_out_df), "(Full Context Used)" if len(held_out_df) == 0 else "(For Validation)") markdown_text += "- **Prediction Length:** {} periods\n".format(pred_len) markdown_text += "- **Prediction Interval:** {}% (alphas: {:.3f} - {:.3f})\n".format(confidence, lower_alpha_slider, upper_alpha_slider) markdown_text += "- **Sum of Median Predicted Values:** {:.2f}\n".format(pred_df['predicted_sales_median'].sum()) markdown_text += "- **Sum of Mean Predicted Values:** {:.2f}\n".format(pred_df['predicted_sales_mean'].sum()) markdown_text += "- **Skew Distribution:** {} Upside, {} Downside, {} Neutral\n".format(upside_count, downside_count, neutral_count) markdown_text += "- **Average Skew Ratio:** {:.3f} (momentum: max |Δ| = {:.3f})\n\n".format(avg_skew, max_momentum_shift) forecast_table = "### TiRex Forecast Results (Median + {}% Prediction Interval)\n\n".format(confidence) forecast_table += "| Date | Median | Lower Bound | Upper Bound | Mean | Skew Direction | Skew Ratio | Δ Skew |\n" forecast_table += "|------|--------|-------------|-------------|------|----------------|------------|--------|\n" for idx, row in pred_df.iterrows(): forecast_table += "| {} | {:.2f} | {:.2f} | {:.2f} | {:.2f} | {} | {:.3f} | {:.3f} |\n".format( idx.strftime('%Y-%m-%d'), row['predicted_sales_median'], row['predicted_sales_lower'], row['predicted_sales_upper'], row['predicted_sales_mean'], row['skew_direction'], row['skew_ratio'], row['delta_skew'] ) sample_data = "### Sample Historical Data (Context)\n" sample_data += "```\n" + context_df.head().to_string() + "\n```" markdown_text += f'\n
Click to expand Forecast Table\n\n{forecast_table}\n
\n\n' markdown_text += f'
Click to expand Sample Historical Data\n\n{sample_data}\n
' # Create plot (single subplot) fig, ax = plt.subplots(figsize=(14, 7)) fig.set_dpi(300) # High resolution for PNG zoom # Historical and held-out ax.plot(context_df.index, context_df['sales'], label='Historical Data', color='#1f77b4', linewidth=1.5, alpha=0.8) if not held_out_df.empty: ax.plot(held_out_df.index, held_out_df['sales'], label='Held Out Actual (Validation)', color='#2ca02c', linestyle=':', linewidth=2) if add_trendline: x = np.arange(len(context_df)) y = context_df['sales'].values if len(x) > 1: coeffs = np.polyfit(x, y, 1) trend = np.polyval(coeffs, x) ax.plot(context_df.index, trend, label='Trendline', color='black', linestyle='-', linewidth=1.5) if add_moving_average: window = int(ma_window) ma = context_df['sales'].rolling(window=window, min_periods=1).mean() ax.plot(context_df.index, ma, label=f'Moving Average ({window} periods)', color='purple', linewidth=2) # Median forecast: regular green line ax.plot(pred_df.index, median, label='Median Forecast', color='green', linewidth=2, alpha=0.9) # Fan chart: non-overlapping bands # Inner 50% (lightest, center) ax.fill_between(pred_df.index, lower50, upper50, color='#d62728', alpha=0.1, label='50% Prediction Interval') # Wings: between 50% and slider level (medium) ax.fill_between(pred_df.index, lower_slider, lower50, color='#d62728', alpha=0.3) ax.fill_between(pred_df.index, upper50, upper_slider, color='#d62728', alpha=0.3, label=f'{confidence}% Prediction Interval') # Optional skew visualization on twin axis (light lines) skew_handles = [] if add_skew_viz: ax2 = ax.twinx() # Light line for skew_ratio line1, = ax2.plot(pred_df.index, skew_ratios, label='Skew Ratio', color='lightblue', linewidth=1, alpha=0.6) skew_handles.append(line1) # Light line for delta_skew (momentum) - milder color line2, = ax2.plot(pred_df.index, delta_skews, label='Skew Momentum', color='lightgray', linewidth=1, alpha=0.6) skew_handles.append(line2) ax2.set_ylabel('Skew (-1 to 1)', color='lightblue') ax2.tick_params(colors='lightblue') # Set limits for visibility ax2.set_ylim(-1.2, 1.2) ax.set_title(f'{selected_col} Forecast with TiRex (Context: {context_start+1}-{context_end}, Horizon: {pred_len})', fontsize=16, fontweight='bold') ax.set_xlabel('Date', fontsize=12) ax.set_ylabel(selected_col, fontsize=12) # Combined legend to avoid overlap if add_skew_viz: handles1, labels1 = ax.get_legend_handles_labels() handles2, labels2 = ax2.get_legend_handles_labels() ax.legend(handles1 + handles2, labels1 + labels2, fontsize=10, loc='upper left') else: ax.legend(fontsize=10) ax.tick_params(axis='x', rotation=45) plt.tight_layout() return fig, markdown_text except Exception as e: return None, f"### Error\n{str(e)}\n\nTips: Ensure the time column can be parsed as dates; check NaNs/zeros; ensure data is valid." # Create the Gradio interface with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="red"), title="TiRex Forecaster") as demo: gr.HTML(""" """) gr.Markdown(""" # TiRex Forecaster Dashboard Upload a CSV file with a time column and numeric columns. Select the time column and one numeric column to forecast future values using the TiRex model. """) with gr.Row(variant="panel"): with gr.Column(scale=1): csv_file = gr.File( file_types=[".csv"], label="Upload CSV File", elem_id="file_upload" ) gr.Markdown("The minimum effective input is around 128 time steps per series. Use a full context of 2048 steps for optimal performance.") time_dropdown = gr.Dropdown( choices=[], label="Select Time Column", interactive=True, elem_id="time_select" ) column_dropdown = gr.Dropdown( choices=[], label="Select Value Column", interactive=True, elem_id="column_select" ) start_slider = gr.Slider( minimum=1, maximum=1, value=1, step=1, label="Historical Start Index", elem_id="start_idx" ) end_slider = gr.Slider( minimum=1, maximum=1, value=1, step=1, label="Historical End Index", elem_id="end_idx" ) prediction_length = gr.Slider( minimum=1, maximum=720, value=100, step=1, label="Prediction Length", elem_id="pred_length" ) confidence = gr.Slider( minimum=50, maximum=95, value=80, step=5, label="Prediction Interval (%)", elem_id="confidence" ) trend_checkbox = gr.Checkbox( label="Add Trendline", value=False ) ma_checkbox = gr.Checkbox( label="Add Moving Average", value=False ) ma_slider = gr.Slider( minimum=3, maximum=30, value=7, step=1, label="Moving Average Window (Periods)", elem_id="ma_window", visible=False ) skew_checkbox = gr.Checkbox( label="Add Skew Ratio & Momentum", value=False ) run_button = gr.Button( "Run forecast", variant="primary", size="lg", elem_id="run_btn" ) with gr.Column(scale=2): forecast_plot = gr.Plot( label="Forecast Visualization", elem_id="plot" ) output_text = gr.Markdown( "### Welcome!\nUpload your CSV to get started.", elem_id="output" ) gr.Markdown("**Built by** [next one gmbh](https://nextone.at/?utm_source=dashboard&utm_medium=referrer&utm_campaign=tirex)") # Event for updating dropdowns on file upload csv_file.change( load_columns, inputs=csv_file, outputs=[time_dropdown, column_dropdown, start_slider, end_slider] ) # Event for updating MA slider visibility ma_checkbox.change( update_ma_visibility, inputs=[ma_checkbox], outputs=[ma_slider] ) # Event for running forecast run_button.click( run_forecast, inputs=[csv_file, time_dropdown, column_dropdown, start_slider, end_slider, prediction_length, confidence, trend_checkbox, ma_checkbox, ma_slider, skew_checkbox], outputs=[forecast_plot, output_text] ) if __name__ == "__main__": demo.launch()