# app/core/inference/client.py from __future__ import annotations """ Unified chat client module. - Exposes a production-ready MultiProvider cascade client (GROQ → Gemini → HF Router), via ChatClient / chat(...). - Keeps the legacy RouterRequestsClient for direct access to the HF Router compatible /v1/chat/completions endpoint, preserving backward compatibility. This file assumes: - app/bootstrap.py exists and loads configs/.env + sets up logging. - app/core/config.py provides Settings (with provider_order, etc.). - app/core/inference/providers.py implements MultiProviderChat orchestrator. """ import os import json import time import logging from typing import Dict, List, Optional, Iterator, Tuple, Iterable, Union, Generator # Ensure .env & logging before we load settings/providers import app.bootstrap # noqa: F401 import requests from app.core.config import Settings from app.core.inference.providers import MultiProviderChat logger = logging.getLogger(__name__) # ----------------------------- # Multi-provider cascade client # ----------------------------- Message = Dict[str, str] class ChatClient: """ Unified chat client that executes the configured provider cascade. Providers are tried in order (settings.provider_order). First success wins. """ def __init__(self, settings: Settings | None = None): self._settings = settings or Settings.load() self._chain = MultiProviderChat(self._settings) def chat( self, messages: Iterable[Message], temperature: Optional[float] = None, max_new_tokens: Optional[int] = None, stream: Optional[bool] = None, ) -> Union[str, Generator[str, None, None]]: """ Execute a chat completion using the provider cascade. Args: messages: Iterable of {"role": "system|user|assistant", "content": "..."} temperature: Optional override for sampling temperature. max_new_tokens: Optional override for max tokens. stream: If None, uses settings.chat_stream. If True, returns a generator of text chunks. Returns: str (non-stream) or generator[str] (stream) """ use_stream = self._settings.chat_stream if stream is None else bool(stream) return self._chain.chat( messages, temperature=temperature, max_new_tokens=max_new_tokens, stream=use_stream, ) # Backward-compatible helpers _default_client: ChatClient | None = None def _get_default() -> ChatClient: global _default_client if _default_client is None: _default_client = ChatClient() return _default_client def chat( messages: Iterable[Message], temperature: Optional[float] = None, max_new_tokens: Optional[int] = None, stream: Optional[bool] = None, ) -> Union[str, Generator[str, None, None]]: """ Convenience function using a process-wide default ChatClient. """ return _get_default().chat(messages, temperature=temperature, max_new_tokens=max_new_tokens, stream=stream) def get_client(settings: Settings | None = None) -> ChatClient: """ Factory for an explicit ChatClient bound to provided settings. """ return ChatClient(settings) # ------------------------------------------------------ # Legacy HF Router client (kept for backward compatibility) # ------------------------------------------------------ ROUTER_URL = "https://router.huggingface.co/v1/chat/completions" def _require_token() -> str: tok = os.getenv("HF_TOKEN") if not tok: raise ValueError("HF_TOKEN is not set. Put it in .env or export it before starting.") return tok def _model_with_provider(model: str, provider: Optional[str]) -> str: if provider and ":" not in model: return f"{model}:{provider}" return model def _mk_messages(system_prompt: Optional[str], user_text: str) -> List[Dict[str, str]]: msgs: List[Dict[str, str]] = [] if system_prompt: msgs.append({"role": "system", "content": system_prompt}) msgs.append({"role": "user", "content": user_text}) return msgs def _timeout_tuple(connect: float = 10.0, read: float = 60.0) -> Tuple[float, float]: return (connect, read) class RouterRequestsClient: """ Simple requests-only client for HF Router Chat Completions. Supports non-streaming (returns str) and streaming (yields token strings). NOTE: New code should prefer ChatClient above. This class is preserved for any legacy call sites that rely on direct HF Router access. """ def __init__( self, model: str, fallback: Optional[str] = None, provider: Optional[str] = None, max_retries: int = 2, connect_timeout: float = 10.0, read_timeout: float = 60.0 ): self.model = model self.fallback = fallback if fallback != model else None self.provider = provider self.headers = {"Authorization": f"Bearer {_require_token()}"} self.max_retries = max(0, int(max_retries)) self.timeout = _timeout_tuple(connect_timeout, read_timeout) # -------- Non-stream (single text) -------- def chat_nonstream( self, system_prompt: Optional[str], user_text: str, max_tokens: int, temperature: float, stop: Optional[List[str]] = None, frequency_penalty: Optional[float] = None, presence_penalty: Optional[float] = None, ) -> str: payload = { "model": _model_with_provider(self.model, self.provider), "messages": _mk_messages(system_prompt, user_text), "temperature": float(max(0.0, temperature)), "max_tokens": int(max_tokens), "stream": False, } if stop: payload["stop"] = stop if frequency_penalty is not None: payload["frequency_penalty"] = float(frequency_penalty) if presence_penalty is not None: payload["presence_penalty"] = float(presence_penalty) text, ok = self._try_once(payload) if ok: return text # fallback (if configured) if self.fallback: payload["model"] = _model_with_provider(self.fallback, self.provider) text, ok = self._try_once(payload) if ok: return text raise RuntimeError(f"Chat non-stream failed: model={self.model} fallback={self.fallback}") def _try_once(self, payload: dict) -> Tuple[str, bool]: last_err: Optional[Exception] = None for attempt in range(self.max_retries + 1): try: r = requests.post(ROUTER_URL, headers=self.headers, json=payload, timeout=self.timeout) if r.status_code >= 400: logger.error("Router error %s: %s", r.status_code, r.text) last_err = RuntimeError(f"{r.status_code}: {r.text}") time.sleep(min(1.5 * (attempt + 1), 3.0)) continue data = r.json() return data["choices"][0]["message"]["content"], True except Exception as e: logger.error("Router request failure: %s", e) last_err = e time.sleep(min(1.5 * (attempt + 1), 3.0)) if last_err: logger.error("Router exhausted retries: %s", last_err) return "", False # -------- Streaming (yield token deltas) -------- def chat_stream( self, system_prompt: Optional[str], user_text: str, max_tokens: int, temperature: float, stop: Optional[List[str]] = None, frequency_penalty: Optional[float] = None, presence_penalty: Optional[float] = None, ) -> Iterator[str]: payload = { "model": _model_with_provider(self.model, self.provider), "messages": _mk_messages(system_prompt, user_text), "temperature": float(max(0.0, temperature)), "max_tokens": int(max_tokens), "stream": True, } if stop: payload["stop"] = stop if frequency_penalty is not None: payload["frequency_penalty"] = float(frequency_penalty) if presence_penalty is not None: payload["presence_penalty"] = float(presence_penalty) # primary ok = False for token in self._stream_once(payload): ok = True yield token if ok: return # fallback stream if primary produced nothing (or died immediately) if self.fallback: payload["model"] = _model_with_provider(self.fallback, self.provider) for token in self._stream_once(payload): yield token def _stream_once(self, payload: dict) -> Iterator[str]: try: with requests.post(ROUTER_URL, headers=self.headers, json=payload, stream=True, timeout=self.timeout) as r: if r.status_code >= 400: logger.error("Router stream error %s: %s", r.status_code, r.text) return for line in r.iter_lines(decode_unicode=True): if not line: continue if not line.startswith("data:"): continue data = line[len("data:"):].strip() if data == "[DONE]": return try: obj = json.loads(data) delta = obj["choices"][0]["delta"].get("content", "") if delta: yield delta except Exception as e: logger.warning("Stream JSON parse issue: %s | line=%r", e, line) continue except Exception as e: logger.error("Stream request failure: %s", e) return # -------- Planning (non-stream) -------- def plan_nonstream(self, system_prompt: str, user_text: str, max_tokens: int, temperature: float) -> str: return self.chat_nonstream(system_prompt, user_text, max_tokens, temperature) __all__ = [ "ChatClient", "chat", "get_client", "RouterRequestsClient", ]