ALM_LLM / tools /ts_forecast_tool.py
AshenH's picture
Update tools/ts_forecast_tool.py
94db2b6 verified
raw
history blame
4.99 kB
# space/tools/ts_forecast_tool.py
import os
from typing import Optional, Dict
import torch
import pandas as pd
from utils.tracing import Tracer
from utils.config import AppConfig
# We avoid unavailable task-specific heads.
# Use a generic AutoModel and attempt capability-based calls.
from transformers import AutoModel, AutoConfig
class TimeseriesForecastTool:
"""
Lightweight wrapper around 'ibm-granite/granite-timeseries-ttm-r1' for zero-shot forecasting.
This wrapper:
- loads the model with `AutoModel.from_pretrained`
- checks for a `.predict(...)` method first
- else tries calling the model with `prediction_length=horizon`
- returns a Pandas DataFrame with a single 'forecast' column
Expected input:
- series: pd.Series with a DatetimeIndex (regular frequency recommended)
- horizon: int, number of future steps
NOTE:
Different library versions expose different APIs. If your environment/model
lacks a compatible inference method, we raise a clear RuntimeError with
guidance rather than failing at import time.
"""
def __init__(
self,
cfg: Optional[AppConfig],
tracer: Optional[Tracer],
model_id: str = "ibm-granite/granite-timeseries-ttm-r1",
device: Optional[str] = None,
):
self.cfg = cfg
self.tracer = tracer
self.model_id = model_id
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
# Load config + model generically
self.config = AutoConfig.from_pretrained(self.model_id)
self.model = AutoModel.from_pretrained(self.model_id)
self.model.to(self.device)
self.model.eval()
def zeroshot_forecast(self, series: pd.Series, horizon: int = 96) -> pd.DataFrame:
if not isinstance(series, pd.Series):
raise ValueError("series must be a pandas Series")
if series.empty:
return pd.DataFrame(columns=["forecast"])
# Ensure numeric tensor
values = series.astype("float32").to_numpy()
x = torch.tensor(values, dtype=torch.float32, device=self.device).unsqueeze(0)
with torch.no_grad():
# 1) Preferred: explicit .predict API
if hasattr(self.model, "predict"):
try:
preds = self.model.predict(x, prediction_length=horizon)
yhat = preds if isinstance(preds, torch.Tensor) else torch.tensor(preds)
out = yhat.squeeze().detach().cpu().numpy()
return pd.DataFrame({"forecast": out})
except Exception as e:
raise RuntimeError(
f"Granite model has a 'predict' method but it failed at runtime: {e}"
)
# 2) Fallback: call forward with a 'prediction_length' kwarg if supported
try:
outputs = self.model(x, prediction_length=horizon)
# Try common attribute names
for k in ("predictions", "prediction", "logits", "output"):
if hasattr(outputs, k):
tensor = getattr(outputs, k)
if isinstance(tensor, (tuple, list)):
tensor = tensor[0]
if not isinstance(tensor, torch.Tensor):
tensor = torch.tensor(tensor)
out = tensor.squeeze().detach().cpu().numpy()
# If multi-dim, take last dimension as forecast
if out.ndim > 1:
out = out[-1] if out.shape[0] == horizon else out.reshape(-1)
return pd.DataFrame({"forecast": out})
# If outputs is a raw tensor
if isinstance(outputs, torch.Tensor):
out = outputs.squeeze().detach().cpu().numpy()
if out.ndim > 1:
out = out[-1] if out.shape[0] == horizon else out.reshape(-1)
return pd.DataFrame({"forecast": out})
except TypeError:
# Some builds may not accept prediction_length at all
pass
except Exception as e:
raise RuntimeError(
f"Calling the model forward for forecasting failed: {e}"
)
# If we get here, the installed combo doesn't expose an inference entrypoint we can use.
raise RuntimeError(
"The installed transformers/model combo does not expose a usable zero-shot "
"forecasting interface (no `.predict` and forward(...) didn't accept "
"`prediction_length`). Consider:\n"
" • Upgrading transformers/torch versions\n"
" • Using the 'granite-tsfm-public' PyPI if available in your region\n"
" • Switching to a classic forecaster for now (e.g., ARIMA/XGBoost)\n"
)