Skip to content

Commit

Permalink
openai[patch]: pass message name (#17537)
Browse files Browse the repository at this point in the history
  • Loading branch information
efriis authored and hinthornw committed Apr 26, 2024
1 parent 366283a commit 89e0101
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 30 deletions.
44 changes: 24 additions & 20 deletions libs/partners/openai/langchain_openai/chat_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,10 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
The LangChain message.
"""
role = _dict.get("role")
name = _dict.get("name")
id_ = _dict.get("id")
if role == "user":
return HumanMessage(content=_dict.get("content", ""), id=id_)
return HumanMessage(content=_dict.get("content", ""), id=id_, name=name)
elif role == "assistant":
# Fix for azure
# Also OpenAI returns None for tool invocations
Expand All @@ -104,21 +105,24 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
additional_kwargs["function_call"] = dict(function_call)
if tool_calls := _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = tool_calls
return AIMessage(content=content, additional_kwargs=additional_kwargs, id=id_)
return AIMessage(
content=content, additional_kwargs=additional_kwargs, name=name, id=id_
)
elif role == "system":
return SystemMessage(content=_dict.get("content", ""), id=id_)
return SystemMessage(content=_dict.get("content", ""), name=name, id=id_)
elif role == "function":
return FunctionMessage(
content=_dict.get("content", ""), name=_dict.get("name"), id=id_
content=_dict.get("content", ""), name=cast(str, _dict.get("name")), id=id_
)
elif role == "tool":
additional_kwargs = {}
if "name" in _dict:
additional_kwargs["name"] = _dict["name"]
return ToolMessage(
content=_dict.get("content", ""),
tool_call_id=_dict.get("tool_call_id"),
tool_call_id=cast(str, _dict.get("tool_call_id")),
additional_kwargs=additional_kwargs,
name=name,
id=id_,
)
else:
Expand All @@ -134,13 +138,16 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
Returns:
The dictionary.
"""
message_dict: Dict[str, Any]
message_dict: Dict[str, Any] = {
"content": message.content,
"name": message.name,
}
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
message_dict["role"] = message.role
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
message_dict["role"] = "user"
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
message_dict["role"] = "assistant"
if "function_call" in message.additional_kwargs:
message_dict["function_call"] = message.additional_kwargs["function_call"]
# If function call only, content is None not empty string
Expand All @@ -152,19 +159,16 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
if message_dict["content"] == "":
message_dict["content"] = None
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
message_dict["role"] = "system"
elif isinstance(message, FunctionMessage):
message_dict = {
"role": "function",
"content": message.content,
"name": message.name,
}
message_dict["role"] = "function"
elif isinstance(message, ToolMessage):
message_dict = {
"role": "tool",
"content": message.content,
"tool_call_id": message.tool_call_id,
}
message_dict["role"] = "tool"
message_dict["tool_call_id"] = message.tool_call_id

# tool message doesn't have name: https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages
if message_dict["name"] is None:
del message_dict["name"]
else:
raise TypeError(f"Got unknown type {message}")
if "name" in message.additional_kwargs:
Expand Down
20 changes: 10 additions & 10 deletions libs/partners/openai/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

51 changes: 51 additions & 0 deletions libs/partners/openai/tests/unit_tests/chat_models/test_base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Test OpenAI Chat API wrapper."""

import json
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
Expand Down Expand Up @@ -44,20 +45,41 @@ def test__convert_dict_to_message_human() -> None:
assert result == expected_output


def test__convert_dict_to_message_human_with_name() -> None:
message = {"role": "user", "content": "foo", "name": "test"}
result = _convert_dict_to_message(message)
expected_output = HumanMessage(content="foo", name="test")
assert result == expected_output


def test__convert_dict_to_message_ai() -> None:
message = {"role": "assistant", "content": "foo"}
result = _convert_dict_to_message(message)
expected_output = AIMessage(content="foo")
assert result == expected_output


def test__convert_dict_to_message_ai_with_name() -> None:
message = {"role": "assistant", "content": "foo", "name": "test"}
result = _convert_dict_to_message(message)
expected_output = AIMessage(content="foo", name="test")
assert result == expected_output


def test__convert_dict_to_message_system() -> None:
message = {"role": "system", "content": "foo"}
result = _convert_dict_to_message(message)
expected_output = SystemMessage(content="foo")
assert result == expected_output


def test__convert_dict_to_message_system_with_name() -> None:
message = {"role": "system", "content": "foo", "name": "test"}
result = _convert_dict_to_message(message)
expected_output = SystemMessage(content="foo", name="test")
assert result == expected_output


@pytest.fixture
def mock_completion() -> dict:
return {
Expand All @@ -71,6 +93,7 @@ def mock_completion() -> dict:
"message": {
"role": "assistant",
"content": "Bar Baz",
"name": "Erick",
},
"finish_reason": "stop",
}
Expand Down Expand Up @@ -134,3 +157,31 @@ async def mock_create(*args: Any, **kwargs: Any) -> Any:
def test__get_encoding_model(model: str) -> None:
ChatOpenAI(model=model)._get_encoding_model()
return


def test_openai_invoke_name(mock_completion: dict) -> None:
llm = ChatOpenAI()

mock_client = MagicMock()
mock_client.create.return_value = mock_completion

with patch.object(
llm,
"client",
mock_client,
):
messages = [
HumanMessage(content="Foo", name="Katie"),
]
res = llm.invoke(messages)
call_args, call_kwargs = mock_client.create.call_args
assert len(call_args) == 0 # no positional args
call_messages = call_kwargs["messages"]
assert len(call_messages) == 1
assert call_messages[0]["role"] == "user"
assert call_messages[0]["content"] == "Foo"
assert call_messages[0]["name"] == "Katie"

# check return type has name
assert res.content == "Bar Baz"
assert res.name == "Erick"

0 comments on commit 89e0101

Please sign in to comment.