Skip to content

Commit

Permalink
supper drawing_tool \web_browser
Browse files Browse the repository at this point in the history
  • Loading branch information
glide-the committed Jun 25, 2024
1 parent 95ee1cd commit e44290e
Show file tree
Hide file tree
Showing 15 changed files with 749 additions and 117 deletions.
89 changes: 89 additions & 0 deletions langchain_zhipuai/agent_toolkits/all_tools/drawing_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import logging
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

from langchain_core.agents import AgentAction
from langchain_core.callbacks import (
AsyncCallbackManagerForChainRun,
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)

from langchain_zhipuai.agent_toolkits import AdapterAllTool
from langchain_zhipuai.agent_toolkits.all_tools.struct_type import AdapterAllToolStructType
from langchain_zhipuai.agent_toolkits.all_tools.tool import (
AllToolExecutor,
BaseToolOutput,
)

logger = logging.getLogger(__name__)


class DrawingToolOutput(BaseToolOutput):
platform_params: Dict[str, Any]

def __init__(
self,
data: Any,
platform_params: Dict[str, Any],
**extras: Any,
) -> None:
super().__init__(data, "", "", **extras)
self.platform_params = platform_params


@dataclass
class DrawingAllToolExecutor(AllToolExecutor):
"""platform adapter tool for code interpreter tool"""

name: str

