import gradio as gr
import pandas as pd
import io
import torch
import numpy as np
from tirex import load_model
import matplotlib.pyplot as plt
from datetime import timedelta
import warnings
warnings.filterwarnings('ignore')
# Load model (once)
model = load_model("NX-AI/TiRex")
def load_columns(file):
if file is None:
return (gr.Dropdown(choices=[], label="Select Time Column", interactive=True),
gr.Dropdown(choices=[], label="Select Value Column", interactive=True),
gr.Slider(minimum=1, maximum=1, value=1, step=1, label="Historical Start Index"),
gr.Slider(minimum=1, maximum=1, value=1, step=1, label="Historical End Index"))
try:
# Handle file as path string (Gradio convention)
with open(file, 'rb') as f:
content = f.read()
df_preview = pd.read_csv(io.BytesIO(content))
# All columns for time selection
all_cols = df_preview.columns.tolist()
time_choices = [(col, col) for col in all_cols]
time_value = all_cols[0] if all_cols else None
# Available numeric columns for forecast
numeric_cols = df_preview.select_dtypes(include=['number']).columns.tolist()
if numeric_cols:
value_choices = [(col, col) for col in numeric_cols]
value_value = numeric_cols[0]
else:
value_choices = []
value_value = None
n_rows = len(df_preview)
time_dropdown = gr.Dropdown(
choices=time_choices,
value=time_value,
label="Select Time Column",
interactive=True
)
value_dropdown = gr.Dropdown(
choices=value_choices,
value=value_value,
label="Select Value Column",
interactive=True
) if value_choices else gr.Dropdown(
choices=[],
value=None,
label="No numeric columns found",
interactive=False
)
start_slider = gr.Slider(
minimum=1, maximum=n_rows, value=1, step=1,
label="Historical Start Index"
)
end_slider = gr.Slider(
minimum=1, maximum=n_rows, value=n_rows, step=1,
label="Historical End Index"
)
return time_dropdown, value_dropdown, start_slider, end_slider
except Exception as e:
return (gr.Dropdown(
choices=[],
value=None,
label=f"Error loading CSV: {str(e)}",
interactive=False
), gr.Dropdown(
choices=[],
value=None,
label=f"Error loading CSV: {str(e)}",
interactive=False
), gr.Slider(minimum=1, maximum=1, value=1, step=1, label="Historical Start Index"),
gr.Slider(minimum=1, maximum=1, value=1, step=1, label="Historical End Index"))
def update_ma_visibility(add_ma):
return gr.Slider(visible=add_ma)
def run_forecast(file, time_col, selected_col, start_idx, end_idx, prediction_length, confidence, add_trendline, add_moving_average, ma_window, add_skew_viz):
if file is None or time_col is None or selected_col is None:
return None, "### Error\nPlease upload a CSV and select time and value columns!"
try:
# Handle file as path string (Gradio convention)
with open(file, 'rb') as f:
content = f.read()
df = pd.read_csv(io.BytesIO(content))
# Validate columns exist
if time_col not in df.columns or selected_col not in df.columns:
return None, f"### Error\nSelected columns '{time_col}' or '{selected_col}' not found in CSV."
# Rename selected columns
df = df.rename(columns={time_col: 'date', selected_col: 'sales'})
# Validate
required_cols = ['date', 'sales']
if not all(col in df.columns for col in required_cols):
return None, f"### Error\nMissing renamed columns."
# Prep data
df['date'] = pd.to_datetime(df['date'])
df = df.set_index('date').sort_index()
full_len = len(df)
context_start = max(0, int(start_idx) - 1)
context_end = min(full_len, int(end_idx))
context_df = df.iloc[context_start:context_end]
held_out_df = df.iloc[context_end:] if context_end < full_len else pd.DataFrame(index=pd.DatetimeIndex([]), columns=df.columns)
if len(context_df) < 10:
return None, "### Error\nNeed at least 10 data points in the selected historical range."
context_series = context_df['sales'].dropna().values
print(f"Loaded context: {len(context_series)} points from {context_df.index.min().date()} to {context_df.index.max().date()} (Column: {selected_col})") # For logs
# Infer freq
freq = pd.infer_freq(context_df.index)
if freq is None:
freq = 'D'
print(f"Frequency: '{freq}'.")
# Prep context
context_len = min(len(context_series), 2048)
context = torch.tensor(context_series[-context_len:]).unsqueeze(0).float()
pred_len = prediction_length
conf_level = confidence / 100.0
lower_alpha_slider = (1 - conf_level) / 2
upper_alpha_slider = 1 - lower_alpha_slider
# Fixed inner level: 50%
lower_alpha_50 = 0.25
upper_alpha_50 = 0.75
quantiles, mean = model.forecast(context=context, prediction_length=pred_len)
# Median is always 50th percentile (index 4)
median = quantiles[0, :, 4].numpy()
# Extract quantiles array
q = quantiles[0].detach().numpy() # (pred_len, 9)
alphas = np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9])
# Compute bounds for 50% and slider
lower50 = np.zeros(pred_len)
upper50 = np.zeros(pred_len)
lower_slider = np.zeros(pred_len)
upper_slider = np.zeros(pred_len)
skew_ratios = np.zeros(pred_len)
delta_skews = np.zeros(pred_len)
skew_directions = []
epsilon = 1e-8
for t in range(pred_len):
q_t = q[t]
lower50[t] = np.interp(lower_alpha_50, alphas, q_t)
upper50[t] = np.interp(upper_alpha_50, alphas, q_t)
lower_slider[t] = np.interp(lower_alpha_slider, alphas, q_t)
upper_slider[t] = np.interp(upper_alpha_slider, alphas, q_t)
# Compute skew direction based on asymmetry around median
med = median[t]
upside_dist = upper_slider[t] - med
downside_dist = med - lower_slider[t]
total_dist = upside_dist + downside_dist + epsilon
skew_ratios[t] = (upside_dist - downside_dist) / total_dist
# Delta for momentum (shift from previous step)
if t == 0:
delta_skews[t] = 0.0
else:
delta_skews[t] = skew_ratios[t] - skew_ratios[t-1]
# Existing categorical (optional: derive from skew_ratio for compat)
if skew_ratios[t] > 0.1:
skew_directions.append("Upside")
elif skew_ratios[t] < -0.1:
skew_directions.append("Downside")
else:
skew_directions.append("Neutral")
# Mean forecast
mean_forecast = mean[0].detach().numpy()
# Future dates
last_date = context_df.index[-1]
if freq == 'D':
future_dates = pd.date_range(start=last_date + timedelta(days=1), periods=pred_len, freq='D')
else:
future_dates = pd.date_range(start=last_date + pd.DateOffset(1), periods=pred_len, freq=freq)
pred_df = pd.DataFrame({
'date': future_dates,
'predicted_sales_median': median,
'predicted_sales_lower': lower_slider,
'predicted_sales_upper': upper_slider,
'predicted_sales_mean': mean_forecast,
'skew_direction': skew_directions,
'skew_ratio': skew_ratios,
'delta_skew': delta_skews
}).set_index('date')
# Count skews for summary
upside_count = sum(1 for r in skew_ratios if r > 0.1)
downside_count = sum(1 for r in skew_ratios if r < -0.1)
neutral_count = pred_len - upside_count - downside_count
# NEW: Summary stats for skew momentum
avg_skew = skew_ratios.mean()
max_momentum_shift = abs(delta_skews).max()
# Prepare markdown output (broken into smaller strings to avoid multiline f-string parsing issues)
markdown_text = "### Summary\n"
markdown_text += "- **Number of Historical Periods Used:** {} points\n".format(len(context_series))
markdown_text += "- **Held Out Periods:** {} points {}\n".format(len(held_out_df), "(Full Context Used)" if len(held_out_df) == 0 else "(For Validation)")
markdown_text += "- **Prediction Length:** {} periods\n".format(pred_len)
markdown_text += "- **Prediction Interval:** {}% (alphas: {:.3f} - {:.3f})\n".format(confidence, lower_alpha_slider, upper_alpha_slider)
markdown_text += "- **Sum of Median Predicted Values:** {:.2f}\n".format(pred_df['predicted_sales_median'].sum())
markdown_text += "- **Sum of Mean Predicted Values:** {:.2f}\n".format(pred_df['predicted_sales_mean'].sum())
markdown_text += "- **Skew Distribution:** {} Upside, {} Downside, {} Neutral\n".format(upside_count, downside_count, neutral_count)
markdown_text += "- **Average Skew Ratio:** {:.3f} (momentum: max |Δ| = {:.3f})\n\n".format(avg_skew, max_momentum_shift)
forecast_table = "### TiRex Forecast Results (Median + {}% Prediction Interval)\n\n".format(confidence)
forecast_table += "| Date | Median | Lower Bound | Upper Bound | Mean | Skew Direction | Skew Ratio | Δ Skew |\n"
forecast_table += "|------|--------|-------------|-------------|------|----------------|------------|--------|\n"
for idx, row in pred_df.iterrows():
forecast_table += "| {} | {:.2f} | {:.2f} | {:.2f} | {:.2f} | {} | {:.3f} | {:.3f} |\n".format(
idx.strftime('%Y-%m-%d'),
row['predicted_sales_median'],
row['predicted_sales_lower'],
row['predicted_sales_upper'],
row['predicted_sales_mean'],
row['skew_direction'],
row['skew_ratio'],
row['delta_skew']
)
sample_data = "### Sample Historical Data (Context)\n"
sample_data += "```\n" + context_df.head().to_string() + "\n```"
markdown_text += f'\nClick to expand Forecast Table
\n\n{forecast_table}\n \n\n'
markdown_text += f'Click to expand Sample Historical Data
\n\n{sample_data}\n '
# Create plot (single subplot)
fig, ax = plt.subplots(figsize=(14, 7))
fig.set_dpi(300) # High resolution for PNG zoom
# Historical and held-out
ax.plot(context_df.index, context_df['sales'], label='Historical Data', color='#1f77b4', linewidth=1.5, alpha=0.8)
if not held_out_df.empty:
ax.plot(held_out_df.index, held_out_df['sales'], label='Held Out Actual (Validation)', color='#2ca02c', linestyle=':', linewidth=2)
if add_trendline:
x = np.arange(len(context_df))
y = context_df['sales'].values
if len(x) > 1:
coeffs = np.polyfit(x, y, 1)
trend = np.polyval(coeffs, x)
ax.plot(context_df.index, trend, label='Trendline', color='black', linestyle='-', linewidth=1.5)
if add_moving_average:
window = int(ma_window)
ma = context_df['sales'].rolling(window=window, min_periods=1).mean()
ax.plot(context_df.index, ma, label=f'Moving Average ({window} periods)', color='purple', linewidth=2)
# Median forecast: regular green line
ax.plot(pred_df.index, median, label='Median Forecast', color='green', linewidth=2, alpha=0.9)
# Fan chart: non-overlapping bands
# Inner 50% (lightest, center)
ax.fill_between(pred_df.index, lower50, upper50,
color='#d62728', alpha=0.1, label='50% Prediction Interval')
# Wings: between 50% and slider level (medium)
ax.fill_between(pred_df.index, lower_slider, lower50,
color='#d62728', alpha=0.3)
ax.fill_between(pred_df.index, upper50, upper_slider,
color='#d62728', alpha=0.3, label=f'{confidence}% Prediction Interval')
# Optional skew visualization on twin axis (light lines)
skew_handles = []
if add_skew_viz:
ax2 = ax.twinx()
# Light line for skew_ratio
line1, = ax2.plot(pred_df.index, skew_ratios, label='Skew Ratio', color='lightblue', linewidth=1, alpha=0.6)
skew_handles.append(line1)
# Light line for delta_skew (momentum) - milder color
line2, = ax2.plot(pred_df.index, delta_skews, label='Skew Momentum', color='lightgray', linewidth=1, alpha=0.6)
skew_handles.append(line2)
ax2.set_ylabel('Skew (-1 to 1)', color='lightblue')
ax2.tick_params(colors='lightblue')
# Set limits for visibility
ax2.set_ylim(-1.2, 1.2)
ax.set_title(f'{selected_col} Forecast with TiRex (Context: {context_start+1}-{context_end}, Horizon: {pred_len})', fontsize=16, fontweight='bold')
ax.set_xlabel('Date', fontsize=12)
ax.set_ylabel(selected_col, fontsize=12)
# Combined legend to avoid overlap
if add_skew_viz:
handles1, labels1 = ax.get_legend_handles_labels()
handles2, labels2 = ax2.get_legend_handles_labels()
ax.legend(handles1 + handles2, labels1 + labels2, fontsize=10, loc='upper left')
else:
ax.legend(fontsize=10)
ax.tick_params(axis='x', rotation=45)
plt.tight_layout()
return fig, markdown_text
except Exception as e:
return None, f"### Error\n{str(e)}\n\nTips: Ensure the time column can be parsed as dates; check NaNs/zeros; ensure data is valid."
# Create the Gradio interface
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="red"), title="TiRex Forecaster") as demo:
gr.HTML("""
""")
gr.Markdown("""
# TiRex Forecaster Dashboard
Upload a CSV file with a time column and numeric columns. Select the time column and one numeric column to forecast future values using the TiRex model.
""")
with gr.Row(variant="panel"):
with gr.Column(scale=1):
csv_file = gr.File(
file_types=[".csv"],
label="Upload CSV File",
elem_id="file_upload"
)
gr.Markdown("The minimum effective input is around 128 time steps per series. Use a full context of 2048 steps for optimal performance.")
time_dropdown = gr.Dropdown(
choices=[],
label="Select Time Column",
interactive=True,
elem_id="time_select"
)
column_dropdown = gr.Dropdown(
choices=[],
label="Select Value Column",
interactive=True,
elem_id="column_select"
)
start_slider = gr.Slider(
minimum=1, maximum=1, value=1, step=1,
label="Historical Start Index",
elem_id="start_idx"
)
end_slider = gr.Slider(
minimum=1, maximum=1, value=1, step=1,
label="Historical End Index",
elem_id="end_idx"
)
prediction_length = gr.Slider(
minimum=1, maximum=720, value=100, step=1,
label="Prediction Length",
elem_id="pred_length"
)
confidence = gr.Slider(
minimum=50, maximum=95, value=80, step=5,
label="Prediction Interval (%)",
elem_id="confidence"
)
trend_checkbox = gr.Checkbox(
label="Add Trendline",
value=False
)
ma_checkbox = gr.Checkbox(
label="Add Moving Average",
value=False
)
ma_slider = gr.Slider(
minimum=3, maximum=30, value=7, step=1,
label="Moving Average Window (Periods)",
elem_id="ma_window",
visible=False
)
skew_checkbox = gr.Checkbox(
label="Add Skew Ratio & Momentum",
value=False
)
run_button = gr.Button(
"Run forecast",
variant="primary",
size="lg",
elem_id="run_btn"
)
with gr.Column(scale=2):
forecast_plot = gr.Plot(
label="Forecast Visualization",
elem_id="plot"
)
output_text = gr.Markdown(
"### Welcome!\nUpload your CSV to get started.",
elem_id="output"
)
gr.Markdown("**Built by** [next one gmbh](https://nextone.at/?utm_source=dashboard&utm_medium=referrer&utm_campaign=tirex)")
# Event for updating dropdowns on file upload
csv_file.change(
load_columns,
inputs=csv_file,
outputs=[time_dropdown, column_dropdown, start_slider, end_slider]
)
# Event for updating MA slider visibility
ma_checkbox.change(
update_ma_visibility,
inputs=[ma_checkbox],
outputs=[ma_slider]
)
# Event for running forecast
run_button.click(
run_forecast,
inputs=[csv_file, time_dropdown, column_dropdown, start_slider, end_slider, prediction_length, confidence, trend_checkbox, ma_checkbox, ma_slider, skew_checkbox],
outputs=[forecast_plot, output_text]
)
if __name__ == "__main__":
demo.launch()