Dixing (Dex) Xu
commited on
:bug: handle missing function calls for openai (#35) (#38)
Browse files* :bug: handle missing function calls for openai (#35)
* fix: black format
- aide/backend/backend_openai.py +62 -17
aide/backend/backend_openai.py
CHANGED
|
@@ -19,6 +19,23 @@ OPENAI_TIMEOUT_EXCEPTIONS = (
|
|
| 19 |
openai.InternalServerError,
|
| 20 |
)
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
@once
|
| 24 |
def _setup_openai_client():
|
|
@@ -26,21 +43,41 @@ def _setup_openai_client():
|
|
| 26 |
_client = openai.OpenAI(max_retries=0)
|
| 27 |
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
def query(
|
| 30 |
system_message: str | None,
|
| 31 |
user_message: str | None,
|
| 32 |
func_spec: FunctionSpec | None = None,
|
| 33 |
**model_kwargs,
|
| 34 |
) -> tuple[OutputType, float, int, int, dict]:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
_setup_openai_client()
|
| 36 |
-
filtered_kwargs: dict = select_values(notnone, model_kwargs)
|
|
|
|
|
|
|
| 37 |
|
| 38 |
messages = opt_messages_to_list(system_message, user_message)
|
| 39 |
|
| 40 |
if func_spec is not None:
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
t0 = time.time()
|
| 46 |
completion = backoff_create(
|
|
@@ -53,22 +90,30 @@ def query(
|
|
| 53 |
|
| 54 |
choice = completion.choices[0]
|
| 55 |
|
| 56 |
-
if func_spec is None:
|
| 57 |
output = choice.message.content
|
| 58 |
else:
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
output =
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
|
|
|
|
|
|
| 70 |
)
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
in_tokens = completion.usage.prompt_tokens
|
| 74 |
out_tokens = completion.usage.completion_tokens
|
|
|
|
| 19 |
openai.InternalServerError,
|
| 20 |
)
|
| 21 |
|
| 22 |
+
# (docs) https://platform.openai.com/docs/guides/function-calling/supported-models
|
| 23 |
+
SUPPORTED_FUNCTION_CALL_MODELS = {
|
| 24 |
+
"gpt-4o",
|
| 25 |
+
"gpt-4o-2024-08-06",
|
| 26 |
+
"gpt-4o-2024-05-13",
|
| 27 |
+
"gpt-4o-mini",
|
| 28 |
+
"gpt-4o-mini-2024-07-18",
|
| 29 |
+
"gpt-4-turbo",
|
| 30 |
+
"gpt-4-turbo-2024-04-09",
|
| 31 |
+
"gpt-4-turbo-preview",
|
| 32 |
+
"gpt-4-0125-preview",
|
| 33 |
+
"gpt-4-1106-preview",
|
| 34 |
+
"gpt-3.5-turbo",
|
| 35 |
+
"gpt-3.5-turbo-0125",
|
| 36 |
+
"gpt-3.5-turbo-1106",
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
|
| 40 |
@once
|
| 41 |
def _setup_openai_client():
|
|
|
|
| 43 |
_client = openai.OpenAI(max_retries=0)
|
| 44 |
|
| 45 |
|
| 46 |
+
def is_function_call_supported(model_name: str) -> bool:
|
| 47 |
+
"""Return True if the model supports function calling."""
|
| 48 |
+
return model_name in SUPPORTED_FUNCTION_CALL_MODELS
|
| 49 |
+
|
| 50 |
+
|
| 51 |
def query(
|
| 52 |
system_message: str | None,
|
| 53 |
user_message: str | None,
|
| 54 |
func_spec: FunctionSpec | None = None,
|
| 55 |
**model_kwargs,
|
| 56 |
) -> tuple[OutputType, float, int, int, dict]:
|
| 57 |
+
"""
|
| 58 |
+
Query the OpenAI API, optionally with function calling.
|
| 59 |
+
Function calling support is only checked for feedback/review operations.
|
| 60 |
+
"""
|
| 61 |
_setup_openai_client()
|
| 62 |
+
filtered_kwargs: dict = select_values(notnone, model_kwargs)
|
| 63 |
+
model_name = filtered_kwargs.get("model", "")
|
| 64 |
+
logger.debug(f"OpenAI query called with model='{model_name}'")
|
| 65 |
|
| 66 |
messages = opt_messages_to_list(system_message, user_message)
|
| 67 |
|
| 68 |
if func_spec is not None:
|
| 69 |
+
# Only check function call support for feedback/search operations
|
| 70 |
+
if func_spec.name == "submit_review":
|
| 71 |
+
if not is_function_call_supported(model_name):
|
| 72 |
+
logger.warning(
|
| 73 |
+
f"Review function calling was requested, but model '{model_name}' "
|
| 74 |
+
"does not support function calling. Falling back to plain text generation."
|
| 75 |
+
)
|
| 76 |
+
filtered_kwargs.pop("tools", None)
|
| 77 |
+
filtered_kwargs.pop("tool_choice", None)
|
| 78 |
+
else:
|
| 79 |
+
filtered_kwargs["tools"] = [func_spec.as_openai_tool_dict]
|
| 80 |
+
filtered_kwargs["tool_choice"] = func_spec.openai_tool_choice_dict
|
| 81 |
|
| 82 |
t0 = time.time()
|
| 83 |
completion = backoff_create(
|
|
|
|
| 90 |
|
| 91 |
choice = completion.choices[0]
|
| 92 |
|
| 93 |
+
if func_spec is None or "tools" not in filtered_kwargs:
|
| 94 |
output = choice.message.content
|
| 95 |
else:
|
| 96 |
+
tool_calls = getattr(choice.message, "tool_calls", None)
|
| 97 |
+
|
| 98 |
+
if not tool_calls:
|
| 99 |
+
logger.warning(
|
| 100 |
+
f"No function call used despite function spec. Fallback to text. "
|
| 101 |
+
f"Message content: {choice.message.content}"
|
| 102 |
+
)
|
| 103 |
+
output = choice.message.content
|
| 104 |
+
else:
|
| 105 |
+
first_call = tool_calls[0]
|
| 106 |
+
assert first_call.function.name == func_spec.name, (
|
| 107 |
+
f"Function name mismatch: expected {func_spec.name}, "
|
| 108 |
+
f"got {first_call.function.name}"
|
| 109 |
)
|
| 110 |
+
try:
|
| 111 |
+
output = json.loads(first_call.function.arguments)
|
| 112 |
+
except json.JSONDecodeError as e:
|
| 113 |
+
logger.error(
|
| 114 |
+
f"Error decoding function arguments:\n{first_call.function.arguments}"
|
| 115 |
+
)
|
| 116 |
+
raise e
|
| 117 |
|
| 118 |
in_tokens = completion.usage.prompt_tokens
|
| 119 |
out_tokens = completion.usage.completion_tokens
|