File size: 12,589 Bytes
63a424f
 
 
 
 
 
 
 
4694769
63a424f
0048c91
9542081
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4694769
63a424f
4694769
 
63a424f
 
 
 
 
 
4694769
 
 
 
 
 
63a424f
 
 
 
4694769
 
 
 
 
 
 
63a424f
 
4694769
 
63a424f
4694769
0048c91
4694769
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63a424f
4694769
 
63a424f
4694769
63a424f
 
 
4694769
63a424f
 
 
 
 
4694769
 
 
 
 
 
 
63a424f
4694769
 
 
 
 
 
 
 
 
63a424f
4694769
63a424f
 
4694769
63a424f
 
 
4694769
63a424f
 
4694769
63a424f
 
 
 
 
 
4694769
 
63a424f
 
 
4694769
63a424f
 
 
 
4694769
63a424f
 
 
 
4694769
 
 
 
 
 
 
63a424f
4694769
0048c91
 
 
321112a
 
 
 
 
 
 
2f5a7e8
f41d71c
2f5a7e8
 
 
 
 
 
 
 
 
321112a
 
 
 
 
 
0048c91
321112a
 
 
 
 
 
 
 
 
 
2f5a7e8
 
 
 
 
 
 
 
 
 
 
321112a
2f5a7e8
f41d71c
 
 
 
 
 
 
733c07f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321112a
0048c91
 
 
 
321112a
 
 
 
 
 
 
 
 
 
 
 
 
 
f41d71c
 
 
 
 
321112a
 
 
 
 
 
 
f41d71c
321112a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0048c91
321112a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f41d71c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
import numpy as np
import pandas as pd
import streamlit as st
import torch
from chronos import BaseChronosPipeline

st.set_page_config(page_title="Chronos-Bolt Zero-Shot Forecast", layout="centered")
st.title("Chronos-Bolt Zero-Shot Forecast")
st.caption("Zero-shot probabilistic forecasting (q10/q50/q90) using amazon/chronos-bolt-* models.")

# -------------------- Indicator helpers (no pandas-ta needed) --------------------
def ema(series, length=20):
    s = pd.Series(series).astype("float64")
    return s.ewm(span=length, adjust=False).mean()

def rsi(series, length=14):
    s = pd.Series(series).astype("float64")
    delta = s.diff()
    gain = delta.clip(lower=0).ewm(alpha=1/length, adjust=False).mean()
    loss = (-delta.clip(upper=0)).ewm(alpha=1/length, adjust=False).mean()
    rs = gain / loss.replace(0, np.nan)
    return 100 - (100 / (1 + rs))

def stochastic_kd(high, low, close, k=14, d=3, smooth_k=3):
    h = pd.Series(high).astype("float64")
    l = pd.Series(low).astype("float64")
    c = pd.Series(close).astype("float64")
    hh = h.rolling(k).max()
    ll = l.rolling(k).min()
    raw_k = 100 * (c - ll) / (hh - ll)
    k_smoothed = raw_k.rolling(smooth_k).mean()
    d_line = k_smoothed.rolling(d).mean()
    return k_smoothed, d_line


# -------------------- Model options --------------------
MODEL_CHOICES = {
    "Bolt Mini (CPU-friendly)": "amazon/chronos-bolt-mini",
    "Bolt Small (better; GPU if available)": "amazon/chronos-bolt-small",
}

@st.cache_resource(show_spinner=True)
def load_pipeline(model_id: str):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    dtype = torch.bfloat16 if device == "cuda" else torch.float32
    return BaseChronosPipeline.from_pretrained(model_id, device_map=device, torch_dtype=dtype)

# -------------------- Data loaders (always return 1-D) --------------------
def _force_1d(a):
    a = pd.Series(a, dtype="float32").replace([np.inf, -np.inf], np.nan).dropna()
    return a.to_numpy().reshape(-1)

@st.cache_data(show_spinner=False)
def load_ticker_series(ticker: str, period: str = "2y"):
    import yfinance as yf
    df = yf.download(ticker, period=period, interval="1d", auto_adjust=True, progress=False)
    if df.empty:
        return np.asarray([], dtype="float32")
    close = df["Close"]
    if isinstance(close, pd.DataFrame):  # handle rare multi-index cases
        close = close.iloc[:, 0]
    return _force_1d(close)

def parse_pasted_series(txt: str):
    import re
    toks = re.split(r"[,\s]+", txt.strip())
    vals = []
    for t in toks:
        if not t:
            continue
        try:
            vals.append(float(t))
        except:
            pass
    return _force_1d(vals)

def load_csv_series(file, column=None):
    df = pd.read_csv(file)
    if column is None:
        num_cols = [c for c in df.columns if np.issubdtype(df[c].dtype, np.number)]
        column = num_cols[0] if num_cols else None
    if column is None:
        return np.asarray([], dtype="float32"), df, None
    return _force_1d(df[column]), df, column