def run(
self,
tool: str,
tool_input: str,
log: str,
outputs: List[Union[str, dict]] = None,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> DrawingToolOutput:
if outputs is None or str(outputs).strip() == "":
raise ValueError(
f"Tool {self.name} is server error"
)

return DrawingToolOutput(
data=f"""Access:{tool}, Message: {tool_input},{log}""",
platform_params=self.platform_params,
)

async def arun(
self,
tool: str,
tool_input: str,
log: str,
outputs: List[Union[str, dict]] = None,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> DrawingToolOutput:
"""Use the tool asynchronously."""
if outputs is None or str(outputs).strip() == "" or len(outputs) == 0:
raise ValueError(
f"Tool {self.name} is server error"
)

return DrawingToolOutput(
data=f"""Access:{tool}, Message: {tool_input},{log}""",
platform_params=self.platform_params,
)


class DrawingAdapterAllTool(AdapterAllTool[DrawingAllToolExecutor]):
@classmethod
def get_type(cls) -> str:
return "DrawingAdapterAllTool"

def _build_adapter_all_tool(
self, platform_params: Dict[str, Any]
) -> DrawingAllToolExecutor:
return DrawingAllToolExecutor(
name=AdapterAllToolStructType.DRAWING_TOOL, platform_params=platform_params
)
4 changes: 4 additions & 0 deletions langchain_zhipuai/agent_toolkits/all_tools/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@
from langchain_zhipuai.agent_toolkits.all_tools.code_interpreter_tool import (
CodeInterpreterAdapterAllTool,
)
from langchain_zhipuai.agent_toolkits.all_tools.drawing_tool import DrawingAdapterAllTool
from langchain_zhipuai.agent_toolkits.all_tools.struct_type import (
AdapterAllToolStructType,
)
from langchain_zhipuai.agent_toolkits.all_tools.web_browser_tool import WebBrowserAdapterAllTool

TOOL_STRUCT_TYPE_TO_TOOL_CLASS: Dict[AdapterAllToolStructType, Type[AdapterAllTool]] = {
AdapterAllToolStructType.CODE_INTERPRETER: CodeInterpreterAdapterAllTool,
AdapterAllToolStructType.DRAWING_TOOL: DrawingAdapterAllTool,
AdapterAllToolStructType.WEB_BROWSER: WebBrowserAdapterAllTool,
}
2 changes: 2 additions & 0 deletions langchain_zhipuai/agent_toolkits/all_tools/struct_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,5 @@ class AdapterAllToolStructType(str, Enum):
# TODO: refactor so these are properties on the base class

CODE_INTERPRETER = "code_interpreter"
DRAWING_TOOL = "drawing_tool"
WEB_BROWSER = "web_browser"
66 changes: 59 additions & 7 deletions langchain_zhipuai/agent_toolkits/all_tools/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@
)
from langchain_core.tools import BaseTool

from langchain_zhipuai.agents.output_parsers.tools import CodeInterpreterAgentAction
from langchain_zhipuai.agent_toolkits.all_tools.struct_type import AdapterAllToolStructType
from langchain_zhipuai.agents.output_parsers.code_interpreter import CodeInterpreterAgentAction
from langchain_zhipuai.agents.output_parsers.drawing_tool import DrawingToolAgentAction
from langchain_zhipuai.agents.output_parsers.web_browser import WebBrowserAgentAction

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -142,16 +145,40 @@ def _run(
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
**tool_run_kwargs: Any,
) -> Any:
if "code_interpreter" in agent_action.tool:
if AdapterAllToolStructType.CODE_INTERPRETER == agent_action.tool and isinstance(
agent_action, CodeInterpreterAgentAction
):
return self.adapter_all_tool.run(
**{
"tool": agent_action.tool,
"tool_input": agent_action.tool_input,
"log": agent_action.log,
"outputs": agent_action.outputs,
},
**tool_run_kwargs,
)
elif AdapterAllToolStructType.DRAWING_TOOL == agent_action.tool and isinstance(
agent_action, DrawingToolAgentAction
):
return self.adapter_all_tool.run(
{
**{
"tool": agent_action.tool,
"tool_input": agent_action.tool_input,
"log": agent_action.log,
"outputs": agent_action.outputs,
},
**tool_run_kwargs,
)
elif AdapterAllToolStructType.WEB_BROWSER == agent_action.tool and isinstance(
agent_action, WebBrowserAgentAction
):
return self.adapter_all_tool.run(
**{
"tool": agent_action.tool,
"tool_input": agent_action.tool_input,
"log": agent_action.log,
"outputs": agent_action.outputs,
},
verbose=self.verbose,
color="red",
callbacks=run_manager.get_child() if run_manager else None,
**tool_run_kwargs,
)
else:
Expand All @@ -163,7 +190,7 @@ async def _arun(
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
**tool_run_kwargs: Any,
) -> Any:
if "code_interpreter" in agent_action.tool and isinstance(
if AdapterAllToolStructType.CODE_INTERPRETER == agent_action.tool and isinstance(
agent_action, CodeInterpreterAgentAction
):
return await self.adapter_all_tool.arun(
Expand All @@ -175,5 +202,30 @@ async def _arun(
},
**tool_run_kwargs,
)

elif AdapterAllToolStructType.DRAWING_TOOL == agent_action.tool and isinstance(
agent_action, DrawingToolAgentAction
):
return await self.adapter_all_tool.arun(
**{
"tool": agent_action.tool,
"tool_input": agent_action.tool_input,
"log": agent_action.log,
"outputs": agent_action.outputs,
},
**tool_run_kwargs,
)
elif AdapterAllToolStructType.WEB_BROWSER == agent_action.tool and isinstance(
agent_action, WebBrowserAgentAction
):
return await self.adapter_all_tool.arun(
**{
"tool": agent_action.tool,
"tool_input": agent_action.tool_input,
"log": agent_action.log,
"outputs": agent_action.outputs,
},
**tool_run_kwargs,
)
else:
raise KeyError()
89 changes: 89 additions & 0 deletions langchain_zhipuai/agent_toolkits/all_tools/web_browser_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import logging
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

from langchain_core.agents import AgentAction
from langchain_core.callbacks import (
AsyncCallbackManagerForChainRun,
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)

from langchain_zhipuai.agent_toolkits import AdapterAllTool
from langchain_zhipuai.agent_toolkits.all_tools.struct_type import AdapterAllToolStructType
from langchain_zhipuai.agent_toolkits.all_tools.tool import (
AllToolExecutor,
BaseToolOutput,
)

logger = logging.getLogger(__name__)


class WebBrowserToolOutput(BaseToolOutput):
platform_params: Dict[str, Any]

def __init__(
self,
data: Any,
platform_params: Dict[str, Any],
**extras: Any,
) -> None:
super().__init__(data, "", "", **extras)
self.platform_params = platform_params


@dataclass
class WebBrowserAllToolExecutor(AllToolExecutor):
"""platform adapter tool for code interpreter tool"""

name: str

def run(
self,
tool: str,
tool_input: str,
log: str,
outputs: List[Union[str, dict]] = None,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> WebBrowserToolOutput:
if outputs is None or str(outputs).strip() == "":
raise ValueError(
f"Tool {self.name} is server error"
)

return WebBrowserToolOutput(
data=f"""Access:{tool}, Message: {tool_input},{log}""",
platform_params=self.platform_params,
)

async def arun(
self,
tool: str,
tool_input: str,
log: str,
outputs: List[Union[str, dict]] = None,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> WebBrowserToolOutput:
"""Use the tool asynchronously."""
if outputs is None or str(outputs).strip() == "" or len(outputs) == 0:
raise ValueError(
f"Tool {self.name} is server error"
)

return WebBrowserToolOutput(
data=f"""Access:{tool}, Message: {tool_input},{log}""",
platform_params=self.platform_params,
)


class WebBrowserAdapterAllTool(AdapterAllTool[WebBrowserAllToolExecutor]):
@classmethod
def get_type(cls) -> str:
return "WebBrowserAdapterAllTool"

def _build_adapter_all_tool(
self, platform_params: Dict[str, Any]
) -> WebBrowserAllToolExecutor:
return WebBrowserAllToolExecutor(
name=AdapterAllToolStructType.WEB_BROWSER, platform_params=platform_params
)
18 changes: 17 additions & 1 deletion langchain_zhipuai/agents/format_scratchpad/all_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
from langchain_zhipuai.agent_toolkits.all_tools.code_interpreter_tool import (
CodeInterpreterToolOutput,
)
from langchain_zhipuai.agents.output_parsers.tools import CodeInterpreterAgentAction
from langchain_zhipuai.agent_toolkits.all_tools.drawing_tool import DrawingToolOutput
from langchain_zhipuai.agent_toolkits.all_tools.web_browser_tool import WebBrowserToolOutput
from langchain_zhipuai.agents.output_parsers.code_interpreter import CodeInterpreterAgentAction
from langchain_zhipuai.agents.output_parsers.drawing_tool import DrawingToolAgentAction
from langchain_zhipuai.agents.output_parsers.web_browser import WebBrowserAgentAction


def _create_tool_message(
Expand Down Expand Up @@ -67,6 +71,18 @@ def format_to_zhipuai_all_tool_messages(
else:
raise ValueError(f"Unknown observation type: {type(observation)}")

elif isinstance(agent_action, DrawingToolAgentAction):
if isinstance(observation, DrawingToolOutput):
messages.append(AIMessage(content=str(observation)))
else:
raise ValueError(f"Unknown observation type: {type(observation)}")

elif isinstance(agent_action, WebBrowserAgentAction):
if isinstance(observation, WebBrowserToolOutput):
messages.append(AIMessage(content=str(observation)))
else:
raise ValueError(f"Unknown observation type: {type(observation)}")

elif isinstance(agent_action, ToolAgentAction):
new_messages = list(agent_action.message_log) + [
_create_tool_message(agent_action, observation)
Expand Down
15 changes: 15 additions & 0 deletions langchain_zhipuai/agents/output_parsers/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from zhipuai.core import BaseModel
from typing import Optional, Dict, Any


class AllToolsMessageToolCall(BaseModel):
name: Optional[str]
args: Optional[Dict[str, Any]]
id: Optional[str]


class AllToolsMessageToolCallChunk(BaseModel):
name: Optional[str]
args: Optional[Dict[str, Any]]
id: Optional[str]
index: Optional[int]
Loading

0 comments on commit e44290e

Please sign in to comment.