From b88b77bf1920dfc4ab6fa224de1c40352d60a63a Mon Sep 17 00:00:00 2001 From: glide-the <2533736852@qq.com> Date: Tue, 25 Jun 2024 20:12:10 +0800 Subject: [PATCH] =?UTF-8?q?=E5=90=8D=E7=A7=B0=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- langchain_zhipuai/agents/all_tools_agent.py | 2 +- .../agents/format_scratchpad/all_tools.py | 10 ++++- langchain_zhipuai/embeddings/base.py | 2 +- .../all_tools/test_alltools.py | 20 +--------- .../embeddings/test_embeddings.py | 38 +++++++++---------- 5 files changed, 30 insertions(+), 42 deletions(-) diff --git a/langchain_zhipuai/agents/all_tools_agent.py b/langchain_zhipuai/agents/all_tools_agent.py index cf22ccf..8efa112 100644 --- a/langchain_zhipuai/agents/all_tools_agent.py +++ b/langchain_zhipuai/agents/all_tools_agent.py @@ -56,7 +56,7 @@ def _perform_agent_action( # We then call the tool on the tool input to get an observation # TODO: platform adapter tool for all tools, # view tools binding langchain_zhipuai/agents/zhipuai_all_tools/base.py:188 - if "code_interpreter" in agent_action.tool: + if agent_action.tool in AdapterAllToolStructType.__members__.values(): observation = tool.run( { "agent_action": agent_action, diff --git a/langchain_zhipuai/agents/format_scratchpad/all_tools.py b/langchain_zhipuai/agents/format_scratchpad/all_tools.py index 5fe5c3c..6015065 100644 --- a/langchain_zhipuai/agents/format_scratchpad/all_tools.py +++ b/langchain_zhipuai/agents/format_scratchpad/all_tools.py @@ -73,13 +73,19 @@ def format_to_zhipuai_all_tool_messages( elif isinstance(agent_action, DrawingToolAgentAction): if isinstance(observation, DrawingToolOutput): - messages.append(AIMessage(content=str(observation))) + new_messages = list(agent_action.message_log) + [ + _create_tool_message(agent_action, observation) + ] + messages.extend([new for new in new_messages if new not in messages]) else: raise ValueError(f"Unknown observation type: {type(observation)}") elif isinstance(agent_action, WebBrowserAgentAction): if isinstance(observation, WebBrowserToolOutput): - messages.append(AIMessage(content=str(observation))) + new_messages = list(agent_action.message_log) + [ + _create_tool_message(agent_action, observation) + ] + messages.extend([new for new in new_messages if new not in messages]) else: raise ValueError(f"Unknown observation type: {type(observation)}") diff --git a/langchain_zhipuai/embeddings/base.py b/langchain_zhipuai/embeddings/base.py index b388f64..e856af1 100644 --- a/langchain_zhipuai/embeddings/base.py +++ b/langchain_zhipuai/embeddings/base.py @@ -38,7 +38,7 @@ logger = logging.getLogger(__name__) -class ZhipuAIAIEmbeddings(BaseModel, Embeddings): +class ZhipuAIEmbeddings(BaseModel, Embeddings): """ZhipuAI embedding models. To use, you should have the diff --git a/tests/integration_tests/all_tools/test_alltools.py b/tests/integration_tests/all_tools/test_alltools.py index bf49829..e6830bf 100644 --- a/tests/integration_tests/all_tools/test_alltools.py +++ b/tests/integration_tests/all_tools/test_alltools.py @@ -35,7 +35,7 @@ async def test_all_tools_code_interpreter(logging_conf): agent_executor = ZhipuAIAllToolsRunnable.create_agent_executor( model_name="glm-4-alltools", - tools=[{"type": "code_interpreter"}, shell], + tools=[shell], ) chat_iterator = agent_executor.invoke( chat_input="看下本地文件有哪些,告诉我你用的是什么文件,查看当前目录" @@ -56,24 +56,6 @@ async def test_all_tools_code_interpreter(logging_conf): if item.status == AgentStatus.llm_end: print("llm_end:" + item.text) - chat_iterator = agent_executor.invoke(chat_input="打印下test_alltools.py") - async for item in chat_iterator: - if isinstance(item, AllToolsAction): - print("AllToolsAction:" + str(item.to_json())) - - elif isinstance(item, AllToolsFinish): - print("AllToolsFinish:" + str(item.to_json())) - - elif isinstance(item, AllToolsActionToolStart): - print("AllToolsActionToolStart:" + str(item.to_json())) - - elif isinstance(item, AllToolsActionToolEnd): - print("AllToolsActionToolEnd:" + str(item.to_json())) - elif isinstance(item, AllToolsLLMStatus): - if item.status == AgentStatus.llm_end: - print("llm_end:" + item.text) - - @pytest.mark.asyncio async def test_all_tools_code_interpreter_sandbox_none(logging_conf): logging.config.dictConfig(logging_conf) # type: ignore diff --git a/tests/integration_tests/embeddings/test_embeddings.py b/tests/integration_tests/embeddings/test_embeddings.py index f670890..c2ca51b 100644 --- a/tests/integration_tests/embeddings/test_embeddings.py +++ b/tests/integration_tests/embeddings/test_embeddings.py @@ -1,25 +1,25 @@ -"""Test openai embeddings.""" +"""Test zhipuai embeddings.""" import numpy as np import pytest -from langchain_zhipuai.embeddings.base import ZhipuAIAIEmbeddings +from langchain_zhipuai.embeddings.base import ZhipuAIEmbeddings @pytest.mark.scheduled -def test_openai_embedding_documents() -> None: - """Test openai embeddings.""" +def test_zhipuai_embedding_documents() -> None: + """Test zhipuai embeddings.""" documents = ["foo bar"] - embedding = ZhipuAIAIEmbeddings() + embedding = ZhipuAIEmbeddings() output = embedding.embed_documents(documents) assert len(output) == 1 assert len(output[0]) == 1024 @pytest.mark.scheduled -def test_openai_embedding_documents_multiple() -> None: - """Test openai embeddings.""" +def test_zhipuai_embedding_documents_multiple() -> None: + """Test zhipuai embeddings.""" documents = ["foo bar", "bar foo", "foo"] - embedding = ZhipuAIAIEmbeddings(chunk_size=2) + embedding = ZhipuAIEmbeddings(chunk_size=2) embedding.embedding_ctx_length = 8191 output = embedding.embed_documents(documents) assert len(output) == 3 @@ -29,10 +29,10 @@ def test_openai_embedding_documents_multiple() -> None: @pytest.mark.scheduled -async def test_openai_embedding_documents_async_multiple() -> None: - """Test openai embeddings.""" +async def test_zhipuai_embedding_documents_async_multiple() -> None: + """Test zhipuai embeddings.""" documents = ["foo bar", "bar foo", "foo"] - embedding = ZhipuAIAIEmbeddings(chunk_size=2) + embedding = ZhipuAIEmbeddings(chunk_size=2) embedding.embedding_ctx_length = 8191 output = await embedding.aembed_documents(documents) assert len(output) == 3 @@ -42,30 +42,30 @@ async def test_openai_embedding_documents_async_multiple() -> None: @pytest.mark.scheduled -def test_openai_embedding_query() -> None: - """Test openai embeddings.""" +def test_zhipuai_embedding_query() -> None: + """Test zhipuai embeddings.""" document = "foo bar" - embedding = ZhipuAIAIEmbeddings() + embedding = ZhipuAIEmbeddings() output = embedding.embed_query(document) assert len(output) == 1024 @pytest.mark.scheduled -async def test_openai_embedding_async_query() -> None: - """Test openai embeddings.""" +async def test_zhipuai_embedding_async_query() -> None: + """Test zhipuai embeddings.""" document = "foo bar" - embedding = ZhipuAIAIEmbeddings() + embedding = ZhipuAIEmbeddings() output = await embedding.aembed_query(document) assert len(output) == 1024 @pytest.mark.scheduled def test_embed_documents_normalized() -> None: - output = ZhipuAIAIEmbeddings().embed_documents(["foo walked to the market"]) + output = ZhipuAIEmbeddings().embed_documents(["foo walked to the market"]) assert np.isclose(np.linalg.norm(output[0]), 1.0) @pytest.mark.scheduled def test_embed_query_normalized() -> None: - output = ZhipuAIAIEmbeddings().embed_query("foo walked to the market") + output = ZhipuAIEmbeddings().embed_query("foo walked to the market") assert np.isclose(np.linalg.norm(output), 1.0) \ No newline at end of file