michaellupo74 commited on
Commit
4694769
·
verified ·
1 Parent(s): 63a424f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -48
app.py CHANGED
@@ -1,4 +1,3 @@
1
- import io
2
  import numpy as np
3
  import pandas as pd
4
  import streamlit as st
@@ -7,102 +6,131 @@ from chronos import BaseChronosPipeline
7
 
8
  st.set_page_config(page_title="Chronos-Bolt Zero-Shot Forecast", layout="centered")
9
  st.title("Chronos-Bolt Zero-Shot Forecast")
10
- st.caption("Paste a series or pull a ticker and get a probabilistic forecast (q10 / q50 / q90).")
11
 
12
- # --------- SETTINGS ---------
13
  MODEL_CHOICES = {
14
- "Bolt Mini (fast, CPU-friendly)": "amazon/chronos-bolt-mini",
15
- "Bolt Small (better, may need GPU)": "amazon/chronos-bolt-small",
16
  }
17
- DEFAULT_MODEL = "Bolt Mini (fast, CPU-friendly)"
18
 
19
- # --------- HELPERS ---------
20
  @st.cache_resource(show_spinner=True)
21
  def load_pipeline(model_id: str):
22
  device = "cuda" if torch.cuda.is_available() else "cpu"
23
  dtype = torch.bfloat16 if device == "cuda" else torch.float32
24
- return BaseChronosPipeline.from_pretrained(
25
- model_id, device_map=device, torch_dtype=dtype
26
- )
 
 
 
27
 
28
  @st.cache_data(show_spinner=False)
29
  def load_ticker_series(ticker: str, period: str = "2y"):
30
  import yfinance as yf
31
- df = yf.download(ticker, period=period, interval="1d", progress=False)
32
- close = df["Close"].dropna().astype("float32").values
33
- return close
 
 
 
 
34
 
35
  def parse_pasted_series(txt: str):
 
 
36
  vals = []
37
- for line in txt.splitlines():
38
- line = line.strip().replace(",", "")
39
- if line:
40
- try:
41
- vals.append(float(line))
42
- except:
43
- pass
44
- return np.asarray(vals, dtype="float32")
45
-
46
- # --------- UI ---------
47
- col1, col2 = st.columns(2)
48
- with col1:
 
 
 
 
 
 
 
 
 
49
  model_label = st.selectbox("Model", list(MODEL_CHOICES.keys()), index=0)
50
- with col2:
51
- pred_len = st.number_input("Prediction length (steps)", 1, 200, 30)
52
 
53
- src = st.radio("Data source", ["Ticker (yfinance)", "Paste numbers"], horizontal=True)
54
 
55
  series = None
56
  if src == "Ticker (yfinance)":
57
- t1, t2 = st.columns([2,1])
58
  with t1:
59
  ticker = st.text_input("Ticker (e.g., AAPL, SPY, BTC-USD)", "AAPL")
60
  with t2:
61
  period = st.selectbox("History window", ["6mo", "1y", "2y", "5y"], index=2)
62
  if st.button("Load data"):
63
- with st.spinner("Downloading..."):
64
- series = load_ticker_series(ticker.strip(), period)
65
- if series.size == 0:
66
- st.error("No series returned. Try a different ticker or window.")
67
- series = None
 
 
68
  else:
69
- txt = st.text_area("One value per line", "1\n2\n3\n4\n5\n6\n7\n8\n9\n10")
70
- series = parse_pasted_series(txt)
 
 
 
 
 
 
 
71
 
 
72
  if series is not None and series.size > 5:
73
  st.write(f"Loaded {series.size} points.")
74
- st.line_chart(pd.DataFrame({"value": series}))
75
 
76
  if st.button("Forecast"):
77
  with st.spinner("Running Chronos-Bolt..."):
78
- model_id = MODEL_CHOICES[model_label]
79
- pipe = load_pipeline(model_id)
80
-
81
  ctx = torch.tensor(series, dtype=torch.float32)
