Dixing (Dex) Xu
commited on
:bug: better handling for function calling errors (#44)
Browse files* remove the supported model lists
* update error handling #43
* update code block display for html
- aide/backend/backend_openai.py +56 -56
- aide/utils/tree_export.py +14 -1
aide/backend/backend_openai.py
CHANGED
|
@@ -19,23 +19,6 @@ OPENAI_TIMEOUT_EXCEPTIONS = (
|
|
| 19 |
openai.InternalServerError,
|
| 20 |
)
|
| 21 |
|
| 22 |
-
# (docs) https://platform.openai.com/docs/guides/function-calling/supported-models
|
| 23 |
-
SUPPORTED_FUNCTION_CALL_MODELS = {
|
| 24 |
-
"gpt-4o",
|
| 25 |
-
"gpt-4o-2024-08-06",
|
| 26 |
-
"gpt-4o-2024-05-13",
|
| 27 |
-
"gpt-4o-mini",
|
| 28 |
-
"gpt-4o-mini-2024-07-18",
|
| 29 |
-
"gpt-4-turbo",
|
| 30 |
-
"gpt-4-turbo-2024-04-09",
|
| 31 |
-
"gpt-4-turbo-preview",
|
| 32 |
-
"gpt-4-0125-preview",
|
| 33 |
-
"gpt-4-1106-preview",
|
| 34 |
-
"gpt-3.5-turbo",
|
| 35 |
-
"gpt-3.5-turbo-0125",
|
| 36 |
-
"gpt-3.5-turbo-1106",
|
| 37 |
-
}
|
| 38 |
-
|
| 39 |
|
| 40 |
@once
|
| 41 |
def _setup_openai_client():
|
|
@@ -43,11 +26,6 @@ def _setup_openai_client():
|
|
| 43 |
_client = openai.OpenAI(max_retries=0)
|
| 44 |
|
| 45 |
|
| 46 |
-
def is_function_call_supported(model_name: str) -> bool:
|
| 47 |
-
"""Return True if the model supports function calling."""
|
| 48 |
-
return model_name in SUPPORTED_FUNCTION_CALL_MODELS
|
| 49 |
-
|
| 50 |
-
|
| 51 |
def query(
|
| 52 |
system_message: str | None,
|
| 53 |
user_message: str | None,
|
|
@@ -56,64 +34,86 @@ def query(
|
|
| 56 |
) -> tuple[OutputType, float, int, int, dict]:
|
| 57 |
"""
|
| 58 |
Query the OpenAI API, optionally with function calling.
|
| 59 |
-
|
| 60 |
"""
|
| 61 |
_setup_openai_client()
|
| 62 |
filtered_kwargs: dict = select_values(notnone, model_kwargs)
|
| 63 |
-
model_name = filtered_kwargs.get("model", "")
|
| 64 |
-
logger.debug(f"OpenAI query called with model='{model_name}'")
|
| 65 |
|
|
|
|
| 66 |
messages = opt_messages_to_list(system_message, user_message)
|
| 67 |
|
|
|
|
| 68 |
if func_spec is not None:
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
if not is_function_call_supported(model_name):
|
| 72 |
-
logger.warning(
|
| 73 |
-
f"Review function calling was requested, but model '{model_name}' "
|
| 74 |
-
"does not support function calling. Falling back to plain text generation."
|
| 75 |
-
)
|
| 76 |
-
filtered_kwargs.pop("tools", None)
|
| 77 |
-
filtered_kwargs.pop("tool_choice", None)
|
| 78 |
-
else:
|
| 79 |
-
filtered_kwargs["tools"] = [func_spec.as_openai_tool_dict]
|
| 80 |
-
filtered_kwargs["tool_choice"] = func_spec.openai_tool_choice_dict
|
| 81 |
|
|
|
|
| 82 |
t0 = time.time()
|
| 83 |
-
completion = backoff_create(
|
| 84 |
-
_client.chat.completions.create,
|
| 85 |
-
OPENAI_TIMEOUT_EXCEPTIONS,
|
| 86 |
-
messages=messages,
|
| 87 |
-
**filtered_kwargs,
|
| 88 |
-
)
|
| 89 |
-
req_time = time.time() - t0
|
| 90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
choice = completion.choices[0]
|
| 92 |
|
|
|
|
| 93 |
if func_spec is None or "tools" not in filtered_kwargs:
|
|
|
|
| 94 |
output = choice.message.content
|
| 95 |
else:
|
|
|
|
| 96 |
tool_calls = getattr(choice.message, "tool_calls", None)
|
| 97 |
-
|
| 98 |
if not tool_calls:
|
| 99 |
logger.warning(
|
| 100 |
-
|
| 101 |
f"Message content: {choice.message.content}"
|
| 102 |
)
|
| 103 |
output = choice.message.content
|
| 104 |
else:
|
| 105 |
first_call = tool_calls[0]
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
output = json.loads(first_call.function.arguments)
|
| 112 |
-
except json.JSONDecodeError as e:
|
| 113 |
-
logger.error(
|
| 114 |
-
f"Error decoding function arguments:\n{first_call.function.arguments}"
|
| 115 |
)
|
| 116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
in_tokens = completion.usage.prompt_tokens
|
| 119 |
out_tokens = completion.usage.completion_tokens
|
|
|
|
| 19 |
openai.InternalServerError,
|
| 20 |
)
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
@once
|
| 24 |
def _setup_openai_client():
|
|
|
|
| 26 |
_client = openai.OpenAI(max_retries=0)
|
| 27 |
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
def query(
|
| 30 |
system_message: str | None,
|
| 31 |
user_message: str | None,
|
|
|
|
| 34 |
) -> tuple[OutputType, float, int, int, dict]:
|
| 35 |
"""
|
| 36 |
Query the OpenAI API, optionally with function calling.
|
| 37 |
+
If the model doesn't support function calling, gracefully degrade to text generation.
|
| 38 |
"""
|
| 39 |
_setup_openai_client()
|
| 40 |
filtered_kwargs: dict = select_values(notnone, model_kwargs)
|
|
|
|
|
|
|
| 41 |
|
| 42 |
+
# Convert system/user messages to the format required by the client
|
| 43 |
messages = opt_messages_to_list(system_message, user_message)
|
| 44 |
|
| 45 |
+
# If function calling is requested, attach the function spec
|
| 46 |
if func_spec is not None:
|
| 47 |
+
filtered_kwargs["tools"] = [func_spec.as_openai_tool_dict]
|
| 48 |
+
filtered_kwargs["tool_choice"] = func_spec.openai_tool_choice_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
+
completion = None
|
| 51 |
t0 = time.time()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
+
# Attempt the API call
|
| 54 |
+
try:
|
| 55 |
+
completion = backoff_create(
|
| 56 |
+
_client.chat.completions.create,
|
| 57 |
+
OPENAI_TIMEOUT_EXCEPTIONS,
|
| 58 |
+
messages=messages,
|
| 59 |
+
**filtered_kwargs,
|
| 60 |
+
)
|
| 61 |
+
except openai.error.InvalidRequestError as e:
|
| 62 |
+
# Check whether the error indicates that function calling is not supported
|
| 63 |
+
if "function calling" in str(e).lower() or "tools" in str(e).lower():
|
| 64 |
+
logger.warning(
|
| 65 |
+
"Function calling was attempted but is not supported by this model. "
|
| 66 |
+
"Falling back to plain text generation."
|
| 67 |
+
)
|
| 68 |
+
# Remove function-calling parameters and retry
|
| 69 |
+
filtered_kwargs.pop("tools", None)
|
| 70 |
+
filtered_kwargs.pop("tool_choice", None)
|
| 71 |
+
|
| 72 |
+
# Retry without function calling
|
| 73 |
+
completion = backoff_create(
|
| 74 |
+
_client.chat.completions.create,
|
| 75 |
+
OPENAI_TIMEOUT_EXCEPTIONS,
|
| 76 |
+
messages=messages,
|
| 77 |
+
**filtered_kwargs,
|
| 78 |
+
)
|
| 79 |
+
else:
|
| 80 |
+
# If it's some other error, re-raise
|
| 81 |
+
raise
|
| 82 |
+
|
| 83 |
+
req_time = time.time() - t0
|
| 84 |
choice = completion.choices[0]
|
| 85 |
|
| 86 |
+
# Decide how to parse the response
|
| 87 |
if func_spec is None or "tools" not in filtered_kwargs:
|
| 88 |
+
# No function calling was ultimately used
|
| 89 |
output = choice.message.content
|
| 90 |
else:
|
| 91 |
+
# Attempt to extract tool calls
|
| 92 |
tool_calls = getattr(choice.message, "tool_calls", None)
|
|
|
|
| 93 |
if not tool_calls:
|
| 94 |
logger.warning(
|
| 95 |
+
"No function call was used despite function spec. Fallback to text.\n"
|
| 96 |
f"Message content: {choice.message.content}"
|
| 97 |
)
|
| 98 |
output = choice.message.content
|
| 99 |
else:
|
| 100 |
first_call = tool_calls[0]
|
| 101 |
+
# Optional: verify that the function name matches
|
| 102 |
+
if first_call.function.name != func_spec.name:
|
| 103 |
+
logger.warning(
|
| 104 |
+
f"Function name mismatch: expected {func_spec.name}, "
|
| 105 |
+
f"got {first_call.function.name}. Fallback to text."
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
)
|
| 107 |
+
output = choice.message.content
|
| 108 |
+
else:
|
| 109 |
+
try:
|
| 110 |
+
output = json.loads(first_call.function.arguments)
|
| 111 |
+
except json.JSONDecodeError as ex:
|
| 112 |
+
logger.error(
|
| 113 |
+
"Error decoding function arguments:\n"
|
| 114 |
+
f"{first_call.function.arguments}"
|
| 115 |
+
)
|
| 116 |
+
raise ex
|
| 117 |
|
| 118 |
in_tokens = completion.usage.prompt_tokens
|
| 119 |
out_tokens = completion.usage.completion_tokens
|
aide/utils/tree_export.py
CHANGED
|
@@ -38,6 +38,19 @@ def normalize_layout(layout: np.ndarray):
|
|
| 38 |
return layout
|
| 39 |
|
| 40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
def cfg_to_tree_struct(cfg, jou: Journal):
|
| 42 |
edges = list(get_edges(jou))
|
| 43 |
layout = normalize_layout(generate_layout(len(jou), edges))
|
|
@@ -52,7 +65,7 @@ def cfg_to_tree_struct(cfg, jou: Journal):
|
|
| 52 |
edges=edges,
|
| 53 |
layout=layout.tolist(),
|
| 54 |
plan=[textwrap.fill(n.plan, width=80) for n in jou.nodes],
|
| 55 |
-
code=[n.code for n in jou],
|
| 56 |
term_out=[n.term_out for n in jou],
|
| 57 |
analysis=[n.analysis for n in jou],
|
| 58 |
exp_name=cfg.exp_name,
|
|
|
|
| 38 |
return layout
|
| 39 |
|
| 40 |
|
| 41 |
+
def strip_code_markers(code: str) -> str:
|
| 42 |
+
"""Remove markdown code block markers if present."""
|
| 43 |
+
code = code.strip()
|
| 44 |
+
if code.startswith("```"):
|
| 45 |
+
# Remove opening backticks and optional language identifier
|
| 46 |
+
first_newline = code.find("\n")
|
| 47 |
+
if first_newline != -1:
|
| 48 |
+
code = code[first_newline:].strip()
|
| 49 |
+
if code.endswith("```"):
|
| 50 |
+
code = code[:-3].strip()
|
| 51 |
+
return code
|
| 52 |
+
|
| 53 |
+
|
| 54 |
def cfg_to_tree_struct(cfg, jou: Journal):
|
| 55 |
edges = list(get_edges(jou))
|
| 56 |
layout = normalize_layout(generate_layout(len(jou), edges))
|
|
|
|
| 65 |
edges=edges,
|
| 66 |
layout=layout.tolist(),
|
| 67 |
plan=[textwrap.fill(n.plan, width=80) for n in jou.nodes],
|
| 68 |
+
code=[strip_code_markers(n.code) for n in jou],
|
| 69 |
term_out=[n.term_out for n in jou],
|
| 70 |
analysis=[n.analysis for n in jou],
|
| 71 |
exp_name=cfg.exp_name,
|