# space/tools/ts_forecast_tool.py import os from typing import Optional, Dict import torch import pandas as pd from utils.tracing import Tracer from utils.config import AppConfig # We avoid unavailable task-specific heads. # Use a generic AutoModel and attempt capability-based calls. from transformers import AutoModel, AutoConfig class TimeseriesForecastTool: """ Lightweight wrapper around 'ibm-granite/granite-timeseries-ttm-r1' for zero-shot forecasting. This wrapper: - loads the model with `AutoModel.from_pretrained` - checks for a `.predict(...)` method first - else tries calling the model with `prediction_length=horizon` - returns a Pandas DataFrame with a single 'forecast' column Expected input: - series: pd.Series with a DatetimeIndex (regular frequency recommended) - horizon: int, number of future steps NOTE: Different library versions expose different APIs. If your environment/model lacks a compatible inference method, we raise a clear RuntimeError with guidance rather than failing at import time. """ def __init__( self, cfg: Optional[AppConfig], tracer: Optional[Tracer], model_id: str = "ibm-granite/granite-timeseries-ttm-r1", device: Optional[str] = None, ): self.cfg = cfg self.tracer = tracer self.model_id = model_id self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") # Load config + model generically self.config = AutoConfig.from_pretrained(self.model_id) self.model = AutoModel.from_pretrained(self.model_id) self.model.to(self.device) self.model.eval() def zeroshot_forecast(self, series: pd.Series, horizon: int = 96) -> pd.DataFrame: if not isinstance(series, pd.Series): raise ValueError("series must be a pandas Series") if series.empty: return pd.DataFrame(columns=["forecast"]) # Ensure numeric tensor values = series.astype("float32").to_numpy() x = torch.tensor(values, dtype=torch.float32, device=self.device).unsqueeze(0) with torch.no_grad(): # 1) Preferred: explicit .predict API if hasattr(self.model, "predict"): try: preds = self.model.predict(x, prediction_length=horizon) yhat = preds if isinstance(preds, torch.Tensor) else torch.tensor(preds) out = yhat.squeeze().detach().cpu().numpy() return pd.DataFrame({"forecast": out}) except Exception as e: raise RuntimeError( f"Granite model has a 'predict' method but it failed at runtime: {e}" ) # 2) Fallback: call forward with a 'prediction_length' kwarg if supported try: outputs = self.model(x, prediction_length=horizon) # Try common attribute names for k in ("predictions", "prediction", "logits", "output"): if hasattr(outputs, k): tensor = getattr(outputs, k) if isinstance(tensor, (tuple, list)): tensor = tensor[0] if not isinstance(tensor, torch.Tensor): tensor = torch.tensor(tensor) out = tensor.squeeze().detach().cpu().numpy() # If multi-dim, take last dimension as forecast if out.ndim > 1: out = out[-1] if out.shape[0] == horizon else out.reshape(-1) return pd.DataFrame({"forecast": out}) # If outputs is a raw tensor if isinstance(outputs, torch.Tensor): out = outputs.squeeze().detach().cpu().numpy() if out.ndim > 1: out = out[-1] if out.shape[0] == horizon else out.reshape(-1) return pd.DataFrame({"forecast": out}) except TypeError: # Some builds may not accept prediction_length at all pass except Exception as e: raise RuntimeError( f"Calling the model forward for forecasting failed: {e}" ) # If we get here, the installed combo doesn't expose an inference entrypoint we can use. raise RuntimeError( "The installed transformers/model combo does not expose a usable zero-shot " "forecasting interface (no `.predict` and forward(...) didn't accept " "`prediction_length`). Consider:\n" " • Upgrading transformers/torch versions\n" " • Using the 'granite-tsfm-public' PyPI if available in your region\n" " • Switching to a classic forecaster for now (e.g., ARIMA/XGBoost)\n" )