Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| from typing import List, AsyncIterator | |
| from dataclasses import dataclass, field | |
| import json | |
| import asyncio | |
| import shutil | |
| from pathlib import Path | |
| from ollama import AsyncClient, ChatResponse, Message | |
| from .config import ( | |
| MAX_TOOL_CALL_DEPTH, | |
| MODEL_NAME, | |
| NUM_CTX, | |
| OLLAMA_HOST, | |
| SYSTEM_PROMPT, | |
| UPLOAD_DIR, | |
| ) | |
| from .db import ( | |
| Conversation, | |
| Message as DBMessage, | |
| User, | |
| _db, | |
| init_db, | |
| add_document, | |
| ) | |
| from .log import get_logger | |
| from .schema import Msg | |
| from .tools import execute_terminal, execute_terminal_async, set_vm | |
| from .vm import VMRegistry | |
| class _SessionData: | |
| """Shared state for each conversation session.""" | |
| lock: asyncio.Lock = field(default_factory=asyncio.Lock) | |
| state: str = "idle" | |
| tool_task: asyncio.Task | None = None | |
| _SESSION_DATA: dict[int, _SessionData] = {} | |
| def _get_session_data(conv_id: int) -> _SessionData: | |
| data = _SESSION_DATA.get(conv_id) | |
| if data is None: | |
| data = _SessionData() | |
| _SESSION_DATA[conv_id] = data | |
| return data | |
| _LOG = get_logger(__name__) | |
| class ChatSession: | |
| def __init__( | |
| self, | |
| user: str = "default", | |
| session: str = "default", | |
| host: str = OLLAMA_HOST, | |
| model: str = MODEL_NAME, | |
| *, | |
| system_prompt: str = SYSTEM_PROMPT, | |
| tools: list[callable] | None = None, | |
| ) -> None: | |
| init_db() | |
| self._client = AsyncClient(host=host) | |
| self._model = model | |
| self._user, _ = User.get_or_create(username=user) | |
| self._conversation, _ = Conversation.get_or_create( | |
| user=self._user, session_name=session | |
| ) | |
| self._vm = None | |
| self._system_prompt = system_prompt | |
| self._tools = tools or [execute_terminal] | |
| self._tool_funcs = {func.__name__: func for func in self._tools} | |
| self._current_tool_name: str | None = None | |
| self._messages: List[Msg] = self._load_history() | |
| self._data = _get_session_data(self._conversation.id) | |
| self._lock = self._data.lock | |
| self._prompt_queue: asyncio.Queue[ | |
| tuple[str, asyncio.Queue[str | None]] | |
| ] = asyncio.Queue() | |
| self._worker: asyncio.Task | None = None | |
| # Shared state properties ------------------------------------------------- | |
| def _state(self) -> str: | |
| return self._data.state | |
| def _state(self, value: str) -> None: | |
| self._data.state = value | |
| def _tool_task(self) -> asyncio.Task | None: | |
| return self._data.tool_task | |
| def _tool_task(self, task: asyncio.Task | None) -> None: | |
| self._data.tool_task = task | |
| async def __aenter__(self) -> "ChatSession": | |
| self._vm = VMRegistry.acquire(self._user.username) | |
| set_vm(self._vm) | |
| return self | |
| async def __aexit__(self, exc_type, exc, tb) -> None: | |
| set_vm(None) | |
| if self._vm: | |
| VMRegistry.release(self._user.username) | |
| if not _db.is_closed(): | |
| _db.close() | |
| def upload_document(self, file_path: str) -> str: | |
| """Save a document for later access inside the VM. | |
| The file is copied into ``UPLOAD_DIR`` and recorded in the database. The | |
| returned path is the location inside the VM (prefixed with ``/data``). | |
| """ | |
| src = Path(file_path) | |
| if not src.exists(): | |
| raise FileNotFoundError(file_path) | |
| dest = Path(UPLOAD_DIR) / self._user.username | |
| dest.mkdir(parents=True, exist_ok=True) | |
| target = dest / src.name | |
| shutil.copy(src, target) | |
| add_document(self._user.username, str(target), src.name) | |
| return f"/data/{src.name}" | |
| def _load_history(self) -> List[Msg]: | |
| messages: List[Msg] = [] | |
| for msg in self._conversation.messages.order_by(DBMessage.created_at): | |
| if msg.role == "system": | |
| # Skip persisted system prompts from older versions | |
| continue | |
| if msg.role == "assistant": | |
| try: | |
| calls = json.loads(msg.content) | |
| except json.JSONDecodeError: | |
| messages.append({"role": "assistant", "content": msg.content}) | |
| else: | |
| messages.append( | |
| { | |
| "role": "assistant", | |
| "tool_calls": [Message.ToolCall(**c) for c in calls], | |
| } | |
| ) | |
| elif msg.role == "user": | |
| messages.append({"role": "user", "content": msg.content}) | |
| else: | |
| messages.append({"role": "tool", "content": msg.content}) | |
| return messages | |
| # ------------------------------------------------------------------ | |
| def _serialize_tool_calls(calls: List[Message.ToolCall]) -> str: | |
| """Convert tool calls to a JSON string for storage or output.""" | |
| return json.dumps([c.model_dump() for c in calls]) | |
| def _format_output(message: Message) -> str: | |
| """Return tool calls as JSON or message content if present.""" | |
| # if message.tool_calls: | |
| # return ChatSession._serialize_tool_calls(message.tool_calls) | |
| return message.content or "" | |
| def _remove_tool_placeholder(messages: List[Msg]) -> None: | |
| """Remove the pending placeholder tool message if present.""" | |
| for i in range(len(messages) - 1, -1, -1): | |
| msg = messages[i] | |
| if ( | |
| msg.get("role") == "tool" | |
| and msg.get("content") == "Awaiting tool response..." | |
| ): | |
| messages.pop(i) | |
| break | |
| def _store_assistant_message(conversation: Conversation, message: Message) -> None: | |
| """Persist assistant messages, storing tool calls when present.""" | |
| if message.tool_calls: | |
| content = ChatSession._serialize_tool_calls(message.tool_calls) | |
| else: | |
| content = message.content or "" | |
| if content.strip(): | |
| DBMessage.create( | |
| conversation=conversation, role="assistant", content=content | |
| ) | |
| async def ask(self, messages: List[Msg], *, think: bool = True) -> ChatResponse: | |
| """Send a chat request, automatically prepending the system prompt.""" | |
| if not messages or messages[0].get("role") != "system": | |
| payload = [{"role": "system", "content": self._system_prompt}, *messages] | |
| else: | |
| payload = messages | |
| return await self._client.chat( | |
| self._model, | |
| messages=payload, | |
| think=think, | |
| tools=self._tools, | |
| options={"num_ctx": NUM_CTX}, | |
| ) | |
| async def _run_tool_async(self, func, **kwargs) -> str: | |
| if asyncio.iscoroutinefunction(func): | |
| return await func(**kwargs) | |
| loop = asyncio.get_running_loop() | |
| return await loop.run_in_executor(None, lambda: func(**kwargs)) | |
| async def _handle_tool_calls_stream( | |
| self, | |
| messages: List[Msg], | |
| response: ChatResponse, | |
| conversation: Conversation, | |
| depth: int = 0, | |
| ) -> AsyncIterator[ChatResponse]: | |
| if not response.message.tool_calls: | |
| if response.message.content: | |
| yield response | |
| async with self._lock: | |
| self._state = "idle" | |
| return | |
| while depth < MAX_TOOL_CALL_DEPTH and response.message.tool_calls: | |
| for call in response.message.tool_calls: | |
| func = self._tool_funcs.get(call.function.name) | |
| if not func: | |
| _LOG.warning("Unsupported tool call: %s", call.function.name) | |
| result = f"Unsupported tool: {call.function.name}" | |
| name = ( | |
| "junior" if call.function.name == "send_to_junior" else call.function.name | |
| ) | |
| messages.append({"role": "tool", "name": name, "content": result}) | |
| DBMessage.create( | |
| conversation=conversation, | |
| role="tool", | |
| content=result, | |
| ) | |
| continue | |
| exec_task = asyncio.create_task( | |
| self._run_tool_async(func, **call.function.arguments) | |
| ) | |
| self._current_tool_name = call.function.name | |
| placeholder = { | |
| "role": "tool", | |
| "name": "junior" if call.function.name == "send_to_junior" else call.function.name, | |
| "content": "Awaiting tool response...", | |
| } | |
| messages.append(placeholder) | |
| follow_task = asyncio.create_task(self.ask(messages, think=True)) | |
| async with self._lock: | |
| self._state = "awaiting_tool" | |
| self._tool_task = exec_task | |
| done, _ = await asyncio.wait( | |
| {exec_task, follow_task}, | |
| return_when=asyncio.FIRST_COMPLETED, | |
| ) | |
| if exec_task in done: | |
| follow_task.cancel() | |
| try: | |
| await follow_task | |
| except asyncio.CancelledError: | |
| pass | |
| self._remove_tool_placeholder(messages) | |
| result = await exec_task | |
| name = ( | |
| "junior" if call.function.name == "send_to_junior" else call.function.name | |
| ) | |
| messages.append({"role": "tool", "name": name, "content": result}) | |
| DBMessage.create( | |
| conversation=conversation, | |
| role="tool", | |
| content=result, | |
| ) | |
| async with self._lock: | |
| self._state = "generating" | |
| self._tool_task = None | |
| nxt = await self.ask(messages, think=True) | |
| self._store_assistant_message(conversation, nxt.message) | |
| messages.append(nxt.message.model_dump()) | |
| response = nxt | |
| yield nxt | |
| else: | |
| followup = await follow_task | |
| self._store_assistant_message(conversation, followup.message) | |
| messages.append(followup.message.model_dump()) | |
| yield followup | |
| result = await exec_task | |
| self._remove_tool_placeholder(messages) | |
| name = ( | |
| "junior" if call.function.name == "send_to_junior" else call.function.name | |
| ) | |
| messages.append({"role": "tool", "name": name, "content": result}) | |
| DBMessage.create( | |
| conversation=conversation, | |
| role="tool", | |
| content=result, | |
| ) | |
| async with self._lock: | |
| self._state = "generating" | |
| self._tool_task = None | |
| nxt = await self.ask(messages, think=True) | |
| self._store_assistant_message(conversation, nxt.message) | |
| messages.append(nxt.message.model_dump()) | |
| response = nxt | |
| yield nxt | |
| depth += 1 | |
| async with self._lock: | |
| self._state = "idle" | |
| async def _generate_stream(self, prompt: str) -> AsyncIterator[str]: | |
| async with self._lock: | |
| if self._state == "awaiting_tool" and self._tool_task: | |
| async for part in self._chat_during_tool(prompt): | |
| yield part | |
| return | |
| self._state = "generating" | |
| DBMessage.create(conversation=self._conversation, role="user", content=prompt) | |
| self._messages.append({"role": "user", "content": prompt}) | |
| response = await self.ask(self._messages) | |
| self._messages.append(response.message.model_dump()) | |
| self._store_assistant_message(self._conversation, response.message) | |
| async for resp in self._handle_tool_calls_stream( | |
| self._messages, response, self._conversation | |
| ): | |
| text = self._format_output(resp.message) | |
| if text: | |
| yield text | |
| async def _process_prompt_queue(self) -> None: | |
| try: | |
| while not self._prompt_queue.empty(): | |
| prompt, result_q = await self._prompt_queue.get() | |
| try: | |
| async for part in self._generate_stream(prompt): | |
| await result_q.put(part) | |
| except Exception as exc: # pragma: no cover - unforeseen errors | |
| _LOG.exception("Error processing prompt: %s", exc) | |
| await result_q.put(f"Error: {exc}") | |
| finally: | |
| await result_q.put(None) | |
| finally: | |
| self._worker = None | |
| async def chat_stream(self, prompt: str) -> AsyncIterator[str]: | |
| result_q: asyncio.Queue[str | None] = asyncio.Queue() | |
| await self._prompt_queue.put((prompt, result_q)) | |
| if not self._worker or self._worker.done(): | |
| self._worker = asyncio.create_task(self._process_prompt_queue()) | |
| while True: | |
| part = await result_q.get() | |
| if part is None: | |
| break | |
| yield part | |
| async def continue_stream(self) -> AsyncIterator[str]: | |
| async with self._lock: | |
| if self._state != "idle": | |
| return | |
| self._state = "generating" | |
| response = await self.ask(self._messages) | |
| self._messages.append(response.message.model_dump()) | |
| self._store_assistant_message(self._conversation, response.message) | |
| async for resp in self._handle_tool_calls_stream( | |
| self._messages, response, self._conversation | |
| ): | |
| text = self._format_output(resp.message) | |
| if text: | |
| yield text | |
| async def _chat_during_tool(self, prompt: str) -> AsyncIterator[str]: | |
| DBMessage.create(conversation=self._conversation, role="user", content=prompt) | |
| self._messages.append({"role": "user", "content": prompt}) | |
| user_task = asyncio.create_task(self.ask(self._messages)) | |
| exec_task = self._tool_task | |
| done, _ = await asyncio.wait( | |
| {exec_task, user_task}, | |
| return_when=asyncio.FIRST_COMPLETED, | |
| ) | |
| if exec_task in done: | |
| user_task.cancel() | |
| try: | |
| await user_task | |
| except asyncio.CancelledError: | |
| pass | |
| self._remove_tool_placeholder(self._messages) | |
| result = await exec_task | |
| self._tool_task = None | |
| name = self._current_tool_name or "tool" | |
| self._current_tool_name = None | |
| self._messages.append({"role": "tool", "name": name, "content": result}) | |
| DBMessage.create( | |
| conversation=self._conversation, role="tool", content=result | |
| ) | |
| async with self._lock: | |
| self._state = "generating" | |
| nxt = await self.ask(self._messages, think=True) | |
| self._store_assistant_message(self._conversation, nxt.message) | |
| self._messages.append(nxt.message.model_dump()) | |
| text = self._format_output(nxt.message) | |
| if text: | |
| yield text | |
| async for part in self._handle_tool_calls_stream( | |
| self._messages, nxt, self._conversation | |
| ): | |
| text = self._format_output(part.message) | |
| if text: | |
| yield text | |
| else: | |
| resp = await user_task | |
| self._store_assistant_message(self._conversation, resp.message) | |
| self._messages.append(resp.message.model_dump()) | |
| async with self._lock: | |
| self._state = "awaiting_tool" | |
| text = self._format_output(resp.message) | |
| if text: | |
| yield text | |
| result = await exec_task | |
| self._tool_task = None | |
| self._remove_tool_placeholder(self._messages) | |
| name = self._current_tool_name or "tool" | |
| self._current_tool_name = None | |
| self._messages.append({"role": "tool", "name": name, "content": result}) | |
| DBMessage.create( | |
| conversation=self._conversation, role="tool", content=result | |
| ) | |
| async with self._lock: | |
| self._state = "generating" | |
| nxt = await self.ask(self._messages, think=True) | |
| self._store_assistant_message(self._conversation, nxt.message) | |
| self._messages.append(nxt.message.model_dump()) | |
| text = self._format_output(nxt.message) | |
| if text: | |
| yield text | |
| async for part in self._handle_tool_calls_stream( | |
| self._messages, nxt, self._conversation | |
| ): | |
| text = self._format_output(part.message) | |
| if text: | |
| yield text | |