ALM_LLM / utils /config.py
AshenH's picture
Update utils/config.py
3d0e99a verified
raw
history blame
7.57 kB
# space/utils/config.py
import os
import logging
from typing import Optional
from dataclasses import dataclass, field
logger = logging.getLogger(__name__)
class ConfigError(Exception):
"""Custom exception for configuration errors."""
pass
@dataclass
class AppConfig:
"""
Application configuration loaded from environment variables.
Includes validation and sensible defaults.
"""
# SQL Backend Configuration
sql_backend: str = "motherduck" # "bigquery" or "motherduck"
gcp_project: Optional[str] = None
motherduck_token: Optional[str] = None
motherduck_db: str = "workspace"
# Model Configuration
hf_model_repo: str = "your-org/your-model"
hf_token: Optional[str] = None
# Tracing Configuration
trace_enabled: bool = True
trace_url: Optional[str] = None
# Feature Flags
enable_forecasting: bool = True
enable_explanations: bool = True
# Performance Settings
max_workers: int = 4
timeout_seconds: int = 300
# Additional settings
log_level: str = "INFO"
def __post_init__(self):
"""Validate configuration after initialization."""
self._validate()
def _validate(self):
"""Validate configuration values."""
# Validate SQL backend
valid_backends = ["bigquery", "motherduck"]
if self.sql_backend not in valid_backends:
raise ConfigError(
f"Invalid sql_backend: {self.sql_backend}. "
f"Must be one of: {valid_backends}"
)
# Validate backend-specific requirements
if self.sql_backend == "bigquery":
if not self.gcp_project:
logger.warning("BigQuery selected but gcp_project not set")
if self.sql_backend == "motherduck":
if not self.motherduck_token:
logger.warning("MotherDuck selected but motherduck_token not set")
# Validate model configuration
if not self.hf_model_repo:
logger.warning("hf_model_repo not set - predictions/explanations will fail")
# Validate numeric settings
if self.max_workers < 1:
raise ConfigError(f"max_workers must be >= 1, got {self.max_workers}")
if self.timeout_seconds < 1:
raise ConfigError(f"timeout_seconds must be >= 1, got {self.timeout_seconds}")
# Validate log level
valid_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
if self.log_level.upper() not in valid_levels:
raise ConfigError(
f"Invalid log_level: {self.log_level}. "
f"Must be one of: {valid_levels}"
)
@classmethod
def from_env(cls) -> "AppConfig":
"""
Create configuration from environment variables.
Environment variables:
SQL_BACKEND: "bigquery" or "motherduck" (default: "motherduck")
GCP_PROJECT: GCP project ID for BigQuery
GCP_SERVICE_ACCOUNT_JSON: Service account credentials for BigQuery
MOTHERDUCK_TOKEN: MotherDuck authentication token
MOTHERDUCK_DB: MotherDuck database name (default: "workspace")
HF_MODEL_REPO: HuggingFace model repository (required)
HF_TOKEN: HuggingFace API token (optional, for private repos)
TRACE_ENABLED: Enable tracing (default: "true")
TRACE_URL: Custom trace URL
ENABLE_FORECASTING: Enable forecasting features (default: "true")
ENABLE_EXPLANATIONS: Enable SHAP explanations (default: "true")
MAX_WORKERS: Max parallel workers (default: 4)
TIMEOUT_SECONDS: Request timeout (default: 300)
LOG_LEVEL: Logging level (default: "INFO")
"""
try:
config = cls(
sql_backend=os.getenv("SQL_BACKEND", "motherduck").lower(),
gcp_project=os.getenv("GCP_PROJECT"),
motherduck_token=os.getenv("MOTHERDUCK_TOKEN"),
motherduck_db=os.getenv("MOTHERDUCK_DB", "workspace"),
hf_model_repo=os.getenv("HF_MODEL_REPO", "your-org/your-model"),
hf_token=os.getenv("HF_TOKEN"),
trace_enabled=os.getenv("TRACE_ENABLED", "true").lower() == "true",
trace_url=os.getenv("TRACE_URL"),
enable_forecasting=os.getenv("ENABLE_FORECASTING", "true").lower() == "true",
enable_explanations=os.getenv("ENABLE_EXPLANATIONS", "true").lower() == "true",
max_workers=int(os.getenv("MAX_WORKERS", "4")),
timeout_seconds=int(os.getenv("TIMEOUT_SECONDS", "300")),
log_level=os.getenv("LOG_LEVEL", "INFO").upper()
)
logger.info("Configuration loaded successfully")
logger.info(f"SQL Backend: {config.sql_backend}")
logger.info(f"Model Repo: {config.hf_model_repo}")
logger.info(f"Forecasting: {'enabled' if config.enable_forecasting else 'disabled'}")
logger.info(f"Explanations: {'enabled' if config.enable_explanations else 'disabled'}")
return config
except ValueError as e:
raise ConfigError(f"Invalid numeric configuration value: {e}") from e
except Exception as e:
raise ConfigError(f"Configuration loading failed: {e}") from e
def to_dict(self) -> dict:
"""Convert configuration to dictionary (for logging/debugging)."""
return {
"sql_backend": self.sql_backend,
"gcp_project": self.gcp_project or "not set",
"motherduck_db": self.motherduck_db,
"hf_model_repo": self.hf_model_repo,
"hf_token_set": bool(self.hf_token),
"trace_enabled": self.trace_enabled,
"enable_forecasting": self.enable_forecasting,
"enable_explanations": self.enable_explanations,
"max_workers": self.max_workers,
"timeout_seconds": self.timeout_seconds,
"log_level": self.log_level
}
def validate_for_features(self, features: list) -> tuple[bool, list]:
"""
Validate configuration supports requested features.
Args:
features: List of feature names to check
Returns:
Tuple of (all_valid, list_of_errors)
"""
errors = []
for feature in features:
if feature == "predict" or feature == "explain":
if not self.hf_model_repo or self.hf_model_repo == "your-org/your-model":
errors.append(f"{feature} requires valid HF_MODEL_REPO")
elif feature == "forecast":
if not self.enable_forecasting:
errors.append("forecasting is disabled (ENABLE_FORECASTING=false)")
elif feature == "explain":
if not self.enable_explanations:
errors.append("explanations are disabled (ENABLE_EXPLANATIONS=false)")
elif feature == "sql":
if self.sql_backend == "bigquery" and not self.gcp_project:
errors.append("BigQuery requires GCP_PROJECT")
elif self.sql_backend == "motherduck" and not self.motherduck_token:
errors.append("MotherDuck requires MOTHERDUCK_TOKEN")
return len(errors) == 0, errors