diff --git a/langchain_zhipuai/agent_toolkits/all_tools/drawing_tool.py b/langchain_zhipuai/agent_toolkits/all_tools/drawing_tool.py new file mode 100644 index 0000000..5d86d4a --- /dev/null +++ b/langchain_zhipuai/agent_toolkits/all_tools/drawing_tool.py @@ -0,0 +1,89 @@ +import logging +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union + +from langchain_core.agents import AgentAction +from langchain_core.callbacks import ( + AsyncCallbackManagerForChainRun, + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) + +from langchain_zhipuai.agent_toolkits import AdapterAllTool +from langchain_zhipuai.agent_toolkits.all_tools.struct_type import AdapterAllToolStructType +from langchain_zhipuai.agent_toolkits.all_tools.tool import ( + AllToolExecutor, + BaseToolOutput, +) + +logger = logging.getLogger(__name__) + + +class DrawingToolOutput(BaseToolOutput): + platform_params: Dict[str, Any] + + def __init__( + self, + data: Any, + platform_params: Dict[str, Any], + **extras: Any, + ) -> None: + super().__init__(data, "", "", **extras) + self.platform_params = platform_params + + +@dataclass +class DrawingAllToolExecutor(AllToolExecutor): + """platform adapter tool for code interpreter tool""" + + name: str + + def run( + self, + tool: str, + tool_input: str, + log: str, + outputs: List[Union[str, dict]] = None, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> DrawingToolOutput: + if outputs is None or str(outputs).strip() == "": + raise ValueError( + f"Tool {self.name} is server error" + ) + + return DrawingToolOutput( + data=f"""Access:{tool}, Message: {tool_input},{log}""", + platform_params=self.platform_params, + ) + + async def arun( + self, + tool: str, + tool_input: str, + log: str, + outputs: List[Union[str, dict]] = None, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> DrawingToolOutput: + """Use the tool asynchronously.""" + if outputs is None or str(outputs).strip() == "" or len(outputs) == 0: + raise ValueError( + f"Tool {self.name} is server error" + ) + + return DrawingToolOutput( + data=f"""Access:{tool}, Message: {tool_input},{log}""", + platform_params=self.platform_params, + ) + + +class DrawingAdapterAllTool(AdapterAllTool[DrawingAllToolExecutor]): + @classmethod + def get_type(cls) -> str: + return "DrawingAdapterAllTool" + + def _build_adapter_all_tool( + self, platform_params: Dict[str, Any] + ) -> DrawingAllToolExecutor: + return DrawingAllToolExecutor( + name=AdapterAllToolStructType.DRAWING_TOOL, platform_params=platform_params + ) diff --git a/langchain_zhipuai/agent_toolkits/all_tools/registry.py b/langchain_zhipuai/agent_toolkits/all_tools/registry.py index 8b55d5c..fc3b340 100644 --- a/langchain_zhipuai/agent_toolkits/all_tools/registry.py +++ b/langchain_zhipuai/agent_toolkits/all_tools/registry.py @@ -4,10 +4,14 @@ from langchain_zhipuai.agent_toolkits.all_tools.code_interpreter_tool import ( CodeInterpreterAdapterAllTool, ) +from langchain_zhipuai.agent_toolkits.all_tools.drawing_tool import DrawingAdapterAllTool from langchain_zhipuai.agent_toolkits.all_tools.struct_type import ( AdapterAllToolStructType, ) +from langchain_zhipuai.agent_toolkits.all_tools.web_browser_tool import WebBrowserAdapterAllTool TOOL_STRUCT_TYPE_TO_TOOL_CLASS: Dict[AdapterAllToolStructType, Type[AdapterAllTool]] = { AdapterAllToolStructType.CODE_INTERPRETER: CodeInterpreterAdapterAllTool, + AdapterAllToolStructType.DRAWING_TOOL: DrawingAdapterAllTool, + AdapterAllToolStructType.WEB_BROWSER: WebBrowserAdapterAllTool, } diff --git a/langchain_zhipuai/agent_toolkits/all_tools/struct_type.py b/langchain_zhipuai/agent_toolkits/all_tools/struct_type.py index 8575bd0..ebebd50 100644 --- a/langchain_zhipuai/agent_toolkits/all_tools/struct_type.py +++ b/langchain_zhipuai/agent_toolkits/all_tools/struct_type.py @@ -14,3 +14,5 @@ class AdapterAllToolStructType(str, Enum): # TODO: refactor so these are properties on the base class CODE_INTERPRETER = "code_interpreter" + DRAWING_TOOL = "drawing_tool" + WEB_BROWSER = "web_browser" diff --git a/langchain_zhipuai/agent_toolkits/all_tools/tool.py b/langchain_zhipuai/agent_toolkits/all_tools/tool.py index 201d62f..6f106b9 100644 --- a/langchain_zhipuai/agent_toolkits/all_tools/tool.py +++ b/langchain_zhipuai/agent_toolkits/all_tools/tool.py @@ -31,7 +31,10 @@ ) from langchain_core.tools import BaseTool -from langchain_zhipuai.agents.output_parsers.tools import CodeInterpreterAgentAction +from langchain_zhipuai.agent_toolkits.all_tools.struct_type import AdapterAllToolStructType +from langchain_zhipuai.agents.output_parsers.code_interpreter import CodeInterpreterAgentAction +from langchain_zhipuai.agents.output_parsers.drawing_tool import DrawingToolAgentAction +from langchain_zhipuai.agents.output_parsers.web_browser import WebBrowserAgentAction logger = logging.getLogger(__name__) @@ -142,16 +145,40 @@ def _run( run_manager: Optional[AsyncCallbackManagerForChainRun] = None, **tool_run_kwargs: Any, ) -> Any: - if "code_interpreter" in agent_action.tool: + if AdapterAllToolStructType.CODE_INTERPRETER == agent_action.tool and isinstance( + agent_action, CodeInterpreterAgentAction + ): + return self.adapter_all_tool.run( + **{ + "tool": agent_action.tool, + "tool_input": agent_action.tool_input, + "log": agent_action.log, + "outputs": agent_action.outputs, + }, + **tool_run_kwargs, + ) + elif AdapterAllToolStructType.DRAWING_TOOL == agent_action.tool and isinstance( + agent_action, DrawingToolAgentAction + ): return self.adapter_all_tool.run( - { + **{ + "tool": agent_action.tool, + "tool_input": agent_action.tool_input, + "log": agent_action.log, + "outputs": agent_action.outputs, + }, + **tool_run_kwargs, + ) + elif AdapterAllToolStructType.WEB_BROWSER == agent_action.tool and isinstance( + agent_action, WebBrowserAgentAction + ): + return self.adapter_all_tool.run( + **{ "tool": agent_action.tool, "tool_input": agent_action.tool_input, "log": agent_action.log, + "outputs": agent_action.outputs, }, - verbose=self.verbose, - color="red", - callbacks=run_manager.get_child() if run_manager else None, **tool_run_kwargs, ) else: @@ -163,7 +190,7 @@ async def _arun( run_manager: Optional[AsyncCallbackManagerForChainRun] = None, **tool_run_kwargs: Any, ) -> Any: - if "code_interpreter" in agent_action.tool and isinstance( + if AdapterAllToolStructType.CODE_INTERPRETER == agent_action.tool and isinstance( agent_action, CodeInterpreterAgentAction ): return await self.adapter_all_tool.arun( @@ -175,5 +202,30 @@ async def _arun( }, **tool_run_kwargs, ) + + elif AdapterAllToolStructType.DRAWING_TOOL == agent_action.tool and isinstance( + agent_action, DrawingToolAgentAction + ): + return await self.adapter_all_tool.arun( + **{ + "tool": agent_action.tool, + "tool_input": agent_action.tool_input, + "log": agent_action.log, + "outputs": agent_action.outputs, + }, + **tool_run_kwargs, + ) + elif AdapterAllToolStructType.WEB_BROWSER == agent_action.tool and isinstance( + agent_action, WebBrowserAgentAction + ): + return await self.adapter_all_tool.arun( + **{ + "tool": agent_action.tool, + "tool_input": agent_action.tool_input, + "log": agent_action.log, + "outputs": agent_action.outputs, + }, + **tool_run_kwargs, + ) else: raise KeyError() diff --git a/langchain_zhipuai/agent_toolkits/all_tools/web_browser_tool.py b/langchain_zhipuai/agent_toolkits/all_tools/web_browser_tool.py new file mode 100644 index 0000000..c2beafe --- /dev/null +++ b/langchain_zhipuai/agent_toolkits/all_tools/web_browser_tool.py @@ -0,0 +1,89 @@ +import logging +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union + +from langchain_core.agents import AgentAction +from langchain_core.callbacks import ( + AsyncCallbackManagerForChainRun, + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) + +from langchain_zhipuai.agent_toolkits import AdapterAllTool +from langchain_zhipuai.agent_toolkits.all_tools.struct_type import AdapterAllToolStructType +from langchain_zhipuai.agent_toolkits.all_tools.tool import ( + AllToolExecutor, + BaseToolOutput, +) + +logger = logging.getLogger(__name__) + + +class WebBrowserToolOutput(BaseToolOutput): + platform_params: Dict[str, Any] + + def __init__( + self, + data: Any, + platform_params: Dict[str, Any], + **extras: Any, + ) -> None: + super().__init__(data, "", "", **extras) + self.platform_params = platform_params + + +@dataclass +class WebBrowserAllToolExecutor(AllToolExecutor): + """platform adapter tool for code interpreter tool""" + + name: str + + def run( + self, + tool: str, + tool_input: str, + log: str, + outputs: List[Union[str, dict]] = None, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> WebBrowserToolOutput: + if outputs is None or str(outputs).strip() == "": + raise ValueError( + f"Tool {self.name} is server error" + ) + + return WebBrowserToolOutput( + data=f"""Access:{tool}, Message: {tool_input},{log}""", + platform_params=self.platform_params, + ) + + async def arun( + self, + tool: str, + tool_input: str, + log: str, + outputs: List[Union[str, dict]] = None, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> WebBrowserToolOutput: + """Use the tool asynchronously.""" + if outputs is None or str(outputs).strip() == "" or len(outputs) == 0: + raise ValueError( + f"Tool {self.name} is server error" + ) + + return WebBrowserToolOutput( + data=f"""Access:{tool}, Message: {tool_input},{log}""", + platform_params=self.platform_params, + ) + + +class WebBrowserAdapterAllTool(AdapterAllTool[WebBrowserAllToolExecutor]): + @classmethod + def get_type(cls) -> str: + return "WebBrowserAdapterAllTool" + + def _build_adapter_all_tool( + self, platform_params: Dict[str, Any] + ) -> WebBrowserAllToolExecutor: + return WebBrowserAllToolExecutor( + name=AdapterAllToolStructType.WEB_BROWSER, platform_params=platform_params + ) diff --git a/langchain_zhipuai/agents/format_scratchpad/all_tools.py b/langchain_zhipuai/agents/format_scratchpad/all_tools.py index 2a7843b..5fe5c3c 100644 --- a/langchain_zhipuai/agents/format_scratchpad/all_tools.py +++ b/langchain_zhipuai/agents/format_scratchpad/all_tools.py @@ -13,7 +13,11 @@ from langchain_zhipuai.agent_toolkits.all_tools.code_interpreter_tool import ( CodeInterpreterToolOutput, ) -from langchain_zhipuai.agents.output_parsers.tools import CodeInterpreterAgentAction +from langchain_zhipuai.agent_toolkits.all_tools.drawing_tool import DrawingToolOutput +from langchain_zhipuai.agent_toolkits.all_tools.web_browser_tool import WebBrowserToolOutput +from langchain_zhipuai.agents.output_parsers.code_interpreter import CodeInterpreterAgentAction +from langchain_zhipuai.agents.output_parsers.drawing_tool import DrawingToolAgentAction +from langchain_zhipuai.agents.output_parsers.web_browser import WebBrowserAgentAction def _create_tool_message( @@ -67,6 +71,18 @@ def format_to_zhipuai_all_tool_messages( else: raise ValueError(f"Unknown observation type: {type(observation)}") + elif isinstance(agent_action, DrawingToolAgentAction): + if isinstance(observation, DrawingToolOutput): + messages.append(AIMessage(content=str(observation))) + else: + raise ValueError(f"Unknown observation type: {type(observation)}") + + elif isinstance(agent_action, WebBrowserAgentAction): + if isinstance(observation, WebBrowserToolOutput): + messages.append(AIMessage(content=str(observation))) + else: + raise ValueError(f"Unknown observation type: {type(observation)}") + elif isinstance(agent_action, ToolAgentAction): new_messages = list(agent_action.message_log) + [ _create_tool_message(agent_action, observation) diff --git a/langchain_zhipuai/agents/output_parsers/base.py b/langchain_zhipuai/agents/output_parsers/base.py new file mode 100644 index 0000000..c505283 --- /dev/null +++ b/langchain_zhipuai/agents/output_parsers/base.py @@ -0,0 +1,15 @@ +from zhipuai.core import BaseModel +from typing import Optional, Dict, Any + + +class AllToolsMessageToolCall(BaseModel): + name: Optional[str] + args: Optional[Dict[str, Any]] + id: Optional[str] + + +class AllToolsMessageToolCallChunk(BaseModel): + name: Optional[str] + args: Optional[Dict[str, Any]] + id: Optional[str] + index: Optional[int] diff --git a/langchain_zhipuai/agents/output_parsers/code_interpreter.py b/langchain_zhipuai/agents/output_parsers/code_interpreter.py new file mode 100644 index 0000000..c18f7c8 --- /dev/null +++ b/langchain_zhipuai/agents/output_parsers/code_interpreter.py @@ -0,0 +1,107 @@ +import json +import logging +from json import JSONDecodeError +from typing import Any, Dict, List, Optional, Union + +from langchain.agents.output_parsers.tools import ToolAgentAction +from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish +from langchain_core.exceptions import OutputParserException +from langchain_core.messages import ( + AIMessage, + BaseMessage, + ToolCall, +) +from langchain_core.utils.json import ( + parse_partial_json, +) +from zhipuai.core import BaseModel + +from langchain_zhipuai.agent_toolkits.all_tools.struct_type import AdapterAllToolStructType +from langchain_zhipuai.agents.output_parsers.base import AllToolsMessageToolCall, AllToolsMessageToolCallChunk +from langchain_zhipuai.chat_models.all_tools_message import ALLToolsMessageChunk + +logger = logging.getLogger(__name__) + + +class CodeInterpreterAgentAction(ToolAgentAction): + outputs: List[Union[str, dict]] = None + """Output of the tool call.""" + platform_params: dict = None + + +def _best_effort_parse_code_interpreter_tool_calls( + tool_call_chunks: List[dict], +) -> List[Union[AllToolsMessageToolCall, AllToolsMessageToolCallChunk]]: + code_interpreter_chunk: List[ + Union[AllToolsMessageToolCall, AllToolsMessageToolCallChunk] + ] = [] + # Best-effort parsing allready parsed tool calls + for code_interpreter in tool_call_chunks: + if AdapterAllToolStructType.CODE_INTERPRETER == code_interpreter["name"]: + if isinstance(code_interpreter["args"], str): + args_ = parse_partial_json(code_interpreter["args"]) + else: + args_ = code_interpreter["args"] + if not isinstance(args_, dict): + raise ValueError("Malformed args.") + + if "outputs" in args_: + code_interpreter_chunk.append( + AllToolsMessageToolCall( + name=code_interpreter["name"], + args=args_, + id=code_interpreter["id"], + ) + ) + else: + code_interpreter_chunk.append( + AllToolsMessageToolCallChunk( + name=code_interpreter["name"], + args=args_, + id=code_interpreter["id"], + index=code_interpreter.get("index"), + ) + ) + + return code_interpreter_chunk + + +def _paser_code_interpreter_chunk_input( + message: BaseMessage, + code_interpreter_chunk: List[ + Union[AllToolsMessageToolCall, AllToolsMessageToolCallChunk] + ], +) -> CodeInterpreterAgentAction: + try: + input_log_chunk = [] + + outputs = [] + for interpreter_chunk in code_interpreter_chunk: + interpreter_chunk_args = interpreter_chunk.args + + if "input" in interpreter_chunk_args: + input_log_chunk.append(interpreter_chunk_args["input"]) + if "outputs" in interpreter_chunk_args: + outputs.extend(interpreter_chunk_args["outputs"]) + + out_logs = [logs["logs"] for logs in outputs if "logs" in logs] + log = f"{''.join(input_log_chunk)}\n{''.join(out_logs)}\n" + tool_call_id = ( + code_interpreter_chunk[0].id if code_interpreter_chunk[0].id else "abc" + ) + code_interpreter_action = CodeInterpreterAgentAction( + tool=AdapterAllToolStructType.CODE_INTERPRETER, + tool_input="".join(input_log_chunk), + outputs=outputs, + log=log, + message_log=[message], + tool_call_id=tool_call_id, + ) + + return code_interpreter_action + except Exception as e: + logger.error(f"Error parsing code_interpreter_chunk: {e}", exc_info=True) + raise OutputParserException( + f"Could not parse tool input: code_interpreter because " + f"the `arguments` is not valid JSON." + ) diff --git a/langchain_zhipuai/agents/output_parsers/drawing_tool.py b/langchain_zhipuai/agents/output_parsers/drawing_tool.py new file mode 100644 index 0000000..0c476cb --- /dev/null +++ b/langchain_zhipuai/agents/output_parsers/drawing_tool.py @@ -0,0 +1,107 @@ +import json +import logging +from json import JSONDecodeError +from typing import Any, Dict, List, Optional, Union + +from langchain.agents.output_parsers.tools import ToolAgentAction +from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish +from langchain_core.exceptions import OutputParserException +from langchain_core.messages import ( + AIMessage, + BaseMessage, + ToolCall, +) +from langchain_core.utils.json import ( + parse_partial_json, +) +from zhipuai.core import BaseModel + +from langchain_zhipuai.agent_toolkits.all_tools.struct_type import AdapterAllToolStructType +from langchain_zhipuai.agents.output_parsers.base import AllToolsMessageToolCall, AllToolsMessageToolCallChunk +from langchain_zhipuai.chat_models.all_tools_message import ALLToolsMessageChunk + +logger = logging.getLogger(__name__) + + +class DrawingToolAgentAction(ToolAgentAction): + outputs: List[Union[str, dict]] = None + """Output of the tool call.""" + platform_params: dict = None + + +def _best_effort_parse_drawing_tool_tool_calls( + tool_call_chunks: List[dict], +) -> List[Union[AllToolsMessageToolCall, AllToolsMessageToolCallChunk]]: + drawing_tool_chunk: List[ + Union[AllToolsMessageToolCall, AllToolsMessageToolCallChunk] + ] = [] + # Best-effort parsing allready parsed tool calls + for drawing_tool in tool_call_chunks: + if AdapterAllToolStructType.DRAWING_TOOL == drawing_tool["name"]: + if isinstance(drawing_tool["args"], str): + args_ = parse_partial_json(drawing_tool["args"]) + else: + args_ = drawing_tool["args"] + if not isinstance(args_, dict): + raise ValueError("Malformed args.") + + if "outputs" in args_: + drawing_tool_chunk.append( + AllToolsMessageToolCall( + name=drawing_tool["name"], + args=args_, + id=drawing_tool["id"], + ) + ) + else: + drawing_tool_chunk.append( + AllToolsMessageToolCallChunk( + name=drawing_tool["name"], + args=args_, + id=drawing_tool["id"], + index=drawing_tool.get("index"), + ) + ) + + return drawing_tool_chunk + + +def _paser_drawing_tool_chunk_input( + message: BaseMessage, + drawing_tool_chunk: List[ + Union[AllToolsMessageToolCall, AllToolsMessageToolCallChunk] + ], +) -> DrawingToolAgentAction: + try: + input_log_chunk = [] + + outputs = [] + for interpreter_chunk in drawing_tool_chunk: + interpreter_chunk_args = interpreter_chunk.args + + if "input" in interpreter_chunk_args: + input_log_chunk.append(interpreter_chunk_args["input"]) + if "outputs" in interpreter_chunk_args: + outputs.extend(interpreter_chunk_args["outputs"]) + + out_logs = [logs["image"] for logs in outputs if "image" in logs] + log = f"{''.join(input_log_chunk)}\n{''.join(out_logs)}\n" + tool_call_id = ( + drawing_tool_chunk[0].id if drawing_tool_chunk[0].id else "abc" + ) + drawing_tool_action = DrawingToolAgentAction( + tool=AdapterAllToolStructType.DRAWING_TOOL, + tool_input="".join(input_log_chunk), + outputs=outputs, + log=log, + message_log=[message], + tool_call_id=tool_call_id, + ) + + return drawing_tool_action + except Exception as e: + logger.error(f"Error parsing drawing_tool_chunk: {e}", exc_info=True) + raise OutputParserException( + f"Could not parse tool input: drawing_tool because " + f"the `arguments` is not valid JSON." + ) diff --git a/langchain_zhipuai/agents/output_parsers/tools.py b/langchain_zhipuai/agents/output_parsers/tools.py index b9de4b4..0030fde 100644 --- a/langchain_zhipuai/agents/output_parsers/tools.py +++ b/langchain_zhipuai/agents/output_parsers/tools.py @@ -16,32 +16,21 @@ ) from zhipuai.core import BaseModel +from langchain_zhipuai.agent_toolkits.all_tools.struct_type import AdapterAllToolStructType +from langchain_zhipuai.agents.output_parsers.base import AllToolsMessageToolCallChunk, AllToolsMessageToolCall +from langchain_zhipuai.agents.output_parsers.code_interpreter import _best_effort_parse_code_interpreter_tool_calls, \ + _paser_code_interpreter_chunk_input +from langchain_zhipuai.agents.output_parsers.drawing_tool import _paser_drawing_tool_chunk_input, \ + _best_effort_parse_drawing_tool_tool_calls +from langchain_zhipuai.agents.output_parsers.web_browser import _paser_web_browser_chunk_input, \ + _best_effort_parse_web_browser_tool_calls from langchain_zhipuai.chat_models.all_tools_message import ALLToolsMessageChunk logger = logging.getLogger(__name__) -class CodeInterpreterAgentAction(ToolAgentAction): - outputs: List[Union[str, dict]] = None - """Output of the tool call.""" - platform_params: dict = None - - -class AllToolsMessageToolCall(BaseModel): - name: Optional[str] - args: Optional[Dict[str, Any]] - id: Optional[str] - - -class AllToolsMessageToolCallChunk(BaseModel): - name: Optional[str] - args: Optional[Dict[str, Any]] - id: Optional[str] - index: Optional[int] - - def parse_ai_message_to_tool_action( - message: BaseMessage, + message: BaseMessage, ) -> Union[List[AgentAction], AgentFinish]: """Parse an AI message potentially containing tool_calls.""" if not isinstance(message, AIMessage): @@ -79,13 +68,13 @@ def parse_ai_message_to_tool_action( f"Could not parse tool input: {function} because " f"the `arguments` is not valid JSON." ) - elif "code_interpreter" == tool_call["type"]: - code_interpreter = tool_call["code_interpreter"] + elif tool_call["type"] in AdapterAllToolStructType.__members__.values(): + adapter_tool = tool_call[tool_call["type"]] tool_calls.append( ToolCall( - name="code_interpreter", - args=code_interpreter, + name=tool_call["type"], + args=adapter_tool if adapter_tool else {}, id=tool_call["id"] if tool_call["id"] else "abc", ) ) @@ -108,10 +97,46 @@ def parse_ai_message_to_tool_action( _paser_code_interpreter_chunk_input(message, code_interpreter_chunk) ) + drawing_tool_chunk: List[ + Union[AllToolsMessageToolCall, AllToolsMessageToolCallChunk] + ] = [] + if message.tool_calls: + if isinstance(message, ALLToolsMessageChunk): + drawing_tool_chunk = _best_effort_parse_drawing_tool_tool_calls( + message.tool_call_chunks + ) + else: + drawing_tool_chunk = _best_effort_parse_drawing_tool_tool_calls( + tool_calls + ) + + if drawing_tool_chunk and len(drawing_tool_chunk) > 1: + actions.append( + _paser_drawing_tool_chunk_input(message, drawing_tool_chunk) + ) + + web_browser_chunk: List[ + Union[AllToolsMessageToolCall, AllToolsMessageToolCallChunk] + ] = [] + if message.tool_calls: + if isinstance(message, ALLToolsMessageChunk): + web_browser_chunk = _best_effort_parse_web_browser_tool_calls( + message.tool_call_chunks + ) + else: + web_browser_chunk = _best_effort_parse_web_browser_tool_calls( + tool_calls + ) + + if web_browser_chunk and len(web_browser_chunk) > 1: + actions.append( + _paser_web_browser_chunk_input(message, web_browser_chunk) + ) + # TODO: parse platform tools built-in @langchain_zhipuai - # delete code_interpreter_chunk + # delete AdapterAllToolStructType from tool_calls tool_calls = [ - tool_call for tool_call in tool_calls if "code_interpreter" != tool_call["name"] + tool_call for tool_call in tool_calls if tool_call["name"] not in AdapterAllToolStructType.__members__.values() ] for tool_call in tool_calls: @@ -141,81 +166,3 @@ def parse_ai_message_to_tool_action( ) ) return actions - - -def _best_effort_parse_code_interpreter_tool_calls( - tool_call_chunks: List[dict], -) -> List[Union[AllToolsMessageToolCall, AllToolsMessageToolCallChunk]]: - code_interpreter_chunk: List[ - Union[AllToolsMessageToolCall, AllToolsMessageToolCallChunk] - ] = [] - # Best-effort parsing allready parsed tool calls - for code_interpreter in tool_call_chunks: - if "code_interpreter" == code_interpreter["name"]: - if isinstance(code_interpreter["args"], str): - args_ = parse_partial_json(code_interpreter["args"]) - else: - args_ = code_interpreter["args"] - if not isinstance(args_, dict): - raise ValueError("Malformed args.") - - if "outputs" in args_: - code_interpreter_chunk.append( - AllToolsMessageToolCall( - name=code_interpreter["name"], - args=args_, - id=code_interpreter["id"], - ) - ) - else: - code_interpreter_chunk.append( - AllToolsMessageToolCallChunk( - name=code_interpreter["name"], - args=args_, - id=code_interpreter["id"], - index=code_interpreter.get("index"), - ) - ) - - return code_interpreter_chunk - - -def _paser_code_interpreter_chunk_input( - message: BaseMessage, - code_interpreter_chunk: List[ - Union[AllToolsMessageToolCall, AllToolsMessageToolCallChunk] - ], -) -> CodeInterpreterAgentAction: - try: - input_log_chunk = [] - - outputs = [] - for interpreter_chunk in code_interpreter_chunk: - interpreter_chunk_args = interpreter_chunk.args - - if "input" in interpreter_chunk_args: - input_log_chunk.append(interpreter_chunk_args["input"]) - if "outputs" in interpreter_chunk_args: - outputs.extend(interpreter_chunk_args["outputs"]) - - out_logs = [logs["logs"] for logs in outputs if "logs" in logs] - log = f"{''.join(input_log_chunk)}\n{''.join(out_logs)}\n" - tool_call_id = ( - code_interpreter_chunk[0].id if code_interpreter_chunk[0].id else "abc" - ) - code_interpreter_action = CodeInterpreterAgentAction( - tool="code_interpreter", - tool_input="".join(input_log_chunk), - outputs=outputs, - log=log, - message_log=[message], - tool_call_id=tool_call_id, - ) - - return code_interpreter_action - except Exception as e: - logger.error(f"Error parsing code_interpreter_chunk: {e}", exc_info=True) - raise OutputParserException( - f"Could not parse tool input: code_interpreter because " - f"the `arguments` is not valid JSON." - ) diff --git a/langchain_zhipuai/agents/output_parsers/web_browser.py b/langchain_zhipuai/agents/output_parsers/web_browser.py new file mode 100644 index 0000000..ceb8279 --- /dev/null +++ b/langchain_zhipuai/agents/output_parsers/web_browser.py @@ -0,0 +1,107 @@ +import json +import logging +from json import JSONDecodeError +from typing import Any, Dict, List, Optional, Union + +from langchain.agents.output_parsers.tools import ToolAgentAction +from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish +from langchain_core.exceptions import OutputParserException +from langchain_core.messages import ( + AIMessage, + BaseMessage, + ToolCall, +) +from langchain_core.utils.json import ( + parse_partial_json, +) +from zhipuai.core import BaseModel + +from langchain_zhipuai.agent_toolkits.all_tools.struct_type import AdapterAllToolStructType +from langchain_zhipuai.agents.output_parsers.base import AllToolsMessageToolCall, AllToolsMessageToolCallChunk +from langchain_zhipuai.chat_models.all_tools_message import ALLToolsMessageChunk + +logger = logging.getLogger(__name__) + + +class WebBrowserAgentAction(ToolAgentAction): + outputs: List[Union[str, dict]] = None + """Output of the tool call.""" + platform_params: dict = None + + +def _best_effort_parse_web_browser_tool_calls( + tool_call_chunks: List[dict], +) -> List[Union[AllToolsMessageToolCall, AllToolsMessageToolCallChunk]]: + web_browser_chunk: List[ + Union[AllToolsMessageToolCall, AllToolsMessageToolCallChunk] + ] = [] + # Best-effort parsing allready parsed tool calls + for web_browser in tool_call_chunks: + if AdapterAllToolStructType.WEB_BROWSER == web_browser["name"]: + if isinstance(web_browser["args"], str): + args_ = parse_partial_json(web_browser["args"]) + else: + args_ = web_browser["args"] + if not isinstance(args_, dict): + raise ValueError("Malformed args.") + + if "outputs" in args_: + web_browser_chunk.append( + AllToolsMessageToolCall( + name=web_browser["name"], + args=args_, + id=web_browser["id"], + ) + ) + else: + web_browser_chunk.append( + AllToolsMessageToolCallChunk( + name=web_browser["name"], + args=args_, + id=web_browser["id"], + index=web_browser.get("index"), + ) + ) + + return web_browser_chunk + + +def _paser_web_browser_chunk_input( + message: BaseMessage, + web_browser_chunk: List[ + Union[AllToolsMessageToolCall, AllToolsMessageToolCallChunk] + ], +) -> WebBrowserAgentAction: + try: + input_log_chunk = [] + + outputs = [] + for interpreter_chunk in web_browser_chunk: + interpreter_chunk_args = interpreter_chunk.args + + if "input" in interpreter_chunk_args: + input_log_chunk.append(interpreter_chunk_args["input"]) + if "outputs" in interpreter_chunk_args: + outputs.extend(interpreter_chunk_args["outputs"]) + + out_logs = [logs["logs"] for logs in outputs if "logs" in logs] + log = f"{''.join(input_log_chunk)}\n{''.join(out_logs)}\n" + tool_call_id = ( + web_browser_chunk[0].id if web_browser_chunk[0].id else "abc" + ) + web_browser_action = WebBrowserAgentAction( + tool=AdapterAllToolStructType.DRAWING_TOOL, + tool_input="".join(input_log_chunk), + outputs=outputs, + log=log, + message_log=[message], + tool_call_id=tool_call_id, + ) + + return web_browser_action + except Exception as e: + logger.error(f"Error parsing web_browser_chunk: {e}", exc_info=True) + raise OutputParserException( + f"Could not parse tool input: web_browser because " + f"the `arguments` is not valid JSON." + ) diff --git a/langchain_zhipuai/agents/output_parsers/zhipuai_all_tools.py b/langchain_zhipuai/agents/output_parsers/zhipuai_all_tools.py index 96c1dd4..953f2c4 100644 --- a/langchain_zhipuai/agents/output_parsers/zhipuai_all_tools.py +++ b/langchain_zhipuai/agents/output_parsers/zhipuai_all_tools.py @@ -5,11 +5,13 @@ from langchain_core.messages import BaseMessage from langchain_core.outputs import ChatGeneration, Generation +from langchain_zhipuai.agents.output_parsers.code_interpreter import CodeInterpreterAgentAction +from langchain_zhipuai.agents.output_parsers.drawing_tool import DrawingToolAgentAction from langchain_zhipuai.agents.output_parsers.tools import ( - CodeInterpreterAgentAction, ToolAgentAction, parse_ai_message_to_tool_action, ) +from langchain_zhipuai.agents.output_parsers.web_browser import WebBrowserAgentAction ZhipuAiALLToolAgentAction = ToolAgentAction @@ -25,6 +27,10 @@ def parse_ai_message_to_zhipuai_all_tool_action( for action in tool_actions: if isinstance(action, CodeInterpreterAgentAction): final_actions.append(action) + elif isinstance(action, DrawingToolAgentAction): + final_actions.append(action) + elif isinstance(action, WebBrowserAgentAction): + final_actions.append(action) elif isinstance(action, ToolAgentAction): final_actions.append( ZhipuAiALLToolAgentAction( diff --git a/langchain_zhipuai/agents/zhipuai_all_tools/base.py b/langchain_zhipuai/agents/zhipuai_all_tools/base.py index 5ce6f38..8ca9c7a 100644 --- a/langchain_zhipuai/agents/zhipuai_all_tools/base.py +++ b/langchain_zhipuai/agents/zhipuai_all_tools/base.py @@ -30,6 +30,7 @@ from langchain_zhipuai.agent_toolkits.all_tools.registry import ( TOOL_STRUCT_TYPE_TO_TOOL_CLASS, ) +from langchain_zhipuai.agent_toolkits.all_tools.struct_type import AdapterAllToolStructType from langchain_zhipuai.agent_toolkits.all_tools.tool import ( AdapterAllTool, BaseToolOutput, @@ -61,7 +62,7 @@ def _is_assistants_builtin_tool( tool: Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool], ) -> bool: """platform tools built-in""" - assistants_builtin_tools = ("code_interpreter",) + assistants_builtin_tools = AdapterAllToolStructType.__members__.values() return ( isinstance(tool, dict) and ("type" in tool) @@ -73,8 +74,6 @@ def _get_assistants_tool( tool: Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool], ) -> Dict[str, Any]: """Convert a raw function/class to an ZhipuAI tool. - - such as "code_interpreter" """ if _is_assistants_builtin_tool(tool): return tool # type: ignore @@ -220,7 +219,7 @@ def create_agent_executor( for t in tools: # TODO: platform tools built-in for all tools, # load with langchain_zhipuai/agents/all_tools_agent.py:108 - # AdapterAllTool implements code_interpreter + # AdapterAllTool implements it if _is_assistants_builtin_tool(t): assistants_builtin_tools.append(cls.paser_all_tools(t, callbacks)) temp_tools.extend(assistants_builtin_tools) diff --git a/langchain_zhipuai/chat_models/all_tools_message.py b/langchain_zhipuai/chat_models/all_tools_message.py index 48ad076..8fc49f3 100644 --- a/langchain_zhipuai/chat_models/all_tools_message.py +++ b/langchain_zhipuai/chat_models/all_tools_message.py @@ -36,6 +36,22 @@ def default_all_tool_chunk_parser(raw_tool_calls: List[dict]) -> List[ToolCallCh tool_call["code_interpreter"], ensure_ascii=False ) function_name = "code_interpreter" + elif ( + "drawing_tool" in tool_call + and tool_call["drawing_tool"] is not None + ): + function_args = json.dumps( + tool_call["drawing_tool"], ensure_ascii=False + ) + function_name = "drawing_tool" + elif ( + "web_browser" in tool_call + and tool_call["web_browser"] is not None + ): + function_args = json.dumps( + tool_call["web_browser"], ensure_ascii=False + ) + function_name = "web_browser" else: function_args = None function_name = None @@ -108,6 +124,54 @@ def init_tool_calls(cls, values: dict) -> dict: if "code_interpreter" in chunk["name"]: args_ = parse_partial_json(chunk["args"]) + if not isinstance(args_, dict): + raise ValueError("Malformed args.") + + if "outputs" in args_: + tool_calls.append( + ToolCall( + name=chunk["name"] or "", + args=args_, + id=chunk["id"], + ) + ) + + else: + invalid_tool_calls.append( + InvalidToolCall( + name=chunk["name"], + args=chunk["args"], + id=chunk["id"], + error=None, + ) + ) + elif "drawing_tool" in chunk["name"]: + args_ = parse_partial_json(chunk["args"]) + + if not isinstance(args_, dict): + raise ValueError("Malformed args.") + + if "outputs" in args_: + tool_calls.append( + ToolCall( + name=chunk["name"] or "", + args=args_, + id=chunk["id"], + ) + ) + + else: + invalid_tool_calls.append( + InvalidToolCall( + name=chunk["name"], + args=chunk["args"], + id=chunk["id"], + error=None, + ) + ) + elif "web_browser" in chunk["name"]: + args_ = parse_partial_json(chunk["args"]) + if not isinstance(args_, dict): raise ValueError("Malformed args.") diff --git a/tests/integration_tests/all_tools/test_alltools.py b/tests/integration_tests/all_tools/test_alltools.py index eeeb701..bf49829 100644 --- a/tests/integration_tests/all_tools/test_alltools.py +++ b/tests/integration_tests/all_tools/test_alltools.py @@ -120,3 +120,31 @@ async def test_all_tools_code_interpreter_sandbox_none(logging_conf): elif isinstance(item, AllToolsLLMStatus): if item.status == AgentStatus.llm_end: print("llm_end:" + item.text) + + +@pytest.mark.asyncio +async def test_all_tools_drawing_tool(logging_conf): + logging.config.dictConfig(logging_conf) # type: ignore + + agent_executor = ZhipuAIAllToolsRunnable.create_agent_executor( + model_name="glm-4-alltools", + tools=[{"type": "drawing_tool"}], + ) + chat_iterator = agent_executor.invoke( + chat_input="给我画一张猫咪的图片,要是波斯猫" + ) + async for item in chat_iterator: + if isinstance(item, AllToolsAction): + print("AllToolsAction:" + str(item.to_json())) + + elif isinstance(item, AllToolsFinish): + print("AllToolsFinish:" + str(item.to_json())) + + elif isinstance(item, AllToolsActionToolStart): + print("AllToolsActionToolStart:" + str(item.to_json())) + + elif isinstance(item, AllToolsActionToolEnd): + print("AllToolsActionToolEnd:" + str(item.to_json())) + elif isinstance(item, AllToolsLLMStatus): + if item.status == AgentStatus.llm_end: + print("llm_end:" + item.text)