brainsqueeze's picture
v3 (#2)
f5c9c80 verified
from typing import TypedDict, Literal, Any
from collections.abc import Iterator, Sequence
from dataclasses import asdict
import logging
import json
from langchain_core.messages.tool import ToolMessage
from gradio import ChatMessage
logging.basicConfig(format="[%(levelname)s] (%(asctime)s) :: %(message)s")
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class ToolInput(TypedDict):
name: str
args: dict[str, Any]
id: str
type: Literal["tool_call"]
class CalledTool(TypedDict):
id: str
name: Literal["tools"]
input: list[ToolInput]
triggers: tuple[str, ...]
class ToolResult(TypedDict):
id: str
name: Literal["tools"]
error: bool | None
result: list[tuple[str, list[ToolMessage]]]
interrupts: list
def convert_history_for_graph_agent(history: Sequence[dict | ChatMessage]) -> list[dict]:
_hist = []
for h in history:
if isinstance(h, ChatMessage):
h = asdict(h) # noqa: PLW2901
if h.get("content"):
# if h.get("metadata"):
# # skip if it's a tool-call
# continue
_hist.append(h)
return _hist
def format_tool_call(input_chunk: CalledTool) -> Iterator[ChatMessage]:
for graph_input in input_chunk["input"]:
yield ChatMessage(
role="assistant",
content=json.dumps(graph_input["args"]),
metadata={
"title": f"Using tool `{graph_input.get('name')}`",
"status": "done",
"id": input_chunk["id"],
"parent_id": input_chunk["id"]
}
)
def format_tool_response(result_chunk: ToolResult) -> Iterator[ChatMessage]:
for _, outputs in result_chunk["result"]:
for tool in outputs:
logger.info("Called tool `%s`", tool.name)
yield ChatMessage(
role="assistant",
content=tool.content,
metadata={
"title": f"Results from tool `{tool.name}`",
"tool_name": tool.name,
"documents": tool.artifact,
"status": "done",
"parent_id": result_chunk["id"]
} # pyright: ignore[reportArgumentType]
)