File size: 4,985 Bytes
852fd6f
94db2b6
 
 
113176c
852fd6f
94db2b6
 
 
 
 
 
 
 
852fd6f
 
 
94db2b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
852fd6f
94db2b6
 
 
 
 
 
 
 
 
 
 
 
113176c
94db2b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113176c
94db2b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113176c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# space/tools/ts_forecast_tool.py
import os
from typing import Optional, Dict

import torch
import pandas as pd

from utils.tracing import Tracer
from utils.config import AppConfig

# We avoid unavailable task-specific heads.
# Use a generic AutoModel and attempt capability-based calls.
from transformers import AutoModel, AutoConfig


class TimeseriesForecastTool:
    """
    Lightweight wrapper around 'ibm-granite/granite-timeseries-ttm-r1' for zero-shot forecasting.

    This wrapper:
      - loads the model with `AutoModel.from_pretrained`
      - checks for a `.predict(...)` method first
      - else tries calling the model with `prediction_length=horizon`
      - returns a Pandas DataFrame with a single 'forecast' column

    Expected input:
      - series: pd.Series with a DatetimeIndex (regular frequency recommended)
      - horizon: int, number of future steps

    NOTE:
      Different library versions expose different APIs. If your environment/model
      lacks a compatible inference method, we raise a clear RuntimeError with
      guidance rather than failing at import time.
    """

    def __init__(
        self,
        cfg: Optional[AppConfig],
        tracer: Optional[Tracer],
        model_id: str = "ibm-granite/granite-timeseries-ttm-r1",
        device: Optional[str] = None,
    ):
        self.cfg = cfg
        self.tracer = tracer
        self.model_id = model_id

        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        # Load config + model generically
        self.config = AutoConfig.from_pretrained(self.model_id)
        self.model = AutoModel.from_pretrained(self.model_id)
        self.model.to(self.device)
        self.model.eval()

    def zeroshot_forecast(self, series: pd.Series, horizon: int = 96) -> pd.DataFrame:
        if not isinstance(series, pd.Series):
            raise ValueError("series must be a pandas Series")
        if series.empty:
            return pd.DataFrame(columns=["forecast"])

        # Ensure numeric tensor
        values = series.astype("float32").to_numpy()
        x = torch.tensor(values, dtype=torch.float32, device=self.device).unsqueeze(0)

        with torch.no_grad():
            # 1) Preferred: explicit .predict API
            if hasattr(self.model, "predict"):
                try:
                    preds = self.model.predict(x, prediction_length=horizon)
                    yhat = preds if isinstance(preds, torch.Tensor) else torch.tensor(preds)
                    out = yhat.squeeze().detach().cpu().numpy()
                    return pd.DataFrame({"forecast": out})
                except Exception as e:
                    raise RuntimeError(
                        f"Granite model has a 'predict' method but it failed at runtime: {e}"
                    )

            # 2) Fallback: call forward with a 'prediction_length' kwarg if supported
            try:
                outputs = self.model(x, prediction_length=horizon)
                # Try common attribute names
                for k in ("predictions", "prediction", "logits", "output"):
                    if hasattr(outputs, k):
                        tensor = getattr(outputs, k)
                        if isinstance(tensor, (tuple, list)):
                            tensor = tensor[0]
                        if not isinstance(tensor, torch.Tensor):
                            tensor = torch.tensor(tensor)
                        out = tensor.squeeze().detach().cpu().numpy()
                        # If multi-dim, take last dimension as forecast
                        if out.ndim > 1:
                            out = out[-1] if out.shape[0] == horizon else out.reshape(-1)
                        return pd.DataFrame({"forecast": out})
                # If outputs is a raw tensor
                if isinstance(outputs, torch.Tensor):
                    out = outputs.squeeze().detach().cpu().numpy()
                    if out.ndim > 1:
                        out = out[-1] if out.shape[0] == horizon else out.reshape(-1)
                    return pd.DataFrame({"forecast": out})
            except TypeError:
                # Some builds may not accept prediction_length at all
                pass
            except Exception as e:
                raise RuntimeError(
                    f"Calling the model forward for forecasting failed: {e}"
                )

        # If we get here, the installed combo doesn't expose an inference entrypoint we can use.
        raise RuntimeError(
            "The installed transformers/model combo does not expose a usable zero-shot "
            "forecasting interface (no `.predict` and forward(...) didn't accept "
            "`prediction_length`). Consider:\n"
            "  • Upgrading transformers/torch versions\n"
            "  • Using the 'granite-tsfm-public' PyPI if available in your region\n"
            "  • Switching to a classic forecaster for now (e.g., ARIMA/XGBoost)\n"
        )