82
  q_levels = [0.10, 0.50, 0.90]
83
- # predict_quantiles returns (quantiles, mean)
84
  quantiles, mean = pipe.predict_quantiles(
85
  context=ctx,
86
  prediction_length=int(pred_len),
87
  quantile_levels=q_levels,
88
  )
89
 
90
- q = quantiles[0].cpu().numpy() # shape [pred_len, 3]
91
- lo, med, hi = q[:, 0], q[:, 1], q[:, 2]
92
 
93
- # Plot
94
  import matplotlib.pyplot as plt
95
  hist_x = np.arange(len(series))
96
- fut_x = np.arange(len(series), len(series) + pred_len)
97
 
98
  fig = plt.figure(figsize=(9, 4.5))
99
  plt.plot(hist_x, series, label="history")
100
  plt.plot(fut_x, med, label="median forecast")
101
- plt.fill_between(fut_x, lo, hi, alpha=0.3, label="q10–q90 interval")
102
  plt.legend()
103
  plt.grid(True, alpha=0.3)
104
  st.pyplot(fig)
105
 
106
- st.success("Done.")
 
 
 
 
 
 
107
  else:
108
- st.info("Load a ticker or paste at least ~20 numbers to begin.")
 
 
1
  import numpy as np
2
  import pandas as pd
3
  import streamlit as st
 
6
 
7
  st.set_page_config(page_title="Chronos-Bolt Zero-Shot Forecast", layout="centered")
8
  st.title("Chronos-Bolt Zero-Shot Forecast")
9
+ st.caption("Zero-shot probabilistic forecasting (q10/q50/q90) using amazon/chronos-bolt-* models.")
10
 
11
+ # -------------------- Model options --------------------
12
  MODEL_CHOICES = {
13
+ "Bolt Mini (CPU-friendly)": "amazon/chronos-bolt-mini",
14
+ "Bolt Small (better; GPU if available)": "amazon/chronos-bolt-small",
15
  }
 
16
 
 
17
  @st.cache_resource(show_spinner=True)
18
  def load_pipeline(model_id: str):
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
  dtype = torch.bfloat16 if device == "cuda" else torch.float32
21
+ return BaseChronosPipeline.from_pretrained(model_id, device_map=device, torch_dtype=dtype)
22
+
23
+ # -------------------- Data loaders (always return 1-D) --------------------
24
+ def _force_1d(a):
25
+ a = pd.Series(a, dtype="float32").replace([np.inf, -np.inf], np.nan).dropna()
26
+ return a.to_numpy().reshape(-1)
27
 
28
  @st.cache_data(show_spinner=False)
29
  def load_ticker_series(ticker: str, period: str = "2y"):
30
  import yfinance as yf
31
+ df = yf.download(ticker, period=period, interval="1d", auto_adjust=True, progress=False)
32
+ if df.empty:
33
+ return np.asarray([], dtype="float32")
34
+ close = df["Close"]
35
+ if isinstance(close, pd.DataFrame): # handle rare multi-index cases
36
+ close = close.iloc[:, 0]
37
+ return _force_1d(close)
38
 
39
  def parse_pasted_series(txt: str):
40
+ import re
41
+ toks = re.split(r"[,\s]+", txt.strip())
42
  vals = []
43
+ for t in toks:
44
+ if not t:
45
+ continue
46
+ try:
47
+ vals.append(float(t))
48
+ except:
49
+ pass
50
+ return _force_1d(vals)
51
+
52
+ def load_csv_series(file, column=None):
53
+ df = pd.read_csv(file)
54
+ if column is None:
55
+ num_cols = [c for c in df.columns if np.issubdtype(df[c].dtype, np.number)]
56
+ column = num_cols[0] if num_cols else None
57
+ if column is None:
58
+ return np.asarray([], dtype="float32"), df, None
59
+ return _force_1d(df[column]), df, column
60
+
61
+ # -------------------- UI --------------------
62
+ c1, c2 = st.columns(2)
63
+ with c1:
64
  model_label = st.selectbox("Model", list(MODEL_CHOICES.keys()), index=0)
