omniverse1 commited on
Commit
48b8d70
·
verified ·
1 Parent(s): 81179e3

update models 1.5-2

Browse files
Files changed (1) hide show
  1. models.py +3 -4
models.py CHANGED
@@ -1,7 +1,7 @@
1
  import torch
2
  import numpy as np
3
  import pandas as pd
4
- from chronos_forecasting.pipeline import BaseChronosPipeline # Corrected import path
5
  import warnings
6
  warnings.filterwarnings('ignore')
7
 
@@ -39,15 +39,14 @@ def predict_stock_prices(model_pipeline, data, forecast_horizon):
39
  # Konversi data numpy mentah ke tensor float32 (standar PyTorch)
40
  context_tensor = torch.tensor(data, dtype=torch.float32)
41
 
42
- # Chronos memprediksi sebagai sampel kuantil. Kita ambil mediannya (point forecast).
43
  raw_forecasts = model_pipeline.predict(
44
  context=[context_tensor],
45
  prediction_length=forecast_horizon,
46
- num_samples=20 # Mengambil 20 sampel untuk mendapatkan perkiraan median
47
  )
48
 
49
  # Ambil median dari semua sampel (axis=0) untuk mendapatkan point forecast
50
- # raw_forecasts[0] is of shape (num_samples, prediction_length)
51
  point_forecast = np.median(raw_forecasts[0].cpu().numpy(), axis=0)
52
 
53
  return point_forecast
 
1
  import torch
2
  import numpy as np
3
  import pandas as pd
4
+ from chronos import BaseChronosPipeline # <--- PERBAIKAN: Impor langsung dari 'chronos'
5
  import warnings
6
  warnings.filterwarnings('ignore')
7
 
 
39
  # Konversi data numpy mentah ke tensor float32 (standar PyTorch)
40
  context_tensor = torch.tensor(data, dtype=torch.float32)
41
 
42
+ # Chronos Pipeline expects context as a list of tensors
43
  raw_forecasts = model_pipeline.predict(
44
  context=[context_tensor],
45
  prediction_length=forecast_horizon,
46
+ num_samples=20
47
  )
48
 
49
  # Ambil median dari semua sampel (axis=0) untuk mendapatkan point forecast
 
50
  point_forecast = np.median(raw_forecasts[0].cpu().numpy(), axis=0)
51
 
52
  return point_forecast