# -------------------- UI --------------------
c1, c2 = st.columns(2)
with c1:
    model_label = st.selectbox("Model", list(MODEL_CHOICES.keys()), index=0)
with c2:
    pred_len = st.number_input("Prediction length (steps)", 1, 365, 30)

src = st.radio("Data source", ["Ticker (yfinance)", "Paste numbers", "Upload CSV"], horizontal=True)

series = None
if src == "Ticker (yfinance)":
    t1, t2 = st.columns([2, 1])
    with t1:
        ticker = st.text_input("Ticker (e.g., AAPL, SPY, BTC-USD)", "AAPL")
    with t2:
        period = st.selectbox("History window", ["6mo", "1y", "2y", "5y"], index=2)
    if st.button("Load data"):
        series = load_ticker_series(ticker.strip(), period)
        if series.size == 0:
            st.error("No data returned. Try another ticker/window.")
elif src == "Paste numbers":
    txt = st.text_area("One value per line (or comma/space separated)", "1\n2\n3\n4\n5\n6\n7\n8\n9\n10")
    if st.button("Use pasted data"):
        series = parse_pasted_series(txt)
else:
    uploaded = st.file_uploader("Upload CSV", type=["csv"])
    if uploaded is not None:
        df = pd.read_csv(uploaded)
        numeric_cols = [c for c in df.columns if np.issubdtype(df[c].dtype, np.number)]
        col = st.selectbox("Pick numeric column", numeric_cols) if numeric_cols else None
        if st.button("Load CSV column") and col:
            series, _, _ = load_csv_series(uploaded, column=col)
        elif uploaded and not numeric_cols:
            st.error("No numeric columns found in CSV.")

# -------------------- Plot + Forecast --------------------
if series is not None and series.size > 5:
    st.write(f"Loaded {series.size} points.")
    st.line_chart(pd.DataFrame(series, columns=["value"]))  # always 1-D -> no error

    if st.button("Forecast"):
        with st.spinner("Running Chronos-Bolt..."):
            pipe = load_pipeline(MODEL_CHOICES[model_label])
            ctx = torch.tensor(series, dtype=torch.float32)
            q_levels = [0.10, 0.50, 0.90]

            quantiles, mean = pipe.predict_quantiles(
                context=ctx,
                prediction_length=int(pred_len),
                quantile_levels=q_levels,
            )

            q_np = quantiles[0].cpu().numpy()  # shape [pred_len, 3]
            lo, med, hi = q_np[:, 0], q_np[:, 1], q_np[:, 2]

            import matplotlib.pyplot as plt
            hist_x = np.arange(len(series))
            fut_x = np.arange(len(series), len(series) + int(pred_len))

            fig = plt.figure(figsize=(9, 4.5))
            plt.plot(hist_x, series, label="history")
            plt.plot(fut_x, med, label="median forecast")
            plt.fill_between(fut_x, lo, hi, alpha=0.3, label="q10–q90 band")
            plt.legend()
            plt.grid(True, alpha=0.3)
            st.pyplot(fig)

            out = pd.DataFrame({"t": fut_x, "q10": lo, "q50": med, "q90": hi})
            st.download_button(
                "Download forecast CSV",
                out.to_csv(index=False).encode("utf-8"),
                file_name="chronos_forecast.csv",
                mime="text/csv",
            )
else:
    st.info("Load a ticker, paste values, or upload a CSV to begin.")