65
+ with c2:
66
+ pred_len = st.number_input("Prediction length (steps)", 1, 365, 30)
67
 
68
+ src = st.radio("Data source", ["Ticker (yfinance)", "Paste numbers", "Upload CSV"], horizontal=True)
69
 
70
  series = None
71
  if src == "Ticker (yfinance)":
72
+ t1, t2 = st.columns([2, 1])
73
  with t1:
74
  ticker = st.text_input("Ticker (e.g., AAPL, SPY, BTC-USD)", "AAPL")
75
  with t2:
76
  period = st.selectbox("History window", ["6mo", "1y", "2y", "5y"], index=2)
77
  if st.button("Load data"):
78
+ series = load_ticker_series(ticker.strip(), period)
79
+ if series.size == 0:
80
+ st.error("No data returned. Try another ticker/window.")
81
+ elif src == "Paste numbers":
82
+ txt = st.text_area("One value per line (or comma/space separated)", "1\n2\n3\n4\n5\n6\n7\n8\n9\n10")
83
+ if st.button("Use pasted data"):
84
+ series = parse_pasted_series(txt)
85
  else:
86
+ uploaded = st.file_uploader("Upload CSV", type=["csv"])
87
+ if uploaded is not None:
88
+ df = pd.read_csv(uploaded)
89
+ numeric_cols = [c for c in df.columns if np.issubdtype(df[c].dtype, np.number)]
90
+ col = st.selectbox("Pick numeric column", numeric_cols) if numeric_cols else None
91
+ if st.button("Load CSV column") and col:
92
+ series, _, _ = load_csv_series(uploaded, column=col)
93
+ elif uploaded and not numeric_cols:
94
+ st.error("No numeric columns found in CSV.")
95
 
96
+ # -------------------- Plot + Forecast --------------------
97
  if series is not None and series.size > 5:
98
  st.write(f"Loaded {series.size} points.")
99
+ st.line_chart(pd.DataFrame(series, columns=["value"])) # always 1-D -> no error
100
 
101
  if st.button("Forecast"):
102
  with st.spinner("Running Chronos-Bolt..."):
103
+ pipe = load_pipeline(MODEL_CHOICES[model_label])
 
 
104
  ctx = torch.tensor(series, dtype=torch.float32)
105
  q_levels = [0.10, 0.50, 0.90]
106
+
107
  quantiles, mean = pipe.predict_quantiles(
108
  context=ctx,
109
  prediction_length=int(pred_len),
110
  quantile_levels=q_levels,
111
  )
112
 
113
+ q_np = quantiles[0].cpu().numpy() # shape [pred_len, 3]
114
+ lo, med, hi = q_np[:, 0], q_np[:, 1], q_np[:, 2]
115
 
 
116
  import matplotlib.pyplot as plt
117
  hist_x = np.arange(len(series))
118
+ fut_x = np.arange(len(series), len(series) + int(pred_len))
119
 
120
  fig = plt.figure(figsize=(9, 4.5))
121
  plt.plot(hist_x, series, label="history")
122
  plt.plot(fut_x, med, label="median forecast")
123
+ plt.fill_between(fut_x, lo, hi, alpha=0.3, label="q10–q90 band")
124
  plt.legend()
125
  plt.grid(True, alpha=0.3)
126
  st.pyplot(fig)
127
 
128
+ out = pd.DataFrame({"t": fut_x, "q10": lo, "q50": med, "q90": hi})
129
+ st.download_button(
130
+ "Download forecast CSV",
131
+ out.to_csv(index=False).encode("utf-8"),
132
+ file_name="chronos_forecast.csv",
133
+ mime="text/csv",
134
+ )
135
  else:
136
+ st.info("Load a ticker, paste values, or upload a CSV to begin.")