Maharshi Gor
commited on
Commit
·
3a1af80
1
Parent(s):
4f5d1cb
Adds support for caching llm calls to a sqlite db and a hf dataset. Refactors repo creation logic and fixes unused temperature param.
Browse files- check_repos.py +14 -15
- src/envs.py +3 -0
- src/workflows/executors.py +1 -0
- src/workflows/llmcache.py +479 -0
- src/workflows/llms.py +123 -16
check_repos.py
CHANGED
|
@@ -1,26 +1,25 @@
|
|
| 1 |
from huggingface_hub import HfApi
|
| 2 |
|
| 3 |
-
from src.envs import QUEUE_REPO, RESULTS_REPO, TOKEN
|
| 4 |
|
| 5 |
|
| 6 |
-
def
|
| 7 |
api = HfApi(token=TOKEN)
|
| 8 |
-
|
| 9 |
-
# Check and create queue repo
|
| 10 |
try:
|
| 11 |
-
api.repo_info(repo_id=
|
| 12 |
-
print(f"
|
| 13 |
except Exception:
|
| 14 |
-
print(f"Creating
|
| 15 |
-
api.create_repo(repo_id=
|
| 16 |
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
|
|
|
| 24 |
|
| 25 |
|
| 26 |
if __name__ == "__main__":
|
|
|
|
| 1 |
from huggingface_hub import HfApi
|
| 2 |
|
| 3 |
+
from src.envs import LLM_CACHE_REPO, QUEUE_REPO, RESULTS_REPO, TOKEN
|
| 4 |
|
| 5 |
|
| 6 |
+
def check_and_create_dataset_repo(repo_id: str):
|
| 7 |
api = HfApi(token=TOKEN)
|
|
|
|
|
|
|
| 8 |
try:
|
| 9 |
+
api.repo_info(repo_id=repo_id, repo_type="dataset")
|
| 10 |
+
print(f"{repo_id} exists")
|
| 11 |
except Exception:
|
| 12 |
+
print(f"Creating {repo_id}")
|
| 13 |
+
api.create_repo(repo_id=repo_id, repo_type="dataset", exist_ok=True, private=True)
|
| 14 |
|
| 15 |
+
|
| 16 |
+
def check_and_create_repos():
|
| 17 |
+
print("1. QUEUE Repository")
|
| 18 |
+
check_and_create_dataset_repo(QUEUE_REPO)
|
| 19 |
+
print("2. RESULTS Repository")
|
| 20 |
+
check_and_create_dataset_repo(RESULTS_REPO)
|
| 21 |
+
print("3. LLM Cache Repository")
|
| 22 |
+
check_and_create_dataset_repo(LLM_CACHE_REPO)
|
| 23 |
|
| 24 |
|
| 25 |
if __name__ == "__main__":
|
src/envs.py
CHANGED
|
@@ -15,6 +15,7 @@ OWNER = "umdclip"
|
|
| 15 |
REPO_ID = f"{OWNER}/quizbowl-submission"
|
| 16 |
QUEUE_REPO = f"{OWNER}/advcal-requests"
|
| 17 |
RESULTS_REPO = f"{OWNER}/model-results" # TODO: change to advcal-results after testing is done
|
|
|
|
| 18 |
|
| 19 |
EXAMPLES_PATH = "examples"
|
| 20 |
|
|
@@ -29,12 +30,14 @@ PLAYGROUND_DATASET_NAMES = {
|
|
| 29 |
CACHE_PATH = os.getenv("HF_HOME", ".")
|
| 30 |
|
| 31 |
# Local caches
|
|
|
|
| 32 |
EVAL_REQUESTS_PATH = os.path.join(CACHE_PATH, "eval-queue")
|
| 33 |
EVAL_RESULTS_PATH = os.path.join(CACHE_PATH, "eval-results")
|
| 34 |
EVAL_REQUESTS_PATH_BACKEND = os.path.join(CACHE_PATH, "eval-queue-bk")
|
| 35 |
EVAL_RESULTS_PATH_BACKEND = os.path.join(CACHE_PATH, "eval-results-bk")
|
| 36 |
|
| 37 |
|
|
|
|
| 38 |
SERVER_REFRESH_INTERVAL = 86400 # seconds (one day)
|
| 39 |
LEADERBOARD_REFRESH_INTERVAL = 600 # seconds (10 minutes)
|
| 40 |
|
|
|
|
| 15 |
REPO_ID = f"{OWNER}/quizbowl-submission"
|
| 16 |
QUEUE_REPO = f"{OWNER}/advcal-requests"
|
| 17 |
RESULTS_REPO = f"{OWNER}/model-results" # TODO: change to advcal-results after testing is done
|
| 18 |
+
LLM_CACHE_REPO = f"{OWNER}/advcal-llm-cache"
|
| 19 |
|
| 20 |
EXAMPLES_PATH = "examples"
|
| 21 |
|
|
|
|
| 30 |
CACHE_PATH = os.getenv("HF_HOME", ".")
|
| 31 |
|
| 32 |
# Local caches
|
| 33 |
+
LLM_CACHE_PATH = os.path.join(CACHE_PATH, "llm-cache")
|
| 34 |
EVAL_REQUESTS_PATH = os.path.join(CACHE_PATH, "eval-queue")
|
| 35 |
EVAL_RESULTS_PATH = os.path.join(CACHE_PATH, "eval-results")
|
| 36 |
EVAL_REQUESTS_PATH_BACKEND = os.path.join(CACHE_PATH, "eval-queue-bk")
|
| 37 |
EVAL_RESULTS_PATH_BACKEND = os.path.join(CACHE_PATH, "eval-results-bk")
|
| 38 |
|
| 39 |
|
| 40 |
+
LLM_CACHE_REFRESH_INTERVAL = 600 # seconds (30 minutes)
|
| 41 |
SERVER_REFRESH_INTERVAL = 86400 # seconds (one day)
|
| 42 |
LEADERBOARD_REFRESH_INTERVAL = 600 # seconds (10 minutes)
|
| 43 |
|
src/workflows/executors.py
CHANGED
|
@@ -221,6 +221,7 @@ def execute_model_step(
|
|
| 221 |
system=model_step.system_prompt,
|
| 222 |
prompt=step_result,
|
| 223 |
response_format=ModelResponse,
|
|
|
|
| 224 |
logprobs=logprobs,
|
| 225 |
)
|
| 226 |
|
|
|
|
| 221 |
system=model_step.system_prompt,
|
| 222 |
prompt=step_result,
|
| 223 |
response_format=ModelResponse,
|
| 224 |
+
temperature=model_step.temperature,
|
| 225 |
logprobs=logprobs,
|
| 226 |
)
|
| 227 |
|
src/workflows/llmcache.py
ADDED
|
@@ -0,0 +1,479 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import hashlib
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
import sqlite3
|
| 5 |
+
import threading
|
| 6 |
+
import time
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Any, Optional
|
| 9 |
+
|
| 10 |
+
from datasets import Dataset, Features, Value
|
| 11 |
+
from huggingface_hub import snapshot_download
|
| 12 |
+
from loguru import logger
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def load_dataset_from_hf(repo_id, local_dir):
|
| 16 |
+
snapshot_download(
|
| 17 |
+
repo_id=repo_id,
|
| 18 |
+
local_dir=local_dir,
|
| 19 |
+
repo_type="dataset",
|
| 20 |
+
tqdm_class=None,
|
| 21 |
+
etag_timeout=30,
|
| 22 |
+
token=os.environ["HF_TOKEN"],
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class CacheDB:
|
| 27 |
+
"""Handles database operations for storing and retrieving cache entries."""
|
| 28 |
+
|
| 29 |
+
def __init__(self, db_path: Path):
|
| 30 |
+
"""Initialize database connection.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
db_path: Path to SQLite database file
|
| 34 |
+
"""
|
| 35 |
+
self.db_path = db_path
|
| 36 |
+
self.lock = threading.Lock()
|
| 37 |
+
|
| 38 |
+
# Initialize the database
|
| 39 |
+
try:
|
| 40 |
+
self.initialize_db()
|
| 41 |
+
except Exception as e:
|
| 42 |
+
logger.exception(f"Failed to initialize database: {e}")
|
| 43 |
+
logger.warning(f"Please provide a different filepath or remove the file at {self.db_path}")
|
| 44 |
+
raise
|
| 45 |
+
|
| 46 |
+
def initialize_db(self) -> None:
|
| 47 |
+
"""Initialize SQLite database with the required table."""
|
| 48 |
+
# Check if database file already exists
|
| 49 |
+
if self.db_path.exists():
|
| 50 |
+
self._verify_existing_db()
|
| 51 |
+
else:
|
| 52 |
+
self._create_new_db()
|
| 53 |
+
|
| 54 |
+
def _verify_existing_db(self) -> None:
|
| 55 |
+
"""Verify and repair an existing database if needed."""
|
| 56 |
+
try:
|
| 57 |
+
with sqlite3.connect(self.db_path) as conn:
|
| 58 |
+
cursor = conn.cursor()
|
| 59 |
+
self._ensure_table_exists(cursor)
|
| 60 |
+
self._verify_schema(cursor)
|
| 61 |
+
self._ensure_index_exists(cursor)
|
| 62 |
+
conn.commit()
|
| 63 |
+
logger.info(f"Using existing SQLite database at {self.db_path}")
|
| 64 |
+
except Exception as e:
|
| 65 |
+
logger.exception(f"Database corruption detected: {e}")
|
| 66 |
+
raise ValueError(f"Corrupted database at {self.db_path}: {str(e)}")
|
| 67 |
+
|
| 68 |
+
def _create_new_db(self) -> None:
|
| 69 |
+
"""Create a new database with the required schema."""
|
| 70 |
+
try:
|
| 71 |
+
with sqlite3.connect(self.db_path) as conn:
|
| 72 |
+
cursor = conn.cursor()
|
| 73 |
+
self._create_table(cursor)
|
| 74 |
+
self._ensure_index_exists(cursor)
|
| 75 |
+
conn.commit()
|
| 76 |
+
logger.info(f"Initialized new SQLite database at {self.db_path}")
|
| 77 |
+
except Exception as e:
|
| 78 |
+
logger.exception(f"Failed to initialize SQLite database: {e}")
|
| 79 |
+
raise
|
| 80 |
+
|
| 81 |
+
def _ensure_table_exists(self, cursor) -> None:
|
| 82 |
+
"""Check if the llm_cache table exists and create it if not."""
|
| 83 |
+
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='llm_cache'")
|
| 84 |
+
if not cursor.fetchone():
|
| 85 |
+
self._create_table(cursor)
|
| 86 |
+
logger.info("Created missing llm_cache table")
|
| 87 |
+
|
| 88 |
+
def _create_table(self, cursor) -> None:
|
| 89 |
+
"""Create the llm_cache table with the required schema."""
|
| 90 |
+
cursor.execute("""
|
| 91 |
+
CREATE TABLE IF NOT EXISTS llm_cache (
|
| 92 |
+
key TEXT PRIMARY KEY,
|
| 93 |
+
request_json TEXT,
|
| 94 |
+
response_json TEXT,
|
| 95 |
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
| 96 |
+
)
|
| 97 |
+
""")
|
| 98 |
+
|
| 99 |
+
def _verify_schema(self, cursor) -> None:
|
| 100 |
+
"""Verify that the table schema has all required columns."""
|
| 101 |
+
cursor.execute("PRAGMA table_info(llm_cache)")
|
| 102 |
+
columns = {row[1] for row in cursor.fetchall()}
|
| 103 |
+
required_columns = {"key", "request_json", "response_json", "created_at"}
|
| 104 |
+
|
| 105 |
+
if not required_columns.issubset(columns):
|
| 106 |
+
missing = required_columns - columns
|
| 107 |
+
raise ValueError(f"Database schema is corrupted. Missing columns: {missing}")
|
| 108 |
+
|
| 109 |
+
def _ensure_index_exists(self, cursor) -> None:
|
| 110 |
+
"""Create an index on the key column for faster lookups."""
|
| 111 |
+
cursor.execute("CREATE INDEX IF NOT EXISTS idx_llm_cache_key ON llm_cache (key)")
|
| 112 |
+
|
| 113 |
+
def get(self, key: str) -> Optional[dict[str, Any]]:
|
| 114 |
+
"""Get cached entry by key.
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
key: Cache key to look up
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
Dict containing the request and response or None if not found
|
| 121 |
+
"""
|
| 122 |
+
try:
|
| 123 |
+
with sqlite3.connect(self.db_path) as conn:
|
| 124 |
+
conn.row_factory = sqlite3.Row
|
| 125 |
+
cursor = conn.cursor()
|
| 126 |
+
cursor.execute("SELECT request_json, response_json FROM llm_cache WHERE key = ?", (key,))
|
| 127 |
+
result = cursor.fetchone()
|
| 128 |
+
|
| 129 |
+
if result:
|
| 130 |
+
logger.debug(f"Cache hit for key: {key}. Response: {result['response_json']}")
|
| 131 |
+
return {
|
| 132 |
+
"request": result["request_json"],
|
| 133 |
+
"response": result["response_json"],
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
logger.debug(f"Cache miss for key: {key}")
|
| 137 |
+
return None
|
| 138 |
+
except Exception as e:
|
| 139 |
+
logger.error(f"Error retrieving from cache: {e}")
|
| 140 |
+
return None
|
| 141 |
+
|
| 142 |
+
def set(self, key: str, request_json: str, response_json: str) -> bool:
|
| 143 |
+
"""Set entry in cache.
|
| 144 |
+
|
| 145 |
+
Args:
|
| 146 |
+
key: Cache key
|
| 147 |
+
request_json: JSON string of request parameters
|
| 148 |
+
response_json: JSON string of response
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
True if successful, False otherwise
|
| 152 |
+
"""
|
| 153 |
+
with self.lock:
|
| 154 |
+
try:
|
| 155 |
+
with sqlite3.connect(self.db_path) as conn:
|
| 156 |
+
cursor = conn.cursor()
|
| 157 |
+
cursor.execute(
|
| 158 |
+
"INSERT OR REPLACE INTO llm_cache (key, request_json, response_json) VALUES (?, ?, ?)",
|
| 159 |
+
(key, request_json, response_json),
|
| 160 |
+
)
|
| 161 |
+
conn.commit()
|
| 162 |
+
logger.debug(f"Saved response to cache with key: {key}, response: {response_json}")
|
| 163 |
+
return True
|
| 164 |
+
except Exception as e:
|
| 165 |
+
logger.error(f"Failed to save to SQLite cache: {e}")
|
| 166 |
+
return False
|
| 167 |
+
|
| 168 |
+
def get_all_entries(self) -> dict[str, dict[str, Any]]:
|
| 169 |
+
"""Get all cache entries from the database."""
|
| 170 |
+
cache = {}
|
| 171 |
+
try:
|
| 172 |
+
with sqlite3.connect(self.db_path) as conn:
|
| 173 |
+
conn.row_factory = sqlite3.Row
|
| 174 |
+
cursor = conn.cursor()
|
| 175 |
+
cursor.execute("SELECT key, request_json, response_json FROM llm_cache ORDER BY created_at")
|
| 176 |
+
|
| 177 |
+
for row in cursor.fetchall():
|
| 178 |
+
cache[row["key"]] = {
|
| 179 |
+
"request": row["request_json"],
|
| 180 |
+
"response": row["response_json"],
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
logger.debug(f"Retrieved {len(cache)} entries from cache database")
|
| 184 |
+
return cache
|
| 185 |
+
except Exception as e:
|
| 186 |
+
logger.error(f"Error retrieving all cache entries: {e}")
|
| 187 |
+
return {}
|
| 188 |
+
|
| 189 |
+
def clear(self) -> bool:
|
| 190 |
+
"""Clear all cache entries.
|
| 191 |
+
|
| 192 |
+
Returns:
|
| 193 |
+
True if successful, False otherwise
|
| 194 |
+
"""
|
| 195 |
+
with self.lock:
|
| 196 |
+
try:
|
| 197 |
+
with sqlite3.connect(self.db_path) as conn:
|
| 198 |
+
cursor = conn.cursor()
|
| 199 |
+
cursor.execute("DELETE FROM llm_cache")
|
| 200 |
+
conn.commit()
|
| 201 |
+
logger.info("Cache cleared")
|
| 202 |
+
return True
|
| 203 |
+
except Exception as e:
|
| 204 |
+
logger.error(f"Failed to clear cache: {e}")
|
| 205 |
+
return False
|
| 206 |
+
|
| 207 |
+
def get_existing_keys(self) -> set:
|
| 208 |
+
"""Get all existing keys in the database.
|
| 209 |
+
|
| 210 |
+
Returns:
|
| 211 |
+
Set of keys
|
| 212 |
+
"""
|
| 213 |
+
existing_keys = set()
|
| 214 |
+
try:
|
| 215 |
+
with sqlite3.connect(self.db_path) as conn:
|
| 216 |
+
cursor = conn.cursor()
|
| 217 |
+
cursor.execute("SELECT key FROM llm_cache")
|
| 218 |
+
for row in cursor.fetchall():
|
| 219 |
+
existing_keys.add(row[0])
|
| 220 |
+
return existing_keys
|
| 221 |
+
except Exception as e:
|
| 222 |
+
logger.error(f"Error retrieving existing keys: {e}")
|
| 223 |
+
return set()
|
| 224 |
+
|
| 225 |
+
def bulk_insert(self, items: list, update: bool = False) -> int:
|
| 226 |
+
"""Insert multiple items into the cache.
|
| 227 |
+
|
| 228 |
+
Args:
|
| 229 |
+
items: List of (key, request_json, response_json) tuples
|
| 230 |
+
update: Whether to update existing entries
|
| 231 |
+
|
| 232 |
+
Returns:
|
| 233 |
+
Number of items inserted
|
| 234 |
+
"""
|
| 235 |
+
count = 0
|
| 236 |
+
UPDATE_OR_IGNORE = "INSERT OR REPLACE" if update else "INSERT OR IGNORE"
|
| 237 |
+
with self.lock:
|
| 238 |
+
try:
|
| 239 |
+
with sqlite3.connect(self.db_path) as conn:
|
| 240 |
+
cursor = conn.cursor()
|
| 241 |
+
cursor.executemany(
|
| 242 |
+
f"{UPDATE_OR_IGNORE} INTO llm_cache (key, request_json, response_json) VALUES (?, ?, ?)",
|
| 243 |
+
items,
|
| 244 |
+
)
|
| 245 |
+
count = cursor.rowcount
|
| 246 |
+
conn.commit()
|
| 247 |
+
return count
|
| 248 |
+
except Exception as e:
|
| 249 |
+
logger.error(f"Error during bulk insert: {e}")
|
| 250 |
+
return 0
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
class LLMCache:
|
| 254 |
+
def __init__(
|
| 255 |
+
self, cache_dir: str = ".", hf_repo: str | None = None, cache_sync_interval: int = 3600, reset: bool = False
|
| 256 |
+
):
|
| 257 |
+
self.cache_dir = Path(cache_dir)
|
| 258 |
+
self.db_path = self.cache_dir / "llm_cache.db"
|
| 259 |
+
self.hf_repo_id = hf_repo
|
| 260 |
+
self.cache_sync_interval = cache_sync_interval
|
| 261 |
+
self.last_sync_time = time.time()
|
| 262 |
+
|
| 263 |
+
# Create cache directory if it doesn't exist
|
| 264 |
+
self.cache_dir.mkdir(exist_ok=True, parents=True)
|
| 265 |
+
|
| 266 |
+
# Initialize CacheDB
|
| 267 |
+
self.db = CacheDB(self.db_path)
|
| 268 |
+
if reset:
|
| 269 |
+
self.db.clear()
|
| 270 |
+
|
| 271 |
+
# Try to load from HF dataset if available
|
| 272 |
+
try:
|
| 273 |
+
self._load_cache_from_hf()
|
| 274 |
+
except Exception as e:
|
| 275 |
+
logger.warning(f"Failed to load cache from HF dataset: {e}")
|
| 276 |
+
|
| 277 |
+
def response_format_to_dict(self, response_format: Any) -> dict[str, Any]:
|
| 278 |
+
"""Convert a response format to a dict."""
|
| 279 |
+
# If it's a Pydantic model, use its schema
|
| 280 |
+
if hasattr(response_format, "model_json_schema"):
|
| 281 |
+
response_format = response_format.model_json_schema()
|
| 282 |
+
|
| 283 |
+
# If it's a Pydantic model, use its dump
|
| 284 |
+
elif hasattr(response_format, "model_dump"):
|
| 285 |
+
response_format = response_format.model_dump()
|
| 286 |
+
|
| 287 |
+
if not isinstance(response_format, dict):
|
| 288 |
+
response_format = {"value": str(response_format)}
|
| 289 |
+
|
| 290 |
+
return response_format
|
| 291 |
+
|
| 292 |
+
def _generate_key(
|
| 293 |
+
self, model: str, system: str, prompt: str, response_format: Any, temperature: float | None = None
|
| 294 |
+
) -> str:
|
| 295 |
+
"""Generate a unique key for caching based on inputs."""
|
| 296 |
+
response_format_dict = self.response_format_to_dict(response_format)
|
| 297 |
+
response_format_str = json.dumps(response_format_dict, sort_keys=True)
|
| 298 |
+
# Include temperature in the key
|
| 299 |
+
key_content = f"{model}:{system}:{prompt}:{response_format_str}"
|
| 300 |
+
if temperature is not None:
|
| 301 |
+
key_content += f":{temperature:.2f}"
|
| 302 |
+
return hashlib.md5(key_content.encode()).hexdigest()
|
| 303 |
+
|
| 304 |
+
def _create_request_json(
|
| 305 |
+
self, model: str, system: str, prompt: str, response_format: Any, temperature: float | None
|
| 306 |
+
) -> str:
|
| 307 |
+
"""Create JSON string from request parameters."""
|
| 308 |
+
logger.info(f"Creating request JSON with temperature: {temperature}")
|
| 309 |
+
request_data = {
|
| 310 |
+
"model": model,
|
| 311 |
+
"system": system,
|
| 312 |
+
"prompt": prompt,
|
| 313 |
+
"response_format": self.response_format_to_dict(response_format),
|
| 314 |
+
"temperature": temperature,
|
| 315 |
+
}
|
| 316 |
+
return json.dumps(request_data)
|
| 317 |
+
|
| 318 |
+
def _check_request_match(
|
| 319 |
+
self,
|
| 320 |
+
cached_request: dict[str, Any],
|
| 321 |
+
model: str,
|
| 322 |
+
system: str,
|
| 323 |
+
prompt: str,
|
| 324 |
+
response_format: Any,
|
| 325 |
+
temperature: float | None,
|
| 326 |
+
) -> bool:
|
| 327 |
+
"""Check if the cached request matches the new request."""
|
| 328 |
+
# Check each field and log any mismatches
|
| 329 |
+
if cached_request["model"] != model:
|
| 330 |
+
logger.debug(f"Cache mismatch: model - cached: {cached_request['model']}, new: {model}")
|
| 331 |
+
return False
|
| 332 |
+
if cached_request["system"] != system:
|
| 333 |
+
logger.debug(f"Cache mismatch: system - cached: {cached_request['system']}, new: {system}")
|
| 334 |
+
return False
|
| 335 |
+
if cached_request["prompt"] != prompt:
|
| 336 |
+
logger.debug(f"Cache mismatch: prompt - cached: {cached_request['prompt']}, new: {prompt}")
|
| 337 |
+
return False
|
| 338 |
+
response_format_dict = self.response_format_to_dict(response_format)
|
| 339 |
+
if cached_request["response_format"] != response_format_dict:
|
| 340 |
+
logger.debug(
|
| 341 |
+
f"Cache mismatch: response_format - cached: {cached_request['response_format']}, new: {response_format_dict}"
|
| 342 |
+
)
|
| 343 |
+
return False
|
| 344 |
+
if cached_request["temperature"] != temperature:
|
| 345 |
+
logger.debug(f"Cache mismatch: temperature - cached: {cached_request['temperature']}, new: {temperature}")
|
| 346 |
+
return False
|
| 347 |
+
|
| 348 |
+
return True
|
| 349 |
+
|
| 350 |
+
def get(
|
| 351 |
+
self, model: str, system: str, prompt: str, response_format: dict[str, Any], temperature: float | None = None
|
| 352 |
+
) -> Optional[dict[str, Any]]:
|
| 353 |
+
"""Get cached response if it exists."""
|
| 354 |
+
key = self._generate_key(model, system, prompt, response_format, temperature)
|
| 355 |
+
result = self.db.get(key)
|
| 356 |
+
|
| 357 |
+
if not result:
|
| 358 |
+
return None
|
| 359 |
+
request_dict = json.loads(result["request"])
|
| 360 |
+
if not self._check_request_match(request_dict, model, system, prompt, response_format, temperature):
|
| 361 |
+
logger.warning(f"Cached request does not match new request for key: {key}")
|
| 362 |
+
return None
|
| 363 |
+
|
| 364 |
+
return json.loads(result["response"])
|
| 365 |
+
|
| 366 |
+
def set(
|
| 367 |
+
self,
|
| 368 |
+
model: str,
|
| 369 |
+
system: str,
|
| 370 |
+
prompt: str,
|
| 371 |
+
response_format: dict[str, Any],
|
| 372 |
+
temperature: float | None,
|
| 373 |
+
response: dict[str, Any],
|
| 374 |
+
) -> None:
|
| 375 |
+
"""Set response in cache and sync if needed."""
|
| 376 |
+
key = self._generate_key(model, system, prompt, response_format, temperature)
|
| 377 |
+
request_json = self._create_request_json(model, system, prompt, response_format, temperature)
|
| 378 |
+
response_json = json.dumps(response)
|
| 379 |
+
|
| 380 |
+
success = self.db.set(key, request_json, response_json)
|
| 381 |
+
|
| 382 |
+
# Check if we should sync to HF
|
| 383 |
+
if success and self.hf_repo_id and (time.time() - self.last_sync_time > self.cache_sync_interval):
|
| 384 |
+
try:
|
| 385 |
+
self.sync_to_hf()
|
| 386 |
+
self.last_sync_time = time.time()
|
| 387 |
+
except Exception as e:
|
| 388 |
+
logger.error(f"Failed to sync cache to HF dataset: {e}")
|
| 389 |
+
|
| 390 |
+
def _load_cache_from_hf(self) -> None:
|
| 391 |
+
"""Load cache from HF dataset if it exists and merge with local cache."""
|
| 392 |
+
if not self.hf_repo_id:
|
| 393 |
+
return
|
| 394 |
+
|
| 395 |
+
try:
|
| 396 |
+
# Check for new commits before loading the dataset
|
| 397 |
+
dataset = load_dataset_from_hf(self.hf_repo_id, self.cache_dir / "hf_cache")
|
| 398 |
+
if dataset:
|
| 399 |
+
existing_keys = self.db.get_existing_keys()
|
| 400 |
+
|
| 401 |
+
# Prepare batch items for insertion
|
| 402 |
+
items_to_insert = []
|
| 403 |
+
for item in dataset:
|
| 404 |
+
key = item["key"]
|
| 405 |
+
# Only update if not in local cache to prioritize local changes
|
| 406 |
+
if key in existing_keys:
|
| 407 |
+
continue
|
| 408 |
+
# Create request JSON
|
| 409 |
+
request_data = {
|
| 410 |
+
"model": item["model"],
|
| 411 |
+
"system": item["system"],
|
| 412 |
+
"prompt": item["prompt"],
|
| 413 |
+
"temperature": item["temperature"],
|
| 414 |
+
"response_format": None, # We can't fully reconstruct this
|
| 415 |
+
}
|
| 416 |
+
|
| 417 |
+
items_to_insert.append(
|
| 418 |
+
(
|
| 419 |
+
key,
|
| 420 |
+
json.dumps(request_data),
|
| 421 |
+
item["response"], # This is already a JSON string
|
| 422 |
+
)
|
| 423 |
+
)
|
| 424 |
+
logger.info(
|
| 425 |
+
f"Inserting item: {key} with temperature: {item['temperature']} and response: {item['response']}"
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
# Bulk insert new items
|
| 429 |
+
if items_to_insert:
|
| 430 |
+
inserted_count = self.db.bulk_insert(items_to_insert)
|
| 431 |
+
logger.info(f"Merged {inserted_count} items from HF dataset into SQLite cache")
|
| 432 |
+
else:
|
| 433 |
+
logger.info("No new items to merge from HF dataset")
|
| 434 |
+
except Exception as e:
|
| 435 |
+
logger.warning(f"Could not load cache from HF dataset: {e}")
|
| 436 |
+
|
| 437 |
+
def get_all_entries(self) -> dict[str, dict[str, Any]]:
|
| 438 |
+
"""Get all cache entries from the database."""
|
| 439 |
+
cache = self.db.get_all_entries()
|
| 440 |
+
entries = {}
|
| 441 |
+
for key, entry in cache.items():
|
| 442 |
+
request = json.loads(entry["request"])
|
| 443 |
+
response = json.loads(entry["response"])
|
| 444 |
+
entries[key] = {"request": request, "response": response}
|
| 445 |
+
return entries
|
| 446 |
+
|
| 447 |
+
def sync_to_hf(self) -> None:
|
| 448 |
+
"""Sync cache to HF dataset."""
|
| 449 |
+
if not self.hf_repo_id:
|
| 450 |
+
return
|
| 451 |
+
|
| 452 |
+
# Get all entries from the database
|
| 453 |
+
cache = self.db.get_all_entries()
|
| 454 |
+
|
| 455 |
+
# Convert cache to dataset format
|
| 456 |
+
entries = []
|
| 457 |
+
for key, entry in cache.items():
|
| 458 |
+
request = json.loads(entry["request"])
|
| 459 |
+
response_str = entry["response"]
|
| 460 |
+
entries.append(
|
| 461 |
+
{
|
| 462 |
+
"key": key,
|
| 463 |
+
"model": request["model"],
|
| 464 |
+
"system": request["system"],
|
| 465 |
+
"prompt": request["prompt"],
|
| 466 |
+
"response_format": request["response_format"],
|
| 467 |
+
"temperature": request["temperature"],
|
| 468 |
+
"response": response_str,
|
| 469 |
+
}
|
| 470 |
+
)
|
| 471 |
+
|
| 472 |
+
# Create and push dataset
|
| 473 |
+
dataset = Dataset.from_list(entries)
|
| 474 |
+
dataset.push_to_hub(self.hf_repo_id, private=True)
|
| 475 |
+
logger.info(f"Synced {len(cache)} cached items to HF dataset {self.hf_repo_id}")
|
| 476 |
+
|
| 477 |
+
def clear(self) -> None:
|
| 478 |
+
"""Clear all cache entries."""
|
| 479 |
+
self.db.clear()
|
src/workflows/llms.py
CHANGED
|
@@ -1,12 +1,14 @@
|
|
| 1 |
# %%
|
|
|
|
| 2 |
import json
|
| 3 |
import os
|
| 4 |
-
from typing import Optional
|
| 5 |
|
| 6 |
import cohere
|
| 7 |
import numpy as np
|
| 8 |
from langchain_anthropic import ChatAnthropic
|
| 9 |
from langchain_cohere import ChatCohere
|
|
|
|
| 10 |
from langchain_openai import ChatOpenAI
|
| 11 |
from loguru import logger
|
| 12 |
from openai import OpenAI
|
|
@@ -14,6 +16,10 @@ from pydantic import BaseModel, Field
|
|
| 14 |
from rich import print as rprint
|
| 15 |
|
| 16 |
from .configs import AVAILABLE_MODELS
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
|
| 19 |
def _openai_is_json_mode_supported(model_name: str) -> bool:
|
|
@@ -30,7 +36,7 @@ class LLMOutput(BaseModel):
|
|
| 30 |
logprob: Optional[float] = Field(None, description="The log probability of the response")
|
| 31 |
|
| 32 |
|
| 33 |
-
def _get_langchain_chat_output(llm, system: str, prompt: str) -> str:
|
| 34 |
output = llm.invoke([("system", system), ("human", prompt)])
|
| 35 |
ai_message = output["raw"]
|
| 36 |
content = {"content": ai_message.content, "tool_calls": ai_message.tool_calls}
|
|
@@ -38,7 +44,9 @@ def _get_langchain_chat_output(llm, system: str, prompt: str) -> str:
|
|
| 38 |
return {"content": content_str, "output": output["parsed"].model_dump()}
|
| 39 |
|
| 40 |
|
| 41 |
-
def _cohere_completion(
|
|
|
|
|
|
|
| 42 |
messages = [
|
| 43 |
{"role": "system", "content": system},
|
| 44 |
{"role": "user", "content": prompt},
|
|
@@ -49,6 +57,7 @@ def _cohere_completion(model: str, system: str, prompt: str, response_model, log
|
|
| 49 |
messages=messages,
|
| 50 |
response_format={"type": "json_schema", "json_schema": response_model.model_json_schema()},
|
| 51 |
logprobs=logprobs,
|
|
|
|
| 52 |
)
|
| 53 |
output = {}
|
| 54 |
output["content"] = response.message.content[0].text
|
|
@@ -59,12 +68,16 @@ def _cohere_completion(model: str, system: str, prompt: str, response_model, log
|
|
| 59 |
return output
|
| 60 |
|
| 61 |
|
| 62 |
-
def _openai_langchain_completion(
|
| 63 |
-
|
|
|
|
|
|
|
| 64 |
return _get_langchain_chat_output(llm, system, prompt)
|
| 65 |
|
| 66 |
|
| 67 |
-
def _openai_completion(
|
|
|
|
|
|
|
| 68 |
messages = [
|
| 69 |
{"role": "system", "content": system},
|
| 70 |
{"role": "user", "content": prompt},
|
|
@@ -75,6 +88,7 @@ def _openai_completion(model: str, system: str, prompt: str, response_model, log
|
|
| 75 |
messages=messages,
|
| 76 |
response_format=response_model,
|
| 77 |
logprobs=logprobs,
|
|
|
|
| 78 |
)
|
| 79 |
output = {}
|
| 80 |
output["content"] = response.choices[0].message.content
|
|
@@ -85,14 +99,18 @@ def _openai_completion(model: str, system: str, prompt: str, response_model, log
|
|
| 85 |
return output
|
| 86 |
|
| 87 |
|
| 88 |
-
def _anthropic_completion(
|
| 89 |
-
|
|
|
|
|
|
|
| 90 |
return _get_langchain_chat_output(llm, system, prompt)
|
| 91 |
|
| 92 |
|
| 93 |
-
def
|
|
|
|
|
|
|
| 94 |
"""
|
| 95 |
-
Generate a completion from an LLM provider with structured output.
|
| 96 |
|
| 97 |
Args:
|
| 98 |
model (str): Provider and model name in format "provider/model" (e.g. "OpenAI/gpt-4")
|
|
@@ -116,20 +134,69 @@ def completion(model: str, system: str, prompt: str, response_format, logprobs:
|
|
| 116 |
model_name = AVAILABLE_MODELS[model]["model"]
|
| 117 |
provider = model.split("/")[0]
|
| 118 |
if provider == "Cohere":
|
| 119 |
-
return _cohere_completion(model_name, system, prompt, response_format, logprobs)
|
| 120 |
elif provider == "OpenAI":
|
| 121 |
if _openai_is_json_mode_supported(model_name):
|
| 122 |
-
return _openai_completion(model_name, system, prompt, response_format, logprobs)
|
|
|
|
|
|
|
| 123 |
else:
|
| 124 |
-
return _openai_langchain_completion(model_name, system, prompt, response_format,
|
| 125 |
elif provider == "Anthropic":
|
| 126 |
if logprobs:
|
| 127 |
-
raise ValueError("Anthropic
|
| 128 |
-
return _anthropic_completion(model_name, system, prompt, response_format)
|
| 129 |
else:
|
| 130 |
raise ValueError(f"Provider {provider} not supported")
|
| 131 |
|
| 132 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
# %%
|
| 134 |
if __name__ == "__main__":
|
| 135 |
from tqdm import tqdm
|
|
@@ -142,12 +209,52 @@ if __name__ == "__main__":
|
|
| 142 |
answer: str = Field(description="The short answer to the question")
|
| 143 |
explanation: str = Field(description="5 words terse best explanation of the answer.")
|
| 144 |
|
| 145 |
-
models = AVAILABLE_MODELS.keys()
|
| 146 |
system = "You are an accurate and concise explainer of scientific concepts."
|
| 147 |
prompt = "Which planet is closest to the sun in the Milky Way galaxy? Answer directly, no explanation needed."
|
| 148 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
for model in tqdm(models):
|
| 150 |
response = completion(model, system, prompt, ExplainedAnswer, logprobs=False)
|
| 151 |
rprint(response)
|
| 152 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
# %%
|
|
|
|
| 1 |
# %%
|
| 2 |
+
|
| 3 |
import json
|
| 4 |
import os
|
| 5 |
+
from typing import Any, Optional
|
| 6 |
|
| 7 |
import cohere
|
| 8 |
import numpy as np
|
| 9 |
from langchain_anthropic import ChatAnthropic
|
| 10 |
from langchain_cohere import ChatCohere
|
| 11 |
+
from langchain_core.language_models import BaseChatModel
|
| 12 |
from langchain_openai import ChatOpenAI
|
| 13 |
from loguru import logger
|
| 14 |
from openai import OpenAI
|
|
|
|
| 16 |
from rich import print as rprint
|
| 17 |
|
| 18 |
from .configs import AVAILABLE_MODELS
|
| 19 |
+
from .llmcache import LLMCache
|
| 20 |
+
|
| 21 |
+
# Initialize global cache
|
| 22 |
+
llm_cache = LLMCache(cache_dir=".", hf_repo="umdclip/advcal-llm-cache")
|
| 23 |
|
| 24 |
|
| 25 |
def _openai_is_json_mode_supported(model_name: str) -> bool:
|
|
|
|
| 36 |
logprob: Optional[float] = Field(None, description="The log probability of the response")
|
| 37 |
|
| 38 |
|
| 39 |
+
def _get_langchain_chat_output(llm: BaseChatModel, system: str, prompt: str) -> str:
|
| 40 |
output = llm.invoke([("system", system), ("human", prompt)])
|
| 41 |
ai_message = output["raw"]
|
| 42 |
content = {"content": ai_message.content, "tool_calls": ai_message.tool_calls}
|
|
|
|
| 44 |
return {"content": content_str, "output": output["parsed"].model_dump()}
|
| 45 |
|
| 46 |
|
| 47 |
+
def _cohere_completion(
|
| 48 |
+
model: str, system: str, prompt: str, response_model, temperature: float | None = None, logprobs: bool = True
|
| 49 |
+
) -> str:
|
| 50 |
messages = [
|
| 51 |
{"role": "system", "content": system},
|
| 52 |
{"role": "user", "content": prompt},
|
|
|
|
| 57 |
messages=messages,
|
| 58 |
response_format={"type": "json_schema", "json_schema": response_model.model_json_schema()},
|
| 59 |
logprobs=logprobs,
|
| 60 |
+
temperature=temperature,
|
| 61 |
)
|
| 62 |
output = {}
|
| 63 |
output["content"] = response.message.content[0].text
|
|
|
|
| 68 |
return output
|
| 69 |
|
| 70 |
|
| 71 |
+
def _openai_langchain_completion(
|
| 72 |
+
model: str, system: str, prompt: str, response_model, temperature: float | None = None
|
| 73 |
+
) -> str:
|
| 74 |
+
llm = ChatOpenAI(model=model, temperature=temperature).with_structured_output(response_model, include_raw=True)
|
| 75 |
return _get_langchain_chat_output(llm, system, prompt)
|
| 76 |
|
| 77 |
|
| 78 |
+
def _openai_completion(
|
| 79 |
+
model: str, system: str, prompt: str, response_model, temperature: float | None = None, logprobs: bool = True
|
| 80 |
+
) -> str:
|
| 81 |
messages = [
|
| 82 |
{"role": "system", "content": system},
|
| 83 |
{"role": "user", "content": prompt},
|
|
|
|
| 88 |
messages=messages,
|
| 89 |
response_format=response_model,
|
| 90 |
logprobs=logprobs,
|
| 91 |
+
temperature=temperature,
|
| 92 |
)
|
| 93 |
output = {}
|
| 94 |
output["content"] = response.choices[0].message.content
|
|
|
|
| 99 |
return output
|
| 100 |
|
| 101 |
|
| 102 |
+
def _anthropic_completion(
|
| 103 |
+
model: str, system: str, prompt: str, response_model, temperature: float | None = None
|
| 104 |
+
) -> str:
|
| 105 |
+
llm = ChatAnthropic(model=model, temperature=temperature).with_structured_output(response_model, include_raw=True)
|
| 106 |
return _get_langchain_chat_output(llm, system, prompt)
|
| 107 |
|
| 108 |
|
| 109 |
+
def _llm_completion(
|
| 110 |
+
model: str, system: str, prompt: str, response_format, temperature: float | None = None, logprobs: bool = False
|
| 111 |
+
) -> dict[str, Any]:
|
| 112 |
"""
|
| 113 |
+
Generate a completion from an LLM provider with structured output without caching.
|
| 114 |
|
| 115 |
Args:
|
| 116 |
model (str): Provider and model name in format "provider/model" (e.g. "OpenAI/gpt-4")
|
|
|
|
| 134 |
model_name = AVAILABLE_MODELS[model]["model"]
|
| 135 |
provider = model.split("/")[0]
|
| 136 |
if provider == "Cohere":
|
| 137 |
+
return _cohere_completion(model_name, system, prompt, response_format, temperature, logprobs)
|
| 138 |
elif provider == "OpenAI":
|
| 139 |
if _openai_is_json_mode_supported(model_name):
|
| 140 |
+
return _openai_completion(model_name, system, prompt, response_format, temperature, logprobs)
|
| 141 |
+
elif logprobs:
|
| 142 |
+
raise ValueError(f"{model} does not support logprobs feature.")
|
| 143 |
else:
|
| 144 |
+
return _openai_langchain_completion(model_name, system, prompt, response_format, temperature)
|
| 145 |
elif provider == "Anthropic":
|
| 146 |
if logprobs:
|
| 147 |
+
raise ValueError("Anthropic models do not support logprobs")
|
| 148 |
+
return _anthropic_completion(model_name, system, prompt, response_format, temperature)
|
| 149 |
else:
|
| 150 |
raise ValueError(f"Provider {provider} not supported")
|
| 151 |
|
| 152 |
|
| 153 |
+
def completion(
|
| 154 |
+
model: str, system: str, prompt: str, response_format, temperature: float | None = None, logprobs: bool = False
|
| 155 |
+
) -> dict[str, Any]:
|
| 156 |
+
"""
|
| 157 |
+
Generate a completion from an LLM provider with structured output with caching.
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
model (str): Provider and model name in format "provider/model" (e.g. "OpenAI/gpt-4")
|
| 161 |
+
system (str): System prompt/instructions for the model
|
| 162 |
+
prompt (str): User prompt/input
|
| 163 |
+
response_format: Pydantic model defining the expected response structure
|
| 164 |
+
logprobs (bool, optional): Whether to return log probabilities. Defaults to False.
|
| 165 |
+
Note: Not supported by Anthropic models.
|
| 166 |
+
|
| 167 |
+
Returns:
|
| 168 |
+
dict: Contains:
|
| 169 |
+
- output: The structured response matching response_format
|
| 170 |
+
- logprob: (optional) Sum of log probabilities if logprobs=True
|
| 171 |
+
- prob: (optional) Exponential of logprob if logprobs=True
|
| 172 |
+
|
| 173 |
+
Raises:
|
| 174 |
+
ValueError: If logprobs=True with Anthropic models
|
| 175 |
+
"""
|
| 176 |
+
# Check cache first
|
| 177 |
+
cached_response = llm_cache.get(model, system, prompt, response_format, temperature)
|
| 178 |
+
if cached_response is not None:
|
| 179 |
+
logger.info(f"Cache hit for model {model}")
|
| 180 |
+
return cached_response
|
| 181 |
+
|
| 182 |
+
logger.info(f"Cache miss for model {model}, calling API")
|
| 183 |
+
|
| 184 |
+
# Continue with the original implementation for cache miss
|
| 185 |
+
response = _llm_completion(model, system, prompt, response_format, temperature, logprobs)
|
| 186 |
+
|
| 187 |
+
# Update cache with the new response
|
| 188 |
+
llm_cache.set(
|
| 189 |
+
model,
|
| 190 |
+
system,
|
| 191 |
+
prompt,
|
| 192 |
+
response_format,
|
| 193 |
+
temperature,
|
| 194 |
+
response,
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
return response
|
| 198 |
+
|
| 199 |
+
|
| 200 |
# %%
|
| 201 |
if __name__ == "__main__":
|
| 202 |
from tqdm import tqdm
|
|
|
|
| 209 |
answer: str = Field(description="The short answer to the question")
|
| 210 |
explanation: str = Field(description="5 words terse best explanation of the answer.")
|
| 211 |
|
| 212 |
+
models = list(AVAILABLE_MODELS.keys())[:1] # Just use the first model for testing
|
| 213 |
system = "You are an accurate and concise explainer of scientific concepts."
|
| 214 |
prompt = "Which planet is closest to the sun in the Milky Way galaxy? Answer directly, no explanation needed."
|
| 215 |
|
| 216 |
+
llm_cache = LLMCache(cache_dir=".", hf_repo="umdclip/advcal-llm-cache", reset=True)
|
| 217 |
+
|
| 218 |
+
# First call - should be a cache miss
|
| 219 |
+
logger.info("First call - should be a cache miss")
|
| 220 |
+
for model in tqdm(models):
|
| 221 |
+
response = completion(model, system, prompt, ExplainedAnswer, logprobs=False)
|
| 222 |
+
rprint(response)
|
| 223 |
+
|
| 224 |
+
# Second call - should be a cache hit
|
| 225 |
+
logger.info("Second call - should be a cache hit")
|
| 226 |
for model in tqdm(models):
|
| 227 |
response = completion(model, system, prompt, ExplainedAnswer, logprobs=False)
|
| 228 |
rprint(response)
|
| 229 |
|
| 230 |
+
# Slightly different prompt - should be a cache miss
|
| 231 |
+
logger.info("Different prompt - should be a cache miss")
|
| 232 |
+
prompt2 = "Which planet is closest to the sun? Answer directly."
|
| 233 |
+
for model in tqdm(models):
|
| 234 |
+
response = completion(model, system, prompt2, ExplainedAnswer, logprobs=False)
|
| 235 |
+
rprint(response)
|
| 236 |
+
|
| 237 |
+
# Get cache entries count from SQLite
|
| 238 |
+
try:
|
| 239 |
+
cache_entries = llm_cache.get_all_entries()
|
| 240 |
+
logger.info(f"Cache now has {len(cache_entries)} items")
|
| 241 |
+
except Exception as e:
|
| 242 |
+
logger.error(f"Failed to get cache entries: {e}")
|
| 243 |
+
|
| 244 |
+
# Test adding entry with temperature parameter
|
| 245 |
+
logger.info("Testing with temperature parameter")
|
| 246 |
+
response = completion(models[0], system, "What is Mars?", ExplainedAnswer, temperature=0.7, logprobs=False)
|
| 247 |
+
rprint(response)
|
| 248 |
+
|
| 249 |
+
# Demonstrate forced sync to HF if repo is configured
|
| 250 |
+
if llm_cache.hf_repo_id:
|
| 251 |
+
logger.info("Forcing sync to HF dataset")
|
| 252 |
+
try:
|
| 253 |
+
llm_cache.sync_to_hf()
|
| 254 |
+
logger.info("Successfully synced to HF dataset")
|
| 255 |
+
except Exception as e:
|
| 256 |
+
logger.exception(f"Failed to sync to HF: {e}")
|
| 257 |
+
else:
|
| 258 |
+
logger.info("HF repo not configured, skipping sync test")
|
| 259 |
+
|
| 260 |
# %%
|