# ================================
# Train with RSI / EMA / Stochastic (AutoGluon) — no pandas-ta
# ================================
with st.expander("Train with Indicators (RSI, EMA, Stochastic)"):
    st.write("Fine-tune Chronos-Bolt on one ticker using indicator covariates (past-only).")
    tcol1, tcol2, tcol3 = st.columns([2, 1, 1])
    with tcol1:
        ft_ticker = st.text_input("Ticker", "SPY")
    with tcol3:
        ft_interval = st.selectbox("Interval", ["1d", "60m", "30m", "15m"], index=0)

    # Allowed lookbacks depend on interval
    if ft_interval == "1d":
        allowed_periods = ["6mo", "1y", "2y", "5y"]
        default_idx = 2
    else:
        allowed_periods = ["5d", "30d", "60d"]
        default_idx = 1
    with tcol2:
        ft_period = st.selectbox("Lookback", allowed_periods, index=default_idx)

    ft_steps = st.slider("Fine-tune steps", 100, 1500, 300, step=50)
    run_ft = st.button("Train fine-tuned model")

    if run_ft:
        with st.spinner("Downloading & computing indicators…"):
            import yfinance as yf
            from autogluon.timeseries import TimeSeriesPredictor, TimeSeriesDataFrame

            # 1) Load OHLC so we can compute Stochastic (needs High/Low/Close)
            df = yf.download(
                ft_ticker.strip(),
                period=ft_period,
                interval=ft_interval,
                auto_adjust=True,
                progress=False,
            )
            # Fallback: if the chosen combo is too long for intraday, clamp and retry
            if df.empty:
                alt_period = "60d" if ft_interval != "1d" else "1y"
                if alt_period != ft_period:
                    df = yf.download(
                        ft_ticker.strip(),
                        period=alt_period,
                        interval=ft_interval,
                        auto_adjust=True,
                        progress=False,
                    )
            if df.empty:
                st.error("No data returned. Try a shorter lookback for intraday (e.g., 30d/60d) or use Interval=1d.")
                st.stop()

            # Determine frequency alias for AutoGluon and ensure tz-naive index
            freq_alias = {"1d": "B", "60m": "60min", "30m": "30min", "15m": "15min"}.get(ft_interval, "B")
            df.index = pd.DatetimeIndex(df.index).tz_localize(None)

            # Handle MultiIndex columns (yfinance can return 2-level columns)
            if isinstance(df.columns, pd.MultiIndex):
                try:
                    sym = df.columns.get_level_values(1).unique()[0]
                    df = df.xs(sym, axis=1, level=1)
                except Exception:
                    # Fallback: flatten by taking the top-level name (Close/High/Low)
                    df.columns = [c[0] for c in df.columns.to_flat_index()]

            # Keep only needed cols
            df = df[["Close", "High", "Low"]].copy()

            # Ensure each column is 1-D (avoid (N,1) arrays)
            for _c in ["Close", "High", "Low"]:
                if isinstance(df[_c], pd.DataFrame):
                    df[_c] = df[_c].iloc[:, 0]
                df[_c] = pd.Series(np.asarray(df[_c]).reshape(-1), index=df.index)

            df = df.dropna()

            # 2) Indicators (helpers above)
            df["rsi14"] = rsi(df["Close"], 14)
            df["ema20"] = ema(df["Close"], 20)
            df["stoch_k"], df["stoch_d"] = stochastic_kd(df["High"], df["Low"], df["Close"], 14, 3, 3)

            df = df.dropna().astype("float32")
            if df.shape[0] < 200:
                st.warning("Very short history after indicators; results may be noisy.")

            # 3) Build TimeSeriesDataFrame (target + past covariates)
            ts = df[["Close", "rsi14", "ema20", "stoch_k", "stoch_d"]].copy()
            ts["item_id"] = ft_ticker.upper()
            ts["timestamp"] = ts.index
            ts = ts.rename(columns={"Close": "target"})

            tsdf = TimeSeriesDataFrame.from_data_frame(
                ts, id_column="item_id", timestamp_column="timestamp"
            )
            # Ensure a regular time grid for AutoGluon
            try:
                tsdf = tsdf.convert_frequency(freq=freq_alias)
            except Exception:
                pass

        with st.spinner("Fine-tuning Chronos-Bolt (small demo)…"):
            # Chronos-Bolt preset via hyperparameters; fine_tune on CPU is OK for small steps
            predictor = TimeSeriesPredictor(
                prediction_length=int(pred_len),              # reuse your UI's pred_len
                eval_metric="WQL",
                quantile_levels=[0.1, 0.5, 0.9],
                freq=freq_alias,
            ).fit(
                train_data=tsdf,
                enable_ensemble=False,
                time_limit=300,  # small demo budget; increase offline/GPU
                hyperparameters={
                    "Chronos": {
                        "model_path": "bolt_mini",       # CPU-friendly; try 'bolt_small' on GPU
                        "fine_tune": True,
                        "fine_tune_steps": int(ft_steps),
                        "fine_tune_lr": 1e-5,
                    }
                },
            )

        # 4) Forecast with the fine-tuned model
        preds = predictor.predict(tsdf)  # AG starts at series end
        item = ft_ticker.upper()
        yhist = tsdf.loc[item]["target"].to_numpy()
        ypred = preds.loc[item]          # MultiIndex -> rows for horizon
        lo = ypred["0.1"].to_numpy()
        med = ypred["0.5"].to_numpy()
        hi = ypred["0.9"].to_numpy()

        import matplotlib.pyplot as plt
        hx = np.arange(len(yhist))
        fx = np.arange(len(yhist), len(yhist) + len(med))

        fig = plt.figure(figsize=(9, 4.5))
        plt.plot(hx, yhist, label="history")
        plt.plot(fx, med, label="median (fine-tuned)")
        plt.fill_between(fx, lo, hi, alpha=0.3, label="q10–q90")
        plt.legend(); plt.grid(True, alpha=0.3)
        st.pyplot(fig)

        out = pd.DataFrame({"t": fx, "q10": lo, "q50": med, "q90": hi})
        st.download_button(
            "Download fine-tuned forecast CSV",
            out.to_csv(index=False).encode("utf-8"),
            file_name=f"{item}_chronos_finetuned.csv",
            mime="text/csv",
        )