Spaces:
Sleeping
Sleeping
update models 1.5-2
Browse files
models.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
import torch
|
| 2 |
import numpy as np
|
| 3 |
import pandas as pd
|
| 4 |
-
from
|
| 5 |
import warnings
|
| 6 |
warnings.filterwarnings('ignore')
|
| 7 |
|
|
@@ -39,15 +39,14 @@ def predict_stock_prices(model_pipeline, data, forecast_horizon):
|
|
| 39 |
# Konversi data numpy mentah ke tensor float32 (standar PyTorch)
|
| 40 |
context_tensor = torch.tensor(data, dtype=torch.float32)
|
| 41 |
|
| 42 |
-
# Chronos
|
| 43 |
raw_forecasts = model_pipeline.predict(
|
| 44 |
context=[context_tensor],
|
| 45 |
prediction_length=forecast_horizon,
|
| 46 |
-
num_samples=20
|
| 47 |
)
|
| 48 |
|
| 49 |
# Ambil median dari semua sampel (axis=0) untuk mendapatkan point forecast
|
| 50 |
-
# raw_forecasts[0] is of shape (num_samples, prediction_length)
|
| 51 |
point_forecast = np.median(raw_forecasts[0].cpu().numpy(), axis=0)
|
| 52 |
|
| 53 |
return point_forecast
|
|
|
|
| 1 |
import torch
|
| 2 |
import numpy as np
|
| 3 |
import pandas as pd
|
| 4 |
+
from chronos import BaseChronosPipeline # <--- PERBAIKAN: Impor langsung dari 'chronos'
|
| 5 |
import warnings
|
| 6 |
warnings.filterwarnings('ignore')
|
| 7 |
|
|
|
|
| 39 |
# Konversi data numpy mentah ke tensor float32 (standar PyTorch)
|
| 40 |
context_tensor = torch.tensor(data, dtype=torch.float32)
|
| 41 |
|
| 42 |
+
# Chronos Pipeline expects context as a list of tensors
|
| 43 |
raw_forecasts = model_pipeline.predict(
|
| 44 |
context=[context_tensor],
|
| 45 |
prediction_length=forecast_horizon,
|
| 46 |
+
num_samples=20
|
| 47 |
)
|
| 48 |
|
| 49 |
# Ambil median dari semua sampel (axis=0) untuk mendapatkan point forecast
|
|
|
|
| 50 |
point_forecast = np.median(raw_forecasts[0].cpu().numpy(), axis=0)
|
| 51 |
|
| 52 |
return point_forecast
|