Spaces:
Runtime error
Runtime error
tech-envision
commited on
Commit
·
258e41b
1
Parent(s):
8828dae
Modify tool response handling
Browse files- src/chat.py +22 -3
- src/vm.py +0 -1
src/chat.py
CHANGED
|
@@ -161,6 +161,19 @@ class ChatSession:
|
|
| 161 |
# return ChatSession._serialize_tool_calls(message.tool_calls)
|
| 162 |
return message.content or ""
|
| 163 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
@staticmethod
|
| 165 |
def _store_assistant_message(conversation: Conversation, message: Message) -> None:
|
| 166 |
"""Persist assistant messages, storing tool calls when present."""
|
|
@@ -225,9 +238,12 @@ class ChatSession:
|
|
| 225 |
execute_terminal_async(**call.function.arguments)
|
| 226 |
)
|
| 227 |
|
| 228 |
-
placeholder = {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
messages.append(placeholder)
|
| 230 |
-
yield ChatResponse(message=Message(**placeholder))
|
| 231 |
|
| 232 |
follow_task = asyncio.create_task(self.ask(messages, think=True))
|
| 233 |
|
|
@@ -246,7 +262,7 @@ class ChatSession:
|
|
| 246 |
await follow_task
|
| 247 |
except asyncio.CancelledError:
|
| 248 |
pass
|
| 249 |
-
|
| 250 |
result = await exec_task
|
| 251 |
messages.append(
|
| 252 |
{
|
|
@@ -274,6 +290,7 @@ class ChatSession:
|
|
| 274 |
messages.append(followup.message.model_dump())
|
| 275 |
yield followup
|
| 276 |
result = await exec_task
|
|
|
|
| 277 |
messages.append(
|
| 278 |
{
|
| 279 |
"role": "tool",
|
|
@@ -344,6 +361,7 @@ class ChatSession:
|
|
| 344 |
await user_task
|
| 345 |
except asyncio.CancelledError:
|
| 346 |
pass
|
|
|
|
| 347 |
result = await exec_task
|
| 348 |
self._tool_task = None
|
| 349 |
self._messages.append(
|
|
@@ -377,6 +395,7 @@ class ChatSession:
|
|
| 377 |
yield text
|
| 378 |
result = await exec_task
|
| 379 |
self._tool_task = None
|
|
|
|
| 380 |
self._messages.append(
|
| 381 |
{"role": "tool", "name": "execute_terminal", "content": result}
|
| 382 |
)
|
|
|
|
| 161 |
# return ChatSession._serialize_tool_calls(message.tool_calls)
|
| 162 |
return message.content or ""
|
| 163 |
|
| 164 |
+
@staticmethod
|
| 165 |
+
def _remove_tool_placeholder(messages: List[Msg]) -> None:
|
| 166 |
+
"""Remove the pending placeholder tool message if present."""
|
| 167 |
+
|
| 168 |
+
for i in range(len(messages) - 1, -1, -1):
|
| 169 |
+
msg = messages[i]
|
| 170 |
+
if (
|
| 171 |
+
msg.get("role") == "tool"
|
| 172 |
+
and msg.get("content") == "Awaiting tool response..."
|
| 173 |
+
):
|
| 174 |
+
messages.pop(i)
|
| 175 |
+
break
|
| 176 |
+
|
| 177 |
@staticmethod
|
| 178 |
def _store_assistant_message(conversation: Conversation, message: Message) -> None:
|
| 179 |
"""Persist assistant messages, storing tool calls when present."""
|
|
|
|
| 238 |
execute_terminal_async(**call.function.arguments)
|
| 239 |
)
|
| 240 |
|
| 241 |
+
placeholder = {
|
| 242 |
+
"role": "tool",
|
| 243 |
+
"name": call.function.name,
|
| 244 |
+
"content": "Awaiting tool response...",
|
| 245 |
+
}
|
| 246 |
messages.append(placeholder)
|
|
|
|
| 247 |
|
| 248 |
follow_task = asyncio.create_task(self.ask(messages, think=True))
|
| 249 |
|
|
|
|
| 262 |
await follow_task
|
| 263 |
except asyncio.CancelledError:
|
| 264 |
pass
|
| 265 |
+
self._remove_tool_placeholder(messages)
|
| 266 |
result = await exec_task
|
| 267 |
messages.append(
|
| 268 |
{
|
|
|
|
| 290 |
messages.append(followup.message.model_dump())
|
| 291 |
yield followup
|
| 292 |
result = await exec_task
|
| 293 |
+
self._remove_tool_placeholder(messages)
|
| 294 |
messages.append(
|
| 295 |
{
|
| 296 |
"role": "tool",
|
|
|
|
| 361 |
await user_task
|
| 362 |
except asyncio.CancelledError:
|
| 363 |
pass
|
| 364 |
+
self._remove_tool_placeholder(self._messages)
|
| 365 |
result = await exec_task
|
| 366 |
self._tool_task = None
|
| 367 |
self._messages.append(
|
|
|
|
| 395 |
yield text
|
| 396 |
result = await exec_task
|
| 397 |
self._tool_task = None
|
| 398 |
+
self._remove_tool_placeholder(self._messages)
|
| 399 |
self._messages.append(
|
| 400 |
{"role": "tool", "name": "execute_terminal", "content": result}
|
| 401 |
)
|
src/vm.py
CHANGED
|
@@ -3,7 +3,6 @@ from __future__ import annotations
|
|
| 3 |
import subprocess
|
| 4 |
import asyncio
|
| 5 |
from functools import partial
|
| 6 |
-
import uuid
|
| 7 |
from pathlib import Path
|
| 8 |
|
| 9 |
from threading import Lock
|
|
|
|
| 3 |
import subprocess
|
| 4 |
import asyncio
|
| 5 |
from functools import partial
|
|
|
|
| 6 |
from pathlib import Path
|
| 7 |
|
| 8 |
from threading import Lock
|