Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

openai[patch]: pass message name #17537

Merged
merged 8 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 24 additions & 20 deletions libs/partners/openai/langchain_openai/chat_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,10 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
The LangChain message.
"""
role = _dict.get("role")
name = _dict.get("name")
id_ = _dict.get("id")
if role == "user":
return HumanMessage(content=_dict.get("content", ""), id=id_)
return HumanMessage(content=_dict.get("content", ""), id=id_, name=name)
elif role == "assistant":
# Fix for azure
# Also OpenAI returns None for tool invocations
Expand All @@ -104,21 +105,24 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
additional_kwargs["function_call"] = dict(function_call)
if tool_calls := _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = tool_calls
return AIMessage(content=content, additional_kwargs=additional_kwargs, id=id_)
return AIMessage(
content=content, additional_kwargs=additional_kwargs, name=name, id=id_
)
elif role == "system":
return SystemMessage(content=_dict.get("content", ""), id=id_)
return SystemMessage(content=_dict.get("content", ""), name=name, id=id_)
elif role == "function":
return FunctionMessage(
content=_dict.get("content", ""), name=_dict.get("name"), id=id_
content=_dict.get("content", ""), name=cast(str, _dict.get("name")), id=id_
)
elif role == "tool":
additional_kwargs = {}
if "name" in _dict:
additional_kwargs["name"] = _dict["name"]
return ToolMessage(
content=_dict.get("content", ""),
tool_call_id=_dict.get("tool_call_id"),
tool_call_id=cast(str, _dict.get("tool_call_id")),
additional_kwargs=additional_kwargs,
name=name,
id=id_,
)
else:
Expand All @@ -134,13 +138,16 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
Returns:
The dictionary.
"""
message_dict: Dict[str, Any]
message_dict: Dict[str, Any] = {
"content": message.content,
"name": message.name,
}
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
message_dict["role"] = message.role
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
message_dict["role"] = "user"
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
message_dict["role"] = "assistant"
if "function_call" in message.additional_kwargs:
message_dict["function_call"] = message.additional_kwargs["function_call"]
# If function call only, content is None not empty string
Expand All @@ -152,19 +159,16 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
if message_dict["content"] == "":
message_dict["content"] = None
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
message_dict["role"] = "system"
elif isinstance(message, FunctionMessage):
message_dict = {
"role": "function",
"content": message.content,
"name": message.name,
}
message_dict["role"] = "function"
elif isinstance(message, ToolMessage):
message_dict = {
"role": "tool",
"content": message.content,
"tool_call_id": message.tool_call_id,
}
message_dict["role"] = "tool"
message_dict["tool_call_id"] = message.tool_call_id

# tool message doesn't have name: https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages
if message_dict["name"] is None:
del message_dict["name"]
else:
raise TypeError(f"Got unknown type {message}")
if "name" in message.additional_kwargs:
Expand Down
20 changes: 10 additions & 10 deletions libs/partners/openai/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

51 changes: 51 additions & 0 deletions libs/partners/openai/tests/unit_tests/chat_models/test_base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Test OpenAI Chat API wrapper."""

import json
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
Expand Down Expand Up @@ -44,20 +45,41 @@ def test__convert_dict_to_message_human() -> None:
assert result == expected_output


def test__convert_dict_to_message_human_with_name() -> None:
message = {"role": "user", "content": "foo", "name": "test"}
result = _convert_dict_to_message(message)
expected_output = HumanMessage(content="foo", name="test")
assert result == expected_output


def test__convert_dict_to_message_ai() -> None:
message = {"role": "assistant", "content": "foo"}
result = _convert_dict_to_message(message)
expected_output = AIMessage(content="foo")
assert result == expected_output


def test__convert_dict_to_message_ai_with_name() -> None:
message = {"role": "assistant", "content": "foo", "name": "test"}
result = _convert_dict_to_message(message)
expected_output = AIMessage(content="foo", name="test")
assert result == expected_output


def test__convert_dict_to_message_system() -> None:
message = {"role": "system", "content": "foo"}
result = _convert_dict_to_message(message)
expected_output = SystemMessage(content="foo")
assert result == expected_output


def test__convert_dict_to_message_system_with_name() -> None:
message = {"role": "system", "content": "foo", "name": "test"}
result = _convert_dict_to_message(message)
expected_output = SystemMessage(content="foo", name="test")
assert result == expected_output


@pytest.fixture
def mock_completion() -> dict:
return {
Expand All @@ -71,6 +93,7 @@ def mock_completion() -> dict:
"message": {
"role": "assistant",
"content": "Bar Baz",
"name": "Erick",
},
"finish_reason": "stop",
}
Expand Down Expand Up @@ -134,3 +157,31 @@ async def mock_create(*args: Any, **kwargs: Any) -> Any:
def test__get_encoding_model(model: str) -> None:
ChatOpenAI(model=model)._get_encoding_model()
return


def test_openai_invoke_name(mock_completion: dict) -> None:
llm = ChatOpenAI()

mock_client = MagicMock()
mock_client.create.return_value = mock_completion

with patch.object(
llm,
"client",
mock_client,
):
messages = [
HumanMessage(content="Foo", name="Katie"),
]
res = llm.invoke(messages)
call_args, call_kwargs = mock_client.create.call_args
assert len(call_args) == 0 # no positional args
call_messages = call_kwargs["messages"]
assert len(call_messages) == 1
assert call_messages[0]["role"] == "user"
assert call_messages[0]["content"] == "Foo"
assert call_messages[0]["name"] == "Katie"

# check return type has name
assert res.content == "Bar Baz"
assert res.name == "Erick"
Loading