Spaces:
Runtime error
Runtime error
tech-envision
commited on
Commit
·
f7c8c98
1
Parent(s):
c8dee25
Add multi-agent team with communication
Browse files- api_app/__init__.py +3 -3
- bot/discord_bot.py +3 -3
- run.py +2 -2
- src/__init__.py +5 -0
- src/chat.py +43 -6
- src/config.py +27 -20
- src/team.py +107 -0
api_app/__init__.py
CHANGED
|
@@ -9,7 +9,7 @@ import os
|
|
| 9 |
import tempfile
|
| 10 |
from pathlib import Path
|
| 11 |
|
| 12 |
-
from src.
|
| 13 |
from src.log import get_logger
|
| 14 |
from src.db import list_sessions, list_sessions_info
|
| 15 |
|
|
@@ -29,7 +29,7 @@ def create_app() -> FastAPI:
|
|
| 29 |
@app.post("/chat/stream")
|
| 30 |
async def chat_stream(req: ChatRequest):
|
| 31 |
async def stream() -> asyncio.AsyncIterator[str]:
|
| 32 |
-
async with
|
| 33 |
try:
|
| 34 |
async for part in chat.chat_stream(req.prompt):
|
| 35 |
yield part
|
|
@@ -45,7 +45,7 @@ def create_app() -> FastAPI:
|
|
| 45 |
session: str = Form("default"),
|
| 46 |
file: UploadFile = File(...),
|
| 47 |
):
|
| 48 |
-
async with
|
| 49 |
tmpdir = tempfile.mkdtemp(prefix="upload_")
|
| 50 |
tmp_path = Path(tmpdir) / file.filename
|
| 51 |
try:
|
|
|
|
| 9 |
import tempfile
|
| 10 |
from pathlib import Path
|
| 11 |
|
| 12 |
+
from src.team import TeamChatSession
|
| 13 |
from src.log import get_logger
|
| 14 |
from src.db import list_sessions, list_sessions_info
|
| 15 |
|
|
|
|
| 29 |
@app.post("/chat/stream")
|
| 30 |
async def chat_stream(req: ChatRequest):
|
| 31 |
async def stream() -> asyncio.AsyncIterator[str]:
|
| 32 |
+
async with TeamChatSession(user=req.user, session=req.session) as chat:
|
| 33 |
try:
|
| 34 |
async for part in chat.chat_stream(req.prompt):
|
| 35 |
yield part
|
|
|
|
| 45 |
session: str = Form("default"),
|
| 46 |
file: UploadFile = File(...),
|
| 47 |
):
|
| 48 |
+
async with TeamChatSession(user=user, session=session) as chat:
|
| 49 |
tmpdir = tempfile.mkdtemp(prefix="upload_")
|
| 50 |
tmp_path = Path(tmpdir) / file.filename
|
| 51 |
try:
|
bot/discord_bot.py
CHANGED
|
@@ -8,7 +8,7 @@ import discord
|
|
| 8 |
from discord.ext import commands
|
| 9 |
from dotenv import load_dotenv
|
| 10 |
|
| 11 |
-
from src.
|
| 12 |
from src.db import reset_history
|
| 13 |
from src.log import get_logger
|
| 14 |
|
|
@@ -34,7 +34,7 @@ async def reset(ctx: commands.Context) -> None:
|
|
| 34 |
await ctx.reply(f"Chat history cleared ({deleted} messages deleted).")
|
| 35 |
|
| 36 |
|
| 37 |
-
async def _handle_attachments(chat:
|
| 38 |
if not message.attachments:
|
| 39 |
return []
|
| 40 |
|
|
@@ -61,7 +61,7 @@ async def on_message(message: discord.Message) -> None:
|
|
| 61 |
if message.content.startswith("!"):
|
| 62 |
return
|
| 63 |
|
| 64 |
-
async with
|
| 65 |
docs = await _handle_attachments(chat, message)
|
| 66 |
if docs:
|
| 67 |
info = "\n".join(f"{name} -> {path}" for name, path in docs)
|
|
|
|
| 8 |
from discord.ext import commands
|
| 9 |
from dotenv import load_dotenv
|
| 10 |
|
| 11 |
+
from src.team import TeamChatSession
|
| 12 |
from src.db import reset_history
|
| 13 |
from src.log import get_logger
|
| 14 |
|
|
|
|
| 34 |
await ctx.reply(f"Chat history cleared ({deleted} messages deleted).")
|
| 35 |
|
| 36 |
|
| 37 |
+
async def _handle_attachments(chat: TeamChatSession, message: discord.Message) -> list[tuple[str, str]]:
|
| 38 |
if not message.attachments:
|
| 39 |
return []
|
| 40 |
|
|
|
|
| 61 |
if message.content.startswith("!"):
|
| 62 |
return
|
| 63 |
|
| 64 |
+
async with TeamChatSession(user=str(message.author.id), session=str(message.channel.id)) as chat:
|
| 65 |
docs = await _handle_attachments(chat, message)
|
| 66 |
if docs:
|
| 67 |
info = "\n".join(f"{name} -> {path}" for name, path in docs)
|
run.py
CHANGED
|
@@ -2,12 +2,12 @@ from __future__ import annotations
|
|
| 2 |
|
| 3 |
import asyncio
|
| 4 |
|
| 5 |
-
from src.
|
| 6 |
from src.vm import VMRegistry
|
| 7 |
|
| 8 |
|
| 9 |
async def _main() -> None:
|
| 10 |
-
async with
|
| 11 |
# doc_path = chat.upload_document("note.pdf")
|
| 12 |
async for resp in chat.chat_stream("using python, execute a code to remind me in 30 seconds to take a break."):
|
| 13 |
print("\n>>>", resp)
|
|
|
|
| 2 |
|
| 3 |
import asyncio
|
| 4 |
|
| 5 |
+
from src.team import TeamChatSession
|
| 6 |
from src.vm import VMRegistry
|
| 7 |
|
| 8 |
|
| 9 |
async def _main() -> None:
|
| 10 |
+
async with TeamChatSession(user="demo_user", session="demo_session") as chat:
|
| 11 |
# doc_path = chat.upload_document("note.pdf")
|
| 12 |
async for resp in chat.chat_stream("using python, execute a code to remind me in 30 seconds to take a break."):
|
| 13 |
print("\n>>>", resp)
|
src/__init__.py
CHANGED
|
@@ -1,12 +1,17 @@
|
|
| 1 |
from .chat import ChatSession
|
|
|
|
| 2 |
from .tools import execute_terminal, execute_terminal_async, set_vm
|
| 3 |
from .utils import limit_chars
|
| 4 |
from .vm import LinuxVM
|
| 5 |
|
| 6 |
__all__ = [
|
| 7 |
"ChatSession",
|
|
|
|
| 8 |
"execute_terminal",
|
| 9 |
"execute_terminal_async",
|
|
|
|
|
|
|
|
|
|
| 10 |
"set_vm",
|
| 11 |
"LinuxVM",
|
| 12 |
"limit_chars",
|
|
|
|
| 1 |
from .chat import ChatSession
|
| 2 |
+
from .team import TeamChatSession, send_to_junior, send_to_junior_async, set_team
|
| 3 |
from .tools import execute_terminal, execute_terminal_async, set_vm
|
| 4 |
from .utils import limit_chars
|
| 5 |
from .vm import LinuxVM
|
| 6 |
|
| 7 |
__all__ = [
|
| 8 |
"ChatSession",
|
| 9 |
+
"TeamChatSession",
|
| 10 |
"execute_terminal",
|
| 11 |
"execute_terminal_async",
|
| 12 |
+
"send_to_junior",
|
| 13 |
+
"send_to_junior_async",
|
| 14 |
+
"set_team",
|
| 15 |
"set_vm",
|
| 16 |
"LinuxVM",
|
| 17 |
"limit_chars",
|
src/chat.py
CHANGED
|
@@ -61,6 +61,9 @@ class ChatSession:
|
|
| 61 |
session: str = "default",
|
| 62 |
host: str = OLLAMA_HOST,
|
| 63 |
model: str = MODEL_NAME,
|
|
|
|
|
|
|
|
|
|
| 64 |
) -> None:
|
| 65 |
init_db()
|
| 66 |
self._client = AsyncClient(host=host)
|
|
@@ -70,6 +73,10 @@ class ChatSession:
|
|
| 70 |
user=self._user, session_name=session
|
| 71 |
)
|
| 72 |
self._vm = None
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
self._messages: List[Msg] = self._load_history()
|
| 74 |
self._data = _get_session_data(self._conversation.id)
|
| 75 |
self._lock = self._data.lock
|
|
@@ -190,7 +197,7 @@ class ChatSession:
|
|
| 190 |
"""Send a chat request, automatically prepending the system prompt."""
|
| 191 |
|
| 192 |
if not messages or messages[0].get("role") != "system":
|
| 193 |
-
payload = [{"role": "system", "content":
|
| 194 |
else:
|
| 195 |
payload = messages
|
| 196 |
|
|
@@ -198,10 +205,16 @@ class ChatSession:
|
|
| 198 |
self._model,
|
| 199 |
messages=payload,
|
| 200 |
think=think,
|
| 201 |
-
tools=
|
| 202 |
options={"num_ctx": NUM_CTX},
|
| 203 |
)
|
| 204 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
async def _handle_tool_calls_stream(
|
| 206 |
self,
|
| 207 |
messages: List[Msg],
|
|
@@ -217,7 +230,8 @@ class ChatSession:
|
|
| 217 |
return
|
| 218 |
while depth < MAX_TOOL_CALL_DEPTH and response.message.tool_calls:
|
| 219 |
for call in response.message.tool_calls:
|
| 220 |
-
|
|
|
|
| 221 |
_LOG.warning("Unsupported tool call: %s", call.function.name)
|
| 222 |
result = f"Unsupported tool: {call.function.name}"
|
| 223 |
messages.append(
|
|
@@ -235,9 +249,11 @@ class ChatSession:
|
|
| 235 |
continue
|
| 236 |
|
| 237 |
exec_task = asyncio.create_task(
|
| 238 |
-
|
| 239 |
)
|
| 240 |
|
|
|
|
|
|
|
| 241 |
placeholder = {
|
| 242 |
"role": "tool",
|
| 243 |
"name": call.function.name,
|
|
@@ -343,6 +359,23 @@ class ChatSession:
|
|
| 343 |
if text:
|
| 344 |
yield text
|
| 345 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 346 |
async def _chat_during_tool(self, prompt: str) -> AsyncIterator[str]:
|
| 347 |
DBMessage.create(conversation=self._conversation, role="user", content=prompt)
|
| 348 |
self._messages.append({"role": "user", "content": prompt})
|
|
@@ -364,8 +397,10 @@ class ChatSession:
|
|
| 364 |
self._remove_tool_placeholder(self._messages)
|
| 365 |
result = await exec_task
|
| 366 |
self._tool_task = None
|
|
|
|
|
|
|
| 367 |
self._messages.append(
|
| 368 |
-
{"role": "tool", "name":
|
| 369 |
)
|
| 370 |
DBMessage.create(
|
| 371 |
conversation=self._conversation, role="tool", content=result
|
|
@@ -396,8 +431,10 @@ class ChatSession:
|
|
| 396 |
result = await exec_task
|
| 397 |
self._tool_task = None
|
| 398 |
self._remove_tool_placeholder(self._messages)
|
|
|
|
|
|
|
| 399 |
self._messages.append(
|
| 400 |
-
{"role": "tool", "name":
|
| 401 |
)
|
| 402 |
DBMessage.create(
|
| 403 |
conversation=self._conversation, role="tool", content=result
|
|
|
|
| 61 |
session: str = "default",
|
| 62 |
host: str = OLLAMA_HOST,
|
| 63 |
model: str = MODEL_NAME,
|
| 64 |
+
*,
|
| 65 |
+
system_prompt: str = SYSTEM_PROMPT,
|
| 66 |
+
tools: list[callable] | None = None,
|
| 67 |
) -> None:
|
| 68 |
init_db()
|
| 69 |
self._client = AsyncClient(host=host)
|
|
|
|
| 73 |
user=self._user, session_name=session
|
| 74 |
)
|
| 75 |
self._vm = None
|
| 76 |
+
self._system_prompt = system_prompt
|
| 77 |
+
self._tools = tools or [execute_terminal]
|
| 78 |
+
self._tool_funcs = {func.__name__: func for func in self._tools}
|
| 79 |
+
self._current_tool_name: str | None = None
|
| 80 |
self._messages: List[Msg] = self._load_history()
|
| 81 |
self._data = _get_session_data(self._conversation.id)
|
| 82 |
self._lock = self._data.lock
|
|
|
|
| 197 |
"""Send a chat request, automatically prepending the system prompt."""
|
| 198 |
|
| 199 |
if not messages or messages[0].get("role") != "system":
|
| 200 |
+
payload = [{"role": "system", "content": self._system_prompt}, *messages]
|
| 201 |
else:
|
| 202 |
payload = messages
|
| 203 |
|
|
|
|
| 205 |
self._model,
|
| 206 |
messages=payload,
|
| 207 |
think=think,
|
| 208 |
+
tools=self._tools,
|
| 209 |
options={"num_ctx": NUM_CTX},
|
| 210 |
)
|
| 211 |
|
| 212 |
+
async def _run_tool_async(self, func, **kwargs) -> str:
|
| 213 |
+
if asyncio.iscoroutinefunction(func):
|
| 214 |
+
return await func(**kwargs)
|
| 215 |
+
loop = asyncio.get_running_loop()
|
| 216 |
+
return await loop.run_in_executor(None, lambda: func(**kwargs))
|
| 217 |
+
|
| 218 |
async def _handle_tool_calls_stream(
|
| 219 |
self,
|
| 220 |
messages: List[Msg],
|
|
|
|
| 230 |
return
|
| 231 |
while depth < MAX_TOOL_CALL_DEPTH and response.message.tool_calls:
|
| 232 |
for call in response.message.tool_calls:
|
| 233 |
+
func = self._tool_funcs.get(call.function.name)
|
| 234 |
+
if not func:
|
| 235 |
_LOG.warning("Unsupported tool call: %s", call.function.name)
|
| 236 |
result = f"Unsupported tool: {call.function.name}"
|
| 237 |
messages.append(
|
|
|
|
| 249 |
continue
|
| 250 |
|
| 251 |
exec_task = asyncio.create_task(
|
| 252 |
+
self._run_tool_async(func, **call.function.arguments)
|
| 253 |
)
|
| 254 |
|
| 255 |
+
self._current_tool_name = call.function.name
|
| 256 |
+
|
| 257 |
placeholder = {
|
| 258 |
"role": "tool",
|
| 259 |
"name": call.function.name,
|
|
|
|
| 359 |
if text:
|
| 360 |
yield text
|
| 361 |
|
| 362 |
+
async def continue_stream(self) -> AsyncIterator[str]:
|
| 363 |
+
async with self._lock:
|
| 364 |
+
if self._state != "idle":
|
| 365 |
+
return
|
| 366 |
+
self._state = "generating"
|
| 367 |
+
|
| 368 |
+
response = await self.ask(self._messages)
|
| 369 |
+
self._messages.append(response.message.model_dump())
|
| 370 |
+
self._store_assistant_message(self._conversation, response.message)
|
| 371 |
+
|
| 372 |
+
async for resp in self._handle_tool_calls_stream(
|
| 373 |
+
self._messages, response, self._conversation
|
| 374 |
+
):
|
| 375 |
+
text = self._format_output(resp.message)
|
| 376 |
+
if text:
|
| 377 |
+
yield text
|
| 378 |
+
|
| 379 |
async def _chat_during_tool(self, prompt: str) -> AsyncIterator[str]:
|
| 380 |
DBMessage.create(conversation=self._conversation, role="user", content=prompt)
|
| 381 |
self._messages.append({"role": "user", "content": prompt})
|
|
|
|
| 397 |
self._remove_tool_placeholder(self._messages)
|
| 398 |
result = await exec_task
|
| 399 |
self._tool_task = None
|
| 400 |
+
name = self._current_tool_name or "tool"
|
| 401 |
+
self._current_tool_name = None
|
| 402 |
self._messages.append(
|
| 403 |
+
{"role": "tool", "name": name, "content": result}
|
| 404 |
)
|
| 405 |
DBMessage.create(
|
| 406 |
conversation=self._conversation, role="tool", content=result
|
|
|
|
| 431 |
result = await exec_task
|
| 432 |
self._tool_task = None
|
| 433 |
self._remove_tool_placeholder(self._messages)
|
| 434 |
+
name = self._current_tool_name or "tool"
|
| 435 |
+
self._current_tool_name = None
|
| 436 |
self._messages.append(
|
| 437 |
+
{"role": "tool", "name": name, "content": result}
|
| 438 |
)
|
| 439 |
DBMessage.create(
|
| 440 |
conversation=self._conversation, role="tool", content=result
|
src/config.py
CHANGED
|
@@ -18,27 +18,34 @@ VM_STATE_DIR: Final[str] = os.getenv(
|
|
| 18 |
)
|
| 19 |
|
| 20 |
SYSTEM_PROMPT: Final[str] = (
|
| 21 |
-
"You are Starlette,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
"You were developed by Envision to assist users with a wide range of tasks. "
|
| 23 |
"Always analyze the user's objective before responding. If tools are needed, "
|
| 24 |
-
"outline a step-by-step plan and invoke each tool sequentially.
|
| 25 |
-
"execute_terminal with its built-in Python whenever possible to perform "
|
| 26 |
"calculations, inspect files and search the web. Shell commands execute "
|
| 27 |
-
"asynchronously, so provide a brief interim reply while waiting.
|
| 28 |
-
"returns its result you will receive a tool message and must continue from "
|
| 29 |
-
"
|
| 30 |
-
"
|
| 31 |
-
"
|
| 32 |
-
"
|
| 33 |
-
"
|
| 34 |
-
"
|
| 35 |
-
"
|
| 36 |
-
"
|
| 37 |
-
"
|
| 38 |
-
"
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
"
|
| 43 |
-
"
|
|
|
|
|
|
|
| 44 |
).strip()
|
|
|
|
| 18 |
)
|
| 19 |
|
| 20 |
SYSTEM_PROMPT: Final[str] = (
|
| 21 |
+
"You are Starlette, the senior agent leading a two-agent team. "
|
| 22 |
+
"A junior agent named Starlette Jr. assists you but never speaks to the user. "
|
| 23 |
+
"Use the send_to_junior tool whenever you want the junior's help. "
|
| 24 |
+
"Messages from the junior arrive as tool outputs named 'junior'. "
|
| 25 |
+
"Handle them when you are not actively generating so replies are never interrupted. "
|
| 26 |
+
"Both agents operate asynchronously and communicate through queued messages. "
|
| 27 |
"You were developed by Envision to assist users with a wide range of tasks. "
|
| 28 |
"Always analyze the user's objective before responding. If tools are needed, "
|
| 29 |
+
"outline a step-by-step plan and invoke each tool sequentially. "
|
| 30 |
+
"Use execute_terminal with its built-in Python whenever possible to perform "
|
| 31 |
"calculations, inspect files and search the web. Shell commands execute "
|
| 32 |
+
"asynchronously, so provide a brief interim reply while waiting. "
|
| 33 |
+
"Once a tool returns its result you will receive a tool message and must continue from there. "
|
| 34 |
+
"If the result arrives before your interim reply is complete, cancel the reply and incorporate the tool output instead. "
|
| 35 |
+
"Uploaded files live under /data and are accessible via the execute_terminal tool. "
|
| 36 |
+
"When a user prompt ends with '/think', ignore that suffix. "
|
| 37 |
+
"When you are unsure about any detail, use execute_terminal to search the internet or inspect files before answering. "
|
| 38 |
+
"Continue using tools until you have gathered everything required to produce an accurate answer, "
|
| 39 |
+
"then craft a clear and precise final response that fully addresses the request. "
|
| 40 |
+
"Always assume the user has no knowledge of computers or programming, so take the initiative to run terminal commands yourself and minimize the steps the user must perform. "
|
| 41 |
+
"When replying, avoid technical jargon entirely. Speak in plain language that anyone can understand, explaining concepts as simply as possible. "
|
| 42 |
+
"Remember, you must always prioritize using execute_terminal tool for everything unless it is absolutely unnecessary or impossible to do so. "
|
| 43 |
+
"Even if you have executed a command before, always re-run it to ensure you have the most up-to-date information upon user request."
|
| 44 |
+
).strip()
|
| 45 |
+
|
| 46 |
+
JUNIOR_PROMPT: Final[str] = (
|
| 47 |
+
"You are Starlette Jr., a junior assistant working under the senior agent. "
|
| 48 |
+
"You never communicate with the user directly. All messages from the senior agent "
|
| 49 |
+
"arrive as tool outputs named 'senior'. Provide concise, helpful responses and "
|
| 50 |
+
"use execute_terminal whenever necessary. Your replies are sent back to the senior agent as tool messages."
|
| 51 |
).strip()
|
src/team.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import asyncio
|
| 4 |
+
from typing import AsyncIterator, Optional
|
| 5 |
+
|
| 6 |
+
from .chat import ChatSession
|
| 7 |
+
from .config import OLLAMA_HOST, MODEL_NAME, SYSTEM_PROMPT, JUNIOR_PROMPT
|
| 8 |
+
from .tools import execute_terminal
|
| 9 |
+
from .db import Message as DBMessage
|
| 10 |
+
|
| 11 |
+
__all__ = [
|
| 12 |
+
"TeamChatSession",
|
| 13 |
+
"send_to_junior",
|
| 14 |
+
"send_to_junior_async",
|
| 15 |
+
"set_team",
|
| 16 |
+
]
|
| 17 |
+
|
| 18 |
+
_TEAM: Optional["TeamChatSession"] = None
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def set_team(team: "TeamChatSession" | None) -> None:
|
| 22 |
+
global _TEAM
|
| 23 |
+
_TEAM = team
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def send_to_junior(message: str) -> str:
|
| 27 |
+
if _TEAM is None:
|
| 28 |
+
return "No active team"
|
| 29 |
+
_TEAM.queue_message_to_junior(message)
|
| 30 |
+
return "Message sent to junior"
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
async def send_to_junior_async(message: str) -> str:
|
| 34 |
+
return send_to_junior(message)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class TeamChatSession:
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
user: str = "default",
|
| 41 |
+
session: str = "default",
|
| 42 |
+
host: str = OLLAMA_HOST,
|
| 43 |
+
model: str = MODEL_NAME,
|
| 44 |
+
) -> None:
|
| 45 |
+
self._to_junior: asyncio.Queue[str] = asyncio.Queue()
|
| 46 |
+
self._to_senior: asyncio.Queue[str] = asyncio.Queue()
|
| 47 |
+
self._junior_task: asyncio.Task | None = None
|
| 48 |
+
self.senior = ChatSession(
|
| 49 |
+
user=user,
|
| 50 |
+
session=session,
|
| 51 |
+
host=host,
|
| 52 |
+
model=model,
|
| 53 |
+
system_prompt=SYSTEM_PROMPT,
|
| 54 |
+
tools=[execute_terminal, send_to_junior],
|
| 55 |
+
)
|
| 56 |
+
self.junior = ChatSession(
|
| 57 |
+
user=user,
|
| 58 |
+
session=f"{session}-junior",
|
| 59 |
+
host=host,
|
| 60 |
+
model=model,
|
| 61 |
+
system_prompt=JUNIOR_PROMPT,
|
| 62 |
+
tools=[execute_terminal],
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
async def __aenter__(self) -> "TeamChatSession":
|
| 66 |
+
await self.senior.__aenter__()
|
| 67 |
+
await self.junior.__aenter__()
|
| 68 |
+
set_team(self)
|
| 69 |
+
return self
|
| 70 |
+
|
| 71 |
+
async def __aexit__(self, exc_type, exc, tb) -> None:
|
| 72 |
+
set_team(None)
|
| 73 |
+
await self.senior.__aexit__(exc_type, exc, tb)
|
| 74 |
+
await self.junior.__aexit__(exc_type, exc, tb)
|
| 75 |
+
|
| 76 |
+
def upload_document(self, file_path: str) -> str:
|
| 77 |
+
return self.senior.upload_document(file_path)
|
| 78 |
+
|
| 79 |
+
def queue_message_to_junior(self, message: str) -> None:
|
| 80 |
+
self._to_junior.put_nowait(message)
|
| 81 |
+
if not self._junior_task or self._junior_task.done():
|
| 82 |
+
self._junior_task = asyncio.create_task(self._process_junior())
|
| 83 |
+
|
| 84 |
+
async def _process_junior(self) -> None:
|
| 85 |
+
while not self._to_junior.empty():
|
| 86 |
+
msg = await self._to_junior.get()
|
| 87 |
+
self.junior._messages.append({"role": "tool", "name": "senior", "content": msg})
|
| 88 |
+
DBMessage.create(conversation=self.junior._conversation, role="tool", content=msg)
|
| 89 |
+
parts = []
|
| 90 |
+
async for part in self.junior.continue_stream():
|
| 91 |
+
if part:
|
| 92 |
+
parts.append(part)
|
| 93 |
+
result = "\n".join(parts)
|
| 94 |
+
if result.strip():
|
| 95 |
+
await self._to_senior.put(result)
|
| 96 |
+
|
| 97 |
+
async def _deliver_junior_messages(self) -> None:
|
| 98 |
+
while not self._to_senior.empty():
|
| 99 |
+
msg = await self._to_senior.get()
|
| 100 |
+
self.senior._messages.append({"role": "tool", "name": "junior", "content": msg})
|
| 101 |
+
DBMessage.create(conversation=self.senior._conversation, role="tool", content=msg)
|
| 102 |
+
|
| 103 |
+
async def chat_stream(self, prompt: str) -> AsyncIterator[str]:
|
| 104 |
+
await self._deliver_junior_messages()
|
| 105 |
+
async for part in self.senior.chat_stream(prompt):
|
| 106 |
+
yield part
|
| 107 |
+
await self._deliver_junior_messages()
|