Merge pull request #21 from dexhunter/fix/beta_models
Browse files- aide/backend/__init__.py +8 -0
- aide/backend/backend_anthropic.py +17 -10
- aide/backend/backend_openai.py +17 -11
- aide/backend/utils.py +21 -0
- requirements.txt +2 -1
aide/backend/__init__.py
CHANGED
|
@@ -33,6 +33,14 @@ def query(
|
|
| 33 |
"max_tokens": max_tokens,
|
| 34 |
}
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
query_func = backend_anthropic.query if "claude-" in model else backend_openai.query
|
| 37 |
output, req_time, in_tok_count, out_tok_count, info = query_func(
|
| 38 |
system_message=compile_prompt_to_md(system_message) if system_message else None,
|
|
|
|
| 33 |
"max_tokens": max_tokens,
|
| 34 |
}
|
| 35 |
|
| 36 |
+
# Handle models with beta limitations
|
| 37 |
+
# ref: https://platform.openai.com/docs/guides/reasoning/beta-limitations
|
| 38 |
+
if model.startswith("o1-"):
|
| 39 |
+
if system_message:
|
| 40 |
+
user_message = system_message
|
| 41 |
+
system_message = None
|
| 42 |
+
model_kwargs["temperature"] = 1
|
| 43 |
+
|
| 44 |
query_func = backend_anthropic.query if "claude-" in model else backend_openai.query
|
| 45 |
output, req_time, in_tok_count, out_tok_count, info = query_func(
|
| 46 |
system_message=compile_prompt_to_md(system_message) if system_message else None,
|
aide/backend/backend_anthropic.py
CHANGED
|
@@ -2,23 +2,25 @@
|
|
| 2 |
|
| 3 |
import time
|
| 4 |
|
| 5 |
-
from
|
| 6 |
-
from
|
| 7 |
-
|
| 8 |
|
| 9 |
-
_client: Anthropic = None # type: ignore
|
| 10 |
|
| 11 |
-
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
|
| 15 |
@once
|
| 16 |
def _setup_anthropic_client():
|
| 17 |
global _client
|
| 18 |
-
_client = Anthropic()
|
| 19 |
|
| 20 |
-
|
| 21 |
-
@retry_exp
|
| 22 |
def query(
|
| 23 |
system_message: str | None,
|
| 24 |
user_message: str | None,
|
|
@@ -48,7 +50,12 @@ def query(
|
|
| 48 |
messages = opt_messages_to_list(None, user_message)
|
| 49 |
|
| 50 |
t0 = time.time()
|
| 51 |
-
message =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
req_time = time.time() - t0
|
| 53 |
|
| 54 |
assert len(message.content) == 1 and message.content[0].type == "text"
|
|
|
|
| 2 |
|
| 3 |
import time
|
| 4 |
|
| 5 |
+
from .utils import FunctionSpec, OutputType, opt_messages_to_list, backoff_create
|
| 6 |
+
from funcy import notnone, once, select_values
|
| 7 |
+
import anthropic
|
| 8 |
|
| 9 |
+
_client: anthropic.Anthropic = None # type: ignore
|
| 10 |
|
| 11 |
+
ANTHROPIC_TIMEOUT_EXCEPTIONS = (
|
| 12 |
+
anthropic.RateLimitError,
|
| 13 |
+
anthropic.APIConnectionError,
|
| 14 |
+
anthropic.APITimeoutError,
|
| 15 |
+
anthropic.InternalServerError,
|
| 16 |
+
)
|
| 17 |
|
| 18 |
|
| 19 |
@once
|
| 20 |
def _setup_anthropic_client():
|
| 21 |
global _client
|
| 22 |
+
_client = anthropic.Anthropic(max_retries=0)
|
| 23 |
|
|
|
|
|
|
|
| 24 |
def query(
|
| 25 |
system_message: str | None,
|
| 26 |
user_message: str | None,
|
|
|
|
| 50 |
messages = opt_messages_to_list(None, user_message)
|
| 51 |
|
| 52 |
t0 = time.time()
|
| 53 |
+
message = backoff_create(
|
| 54 |
+
_client.messages.create,
|
| 55 |
+
ANTHROPIC_TIMEOUT_EXCEPTIONS,
|
| 56 |
+
messages=messages,
|
| 57 |
+
**filtered_kwargs,
|
| 58 |
+
)
|
| 59 |
req_time = time.time() - t0
|
| 60 |
|
| 61 |
assert len(message.content) == 1 and message.content[0].type == "text"
|
aide/backend/backend_openai.py
CHANGED
|
@@ -4,25 +4,26 @@ import json
|
|
| 4 |
import logging
|
| 5 |
import time
|
| 6 |
|
| 7 |
-
from .utils import FunctionSpec, OutputType, opt_messages_to_list
|
| 8 |
-
from funcy import notnone, once,
|
| 9 |
-
|
| 10 |
|
| 11 |
logger = logging.getLogger("aide")
|
| 12 |
|
| 13 |
-
_client: OpenAI = None # type: ignore
|
| 14 |
-
|
| 15 |
-
RATELIMIT_RETRIES = 5
|
| 16 |
-
retry_exp = retry(RATELIMIT_RETRIES, errors=RateLimitError, timeout=lambda a: 2 ** (a + 1)) # type: ignore
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
@once
|
| 20 |
def _setup_openai_client():
|
| 21 |
global _client
|
| 22 |
-
_client = OpenAI(max_retries=
|
| 23 |
-
|
| 24 |
|
| 25 |
-
@retry_exp
|
| 26 |
def query(
|
| 27 |
system_message: str | None,
|
| 28 |
user_message: str | None,
|
|
@@ -40,7 +41,12 @@ def query(
|
|
| 40 |
filtered_kwargs["tool_choice"] = func_spec.openai_tool_choice_dict
|
| 41 |
|
| 42 |
t0 = time.time()
|
| 43 |
-
completion =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
req_time = time.time() - t0
|
| 45 |
|
| 46 |
choice = completion.choices[0]
|
|
|
|
| 4 |
import logging
|
| 5 |
import time
|
| 6 |
|
| 7 |
+
from .utils import FunctionSpec, OutputType, opt_messages_to_list, backoff_create
|
| 8 |
+
from funcy import notnone, once, select_values
|
| 9 |
+
import openai
|
| 10 |
|
| 11 |
logger = logging.getLogger("aide")
|
| 12 |
|
| 13 |
+
_client: openai.OpenAI = None # type: ignore
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
+
OPENAI_TIMEOUT_EXCEPTIONS = (
|
| 16 |
+
openai.RateLimitError,
|
| 17 |
+
openai.APIConnectionError,
|
| 18 |
+
openai.APITimeoutError,
|
| 19 |
+
openai.InternalServerError,
|
| 20 |
+
)
|
| 21 |
|
| 22 |
@once
|
| 23 |
def _setup_openai_client():
|
| 24 |
global _client
|
| 25 |
+
_client = openai.OpenAI(max_retries=0)
|
|
|
|
| 26 |
|
|
|
|
| 27 |
def query(
|
| 28 |
system_message: str | None,
|
| 29 |
user_message: str | None,
|
|
|
|
| 41 |
filtered_kwargs["tool_choice"] = func_spec.openai_tool_choice_dict
|
| 42 |
|
| 43 |
t0 = time.time()
|
| 44 |
+
completion = backoff_create(
|
| 45 |
+
_client.chat.completions.create,
|
| 46 |
+
OPENAI_TIMEOUT_EXCEPTIONS,
|
| 47 |
+
messages=messages,
|
| 48 |
+
**filtered_kwargs,
|
| 49 |
+
)
|
| 50 |
req_time = time.time() - t0
|
| 51 |
|
| 52 |
choice = completion.choices[0]
|
aide/backend/utils.py
CHANGED
|
@@ -8,6 +8,27 @@ FunctionCallType = dict
|
|
| 8 |
OutputType = str | FunctionCallType
|
| 9 |
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
def opt_messages_to_list(
|
| 12 |
system_message: str | None, user_message: str | None
|
| 13 |
) -> list[dict[str, str]]:
|
|
|
|
| 8 |
OutputType = str | FunctionCallType
|
| 9 |
|
| 10 |
|
| 11 |
+
import backoff
|
| 12 |
+
import logging
|
| 13 |
+
from typing import Callable
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger("aide")
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@backoff.on_predicate(
|
| 19 |
+
wait_gen=backoff.expo,
|
| 20 |
+
max_value=60,
|
| 21 |
+
factor=1.5,
|
| 22 |
+
)
|
| 23 |
+
def backoff_create(
|
| 24 |
+
create_fn: Callable, retry_exceptions: list[Exception], *args, **kwargs
|
| 25 |
+
):
|
| 26 |
+
try:
|
| 27 |
+
return create_fn(*args, **kwargs)
|
| 28 |
+
except retry_exceptions as e:
|
| 29 |
+
logger.info(f"Backoff exception: {e}")
|
| 30 |
+
return False
|
| 31 |
+
|
| 32 |
def opt_messages_to_list(
|
| 33 |
system_message: str | None, user_message: str | None
|
| 34 |
) -> list[dict[str, str]]:
|
requirements.txt
CHANGED
|
@@ -88,4 +88,5 @@ pdf2image
|
|
| 88 |
PyPDF
|
| 89 |
pyocr
|
| 90 |
pyarrow
|
| 91 |
-
xlrd
|
|
|
|
|
|
| 88 |
PyPDF
|
| 89 |
pyocr
|
| 90 |
pyarrow
|
| 91 |
+
xlrd
|
| 92 |
+
backoff
|