From 89e01016a6e6a6cdc179e686b7506aa30864de9c Mon Sep 17 00:00:00 2001 From: Erick Friis Date: Tue, 19 Mar 2024 12:57:27 -0700 Subject: [PATCH] openai[patch]: pass message name (#17537) --- .../langchain_openai/chat_models/base.py | 44 ++++++++-------- libs/partners/openai/poetry.lock | 20 ++++---- .../tests/unit_tests/chat_models/test_base.py | 51 +++++++++++++++++++ 3 files changed, 85 insertions(+), 30 deletions(-) diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index b3409fadb72ba..75f5e340ba3de 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -92,9 +92,10 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: The LangChain message. """ role = _dict.get("role") + name = _dict.get("name") id_ = _dict.get("id") if role == "user": - return HumanMessage(content=_dict.get("content", ""), id=id_) + return HumanMessage(content=_dict.get("content", ""), id=id_, name=name) elif role == "assistant": # Fix for azure # Also OpenAI returns None for tool invocations @@ -104,12 +105,14 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: additional_kwargs["function_call"] = dict(function_call) if tool_calls := _dict.get("tool_calls"): additional_kwargs["tool_calls"] = tool_calls - return AIMessage(content=content, additional_kwargs=additional_kwargs, id=id_) + return AIMessage( + content=content, additional_kwargs=additional_kwargs, name=name, id=id_ + ) elif role == "system": - return SystemMessage(content=_dict.get("content", ""), id=id_) + return SystemMessage(content=_dict.get("content", ""), name=name, id=id_) elif role == "function": return FunctionMessage( - content=_dict.get("content", ""), name=_dict.get("name"), id=id_ + content=_dict.get("content", ""), name=cast(str, _dict.get("name")), id=id_ ) elif role == "tool": additional_kwargs = {} @@ -117,8 +120,9 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: additional_kwargs["name"] = _dict["name"] return ToolMessage( content=_dict.get("content", ""), - tool_call_id=_dict.get("tool_call_id"), + tool_call_id=cast(str, _dict.get("tool_call_id")), additional_kwargs=additional_kwargs, + name=name, id=id_, ) else: @@ -134,13 +138,16 @@ def _convert_message_to_dict(message: BaseMessage) -> dict: Returns: The dictionary. """ - message_dict: Dict[str, Any] + message_dict: Dict[str, Any] = { + "content": message.content, + "name": message.name, + } if isinstance(message, ChatMessage): - message_dict = {"role": message.role, "content": message.content} + message_dict["role"] = message.role elif isinstance(message, HumanMessage): - message_dict = {"role": "user", "content": message.content} + message_dict["role"] = "user" elif isinstance(message, AIMessage): - message_dict = {"role": "assistant", "content": message.content} + message_dict["role"] = "assistant" if "function_call" in message.additional_kwargs: message_dict["function_call"] = message.additional_kwargs["function_call"] # If function call only, content is None not empty string @@ -152,19 +159,16 @@ def _convert_message_to_dict(message: BaseMessage) -> dict: if message_dict["content"] == "": message_dict["content"] = None elif isinstance(message, SystemMessage): - message_dict = {"role": "system", "content": message.content} + message_dict["role"] = "system" elif isinstance(message, FunctionMessage): - message_dict = { - "role": "function", - "content": message.content, - "name": message.name, - } + message_dict["role"] = "function" elif isinstance(message, ToolMessage): - message_dict = { - "role": "tool", - "content": message.content, - "tool_call_id": message.tool_call_id, - } + message_dict["role"] = "tool" + message_dict["tool_call_id"] = message.tool_call_id + + # tool message doesn't have name: https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages + if message_dict["name"] is None: + del message_dict["name"] else: raise TypeError(f"Got unknown type {message}") if "name" in message.additional_kwargs: diff --git a/libs/partners/openai/poetry.lock b/libs/partners/openai/poetry.lock index 96a9258f6d521..2d6c16a93da90 100644 --- a/libs/partners/openai/poetry.lock +++ b/libs/partners/openai/poetry.lock @@ -318,7 +318,7 @@ files = [ [[package]] name = "langchain-core" -version = "0.1.29" +version = "0.1.33-rc.1" description = "Building applications with LLMs through composability" optional = false python-versions = ">=3.8.1,<4.0" @@ -344,13 +344,13 @@ url = "../../core" [[package]] name = "langsmith" -version = "0.1.22" +version = "0.1.29" description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." optional = false python-versions = ">=3.8.1,<4.0" files = [ - {file = "langsmith-0.1.22-py3-none-any.whl", hash = "sha256:b877d302bd4cf7c79e9e6e24bedf669132abf0659143390a29350eda0945544f"}, - {file = "langsmith-0.1.22.tar.gz", hash = "sha256:2921ae2297c2fb23baa2641b9cf416914ac7fd65f4a9dd5a573bc30efb54b693"}, + {file = "langsmith-0.1.29-py3-none-any.whl", hash = "sha256:5439f5bf25b00a43602aa1ddaba0a31d413ed920e7b20494070328f7e1ecbb86"}, + {file = "langsmith-0.1.29.tar.gz", hash = "sha256:60ba0bd889c6a2683d123f66dc5043368eb2f103c4eb69e382abf7ce69a9f7d6"}, ] [package.dependencies] @@ -458,13 +458,13 @@ files = [ [[package]] name = "openai" -version = "1.13.3" +version = "1.14.2" description = "The official Python library for the openai API" optional = false python-versions = ">=3.7.1" files = [ - {file = "openai-1.13.3-py3-none-any.whl", hash = "sha256:5769b62abd02f350a8dd1a3a242d8972c947860654466171d60fb0972ae0a41c"}, - {file = "openai-1.13.3.tar.gz", hash = "sha256:ff6c6b3bc7327e715e4b3592a923a5a1c7519ff5dd764a83d69f633d49e77a7b"}, + {file = "openai-1.14.2-py3-none-any.whl", hash = "sha256:a48b3c4d635b603952189ac5a0c0c9b06c025b80eb2900396939f02bb2104ac3"}, + {file = "openai-1.14.2.tar.gz", hash = "sha256:e5642f7c02cf21994b08477d7bb2c1e46d8f335d72c26f0396c5f89b15b5b153"}, ] [package.dependencies] @@ -566,13 +566,13 @@ testing = ["pytest", "pytest-benchmark"] [[package]] name = "pydantic" -version = "2.6.3" +version = "2.6.4" description = "Data validation using Python type hints" optional = false python-versions = ">=3.8" files = [ - {file = "pydantic-2.6.3-py3-none-any.whl", hash = "sha256:72c6034df47f46ccdf81869fddb81aade68056003900a8724a4f160700016a2a"}, - {file = "pydantic-2.6.3.tar.gz", hash = "sha256:e07805c4c7f5c6826e33a1d4c9d47950d7eaf34868e2690f8594d2e30241f11f"}, + {file = "pydantic-2.6.4-py3-none-any.whl", hash = "sha256:cc46fce86607580867bdc3361ad462bab9c222ef042d3da86f2fb333e1d916c5"}, + {file = "pydantic-2.6.4.tar.gz", hash = "sha256:b1704e0847db01817624a6b86766967f552dd9dbf3afba4004409f908dcc84e6"}, ] [package.dependencies] diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py index 550d5b729ac45..87e7111959ee8 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py @@ -1,4 +1,5 @@ """Test OpenAI Chat API wrapper.""" + import json from typing import Any from unittest.mock import AsyncMock, MagicMock, patch @@ -44,6 +45,13 @@ def test__convert_dict_to_message_human() -> None: assert result == expected_output +def test__convert_dict_to_message_human_with_name() -> None: + message = {"role": "user", "content": "foo", "name": "test"} + result = _convert_dict_to_message(message) + expected_output = HumanMessage(content="foo", name="test") + assert result == expected_output + + def test__convert_dict_to_message_ai() -> None: message = {"role": "assistant", "content": "foo"} result = _convert_dict_to_message(message) @@ -51,6 +59,13 @@ def test__convert_dict_to_message_ai() -> None: assert result == expected_output +def test__convert_dict_to_message_ai_with_name() -> None: + message = {"role": "assistant", "content": "foo", "name": "test"} + result = _convert_dict_to_message(message) + expected_output = AIMessage(content="foo", name="test") + assert result == expected_output + + def test__convert_dict_to_message_system() -> None: message = {"role": "system", "content": "foo"} result = _convert_dict_to_message(message) @@ -58,6 +73,13 @@ def test__convert_dict_to_message_system() -> None: assert result == expected_output +def test__convert_dict_to_message_system_with_name() -> None: + message = {"role": "system", "content": "foo", "name": "test"} + result = _convert_dict_to_message(message) + expected_output = SystemMessage(content="foo", name="test") + assert result == expected_output + + @pytest.fixture def mock_completion() -> dict: return { @@ -71,6 +93,7 @@ def mock_completion() -> dict: "message": { "role": "assistant", "content": "Bar Baz", + "name": "Erick", }, "finish_reason": "stop", } @@ -134,3 +157,31 @@ async def mock_create(*args: Any, **kwargs: Any) -> Any: def test__get_encoding_model(model: str) -> None: ChatOpenAI(model=model)._get_encoding_model() return + + +def test_openai_invoke_name(mock_completion: dict) -> None: + llm = ChatOpenAI() + + mock_client = MagicMock() + mock_client.create.return_value = mock_completion + + with patch.object( + llm, + "client", + mock_client, + ): + messages = [ + HumanMessage(content="Foo", name="Katie"), + ] + res = llm.invoke(messages) + call_args, call_kwargs = mock_client.create.call_args + assert len(call_args) == 0 # no positional args + call_messages = call_kwargs["messages"] + assert len(call_messages) == 1 + assert call_messages[0]["role"] == "user" + assert call_messages[0]["content"] == "Foo" + assert call_messages[0]["name"] == "Katie" + + # check return type has name + assert res.content == "Bar Baz" + assert res.name == "Erick"