|
|
|
|
|
import os |
|
|
from typing import Optional, Dict |
|
|
|
|
|
import torch |
|
|
import pandas as pd |
|
|
|
|
|
from utils.tracing import Tracer |
|
|
from utils.config import AppConfig |
|
|
|
|
|
|
|
|
|
|
|
from transformers import AutoModel, AutoConfig |
|
|
|
|
|
|
|
|
class TimeseriesForecastTool: |
|
|
""" |
|
|
Lightweight wrapper around 'ibm-granite/granite-timeseries-ttm-r1' for zero-shot forecasting. |
|
|
|
|
|
This wrapper: |
|
|
- loads the model with `AutoModel.from_pretrained` |
|
|
- checks for a `.predict(...)` method first |
|
|
- else tries calling the model with `prediction_length=horizon` |
|
|
- returns a Pandas DataFrame with a single 'forecast' column |
|
|
|
|
|
Expected input: |
|
|
- series: pd.Series with a DatetimeIndex (regular frequency recommended) |
|
|
- horizon: int, number of future steps |
|
|
|
|
|
NOTE: |
|
|
Different library versions expose different APIs. If your environment/model |
|
|
lacks a compatible inference method, we raise a clear RuntimeError with |
|
|
guidance rather than failing at import time. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
cfg: Optional[AppConfig], |
|
|
tracer: Optional[Tracer], |
|
|
model_id: str = "ibm-granite/granite-timeseries-ttm-r1", |
|
|
device: Optional[str] = None, |
|
|
): |
|
|
self.cfg = cfg |
|
|
self.tracer = tracer |
|
|
self.model_id = model_id |
|
|
|
|
|
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
self.config = AutoConfig.from_pretrained(self.model_id) |
|
|
self.model = AutoModel.from_pretrained(self.model_id) |
|
|
self.model.to(self.device) |
|
|
self.model.eval() |
|
|
|
|
|
def zeroshot_forecast(self, series: pd.Series, horizon: int = 96) -> pd.DataFrame: |
|
|
if not isinstance(series, pd.Series): |
|
|
raise ValueError("series must be a pandas Series") |
|
|
if series.empty: |
|
|
return pd.DataFrame(columns=["forecast"]) |
|
|
|
|
|
|
|
|
values = series.astype("float32").to_numpy() |
|
|
x = torch.tensor(values, dtype=torch.float32, device=self.device).unsqueeze(0) |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
if hasattr(self.model, "predict"): |
|
|
try: |
|
|
preds = self.model.predict(x, prediction_length=horizon) |
|
|
yhat = preds if isinstance(preds, torch.Tensor) else torch.tensor(preds) |
|
|
out = yhat.squeeze().detach().cpu().numpy() |
|
|
return pd.DataFrame({"forecast": out}) |
|
|
except Exception as e: |
|
|
raise RuntimeError( |
|
|
f"Granite model has a 'predict' method but it failed at runtime: {e}" |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
outputs = self.model(x, prediction_length=horizon) |
|
|
|
|
|
for k in ("predictions", "prediction", "logits", "output"): |
|
|
if hasattr(outputs, k): |
|
|
tensor = getattr(outputs, k) |
|
|
if isinstance(tensor, (tuple, list)): |
|
|
tensor = tensor[0] |
|
|
if not isinstance(tensor, torch.Tensor): |
|
|
tensor = torch.tensor(tensor) |
|
|
out = tensor.squeeze().detach().cpu().numpy() |
|
|
|
|
|
if out.ndim > 1: |
|
|
out = out[-1] if out.shape[0] == horizon else out.reshape(-1) |
|
|
return pd.DataFrame({"forecast": out}) |
|
|
|
|
|
if isinstance(outputs, torch.Tensor): |
|
|
out = outputs.squeeze().detach().cpu().numpy() |
|
|
if out.ndim > 1: |
|
|
out = out[-1] if out.shape[0] == horizon else out.reshape(-1) |
|
|
return pd.DataFrame({"forecast": out}) |
|
|
except TypeError: |
|
|
|
|
|
pass |
|
|
except Exception as e: |
|
|
raise RuntimeError( |
|
|
f"Calling the model forward for forecasting failed: {e}" |
|
|
) |
|
|
|
|
|
|
|
|
raise RuntimeError( |
|
|
"The installed transformers/model combo does not expose a usable zero-shot " |
|
|
"forecasting interface (no `.predict` and forward(...) didn't accept " |
|
|
"`prediction_length`). Consider:\n" |
|
|
" • Upgrading transformers/torch versions\n" |
|
|
" • Using the 'granite-tsfm-public' PyPI if available in your region\n" |
|
|
" • Switching to a classic forecaster for now (e.g., ARIMA/XGBoost)\n" |
|
|
) |
|
|
|