Skip to content

Commit

Permalink
Accept message-like things in Chat models, LLMs and MessagesPlacehold…
Browse files Browse the repository at this point in the history
…er (#16418)
  • Loading branch information
nfcampos authored Jan 26, 2024
1 parent 570b4f8 commit 52ccae3
Show file tree
Hide file tree
Showing 10 changed files with 214 additions and 8 deletions.
9 changes: 7 additions & 2 deletions libs/core/langchain_core/language_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@
from typing_extensions import TypeAlias

from langchain_core._api import deprecated
from langchain_core.messages import AnyMessage, BaseMessage, get_buffer_string
from langchain_core.messages import (
AnyMessage,
BaseMessage,
MessageLikeRepresentation,
get_buffer_string,
)
from langchain_core.prompt_values import PromptValue
from langchain_core.runnables import Runnable, RunnableSerializable
from langchain_core.utils import get_pydantic_field_names
Expand Down Expand Up @@ -49,7 +54,7 @@ def _get_token_ids_default_method(text: str) -> List[int]:
return tokenizer.encode(text)


LanguageModelInput = Union[PromptValue, str, Sequence[BaseMessage]]
LanguageModelInput = Union[PromptValue, str, Sequence[MessageLikeRepresentation]]
LanguageModelOutput = Union[BaseMessage, str]
LanguageModelLike = Runnable[LanguageModelInput, LanguageModelOutput]
LanguageModelOutputVar = TypeVar("LanguageModelOutputVar", BaseMessage, str)
Expand Down
3 changes: 2 additions & 1 deletion libs/core/langchain_core/language_models/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
BaseMessage,
BaseMessageChunk,
HumanMessage,
convert_to_messages,
message_chunk_to_message,
)
from langchain_core.outputs import (
Expand Down Expand Up @@ -144,7 +145,7 @@ def _convert_input(self, input: LanguageModelInput) -> PromptValue:
elif isinstance(input, str):
return StringPromptValue(text=input)
elif isinstance(input, Sequence):
return ChatPromptValue(messages=input)
return ChatPromptValue(messages=convert_to_messages(input))
else:
raise ValueError(
f"Invalid input type {type(input)}. "
Expand Down
9 changes: 7 additions & 2 deletions libs/core/langchain_core/language_models/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,12 @@
from langchain_core.globals import get_llm_cache
from langchain_core.language_models.base import BaseLanguageModel, LanguageModelInput
from langchain_core.load import dumpd
from langchain_core.messages import AIMessage, BaseMessage, get_buffer_string
from langchain_core.messages import (
AIMessage,
BaseMessage,
convert_to_messages,
get_buffer_string,
)
from langchain_core.outputs import Generation, GenerationChunk, LLMResult, RunInfo
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
from langchain_core.pydantic_v1 import Field, root_validator, validator
Expand Down Expand Up @@ -210,7 +215,7 @@ def _convert_input(self, input: LanguageModelInput) -> PromptValue:
elif isinstance(input, str):
return StringPromptValue(text=input)
elif isinstance(input, Sequence):
return ChatPromptValue(messages=input)
return ChatPromptValue(messages=convert_to_messages(input))
else:
raise ValueError(
f"Invalid input type {type(input)}. "
Expand Down
107 changes: 106 additions & 1 deletion libs/core/langchain_core/messages/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Sequence, Union
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

from langchain_core.messages.ai import AIMessage, AIMessageChunk
from langchain_core.messages.base import (
Expand Down Expand Up @@ -117,6 +117,110 @@ def message_chunk_to_message(chunk: BaseMessageChunk) -> BaseMessage:
)


MessageLikeRepresentation = Union[BaseMessage, Tuple[str, str], str, Dict[str, Any]]


def _create_message_from_message_type(
message_type: str,
content: str,
name: Optional[str] = None,
tool_call_id: Optional[str] = None,
**additional_kwargs: Any,
) -> BaseMessage:
"""Create a message from a message type and content string.
Args:
message_type: str the type of the message (e.g., "human", "ai", etc.)
content: str the content string.
Returns:
a message of the appropriate type.
"""
kwargs: Dict[str, Any] = {}
if name is not None:
kwargs["name"] = name
if tool_call_id is not None:
kwargs["tool_call_id"] = tool_call_id
if additional_kwargs:
kwargs["additional_kwargs"] = additional_kwargs # type: ignore[assignment]
if message_type in ("human", "user"):
message: BaseMessage = HumanMessage(content=content, **kwargs)
elif message_type in ("ai", "assistant"):
message = AIMessage(content=content, **kwargs)
elif message_type == "system":
message = SystemMessage(content=content, **kwargs)
elif message_type == "function":
message = FunctionMessage(content=content, **kwargs)
elif message_type == "tool":
message = ToolMessage(content=content, **kwargs)
else:
raise ValueError(
f"Unexpected message type: {message_type}. Use one of 'human',"
f" 'user', 'ai', 'assistant', or 'system'."
)
return message


def _convert_to_message(
message: MessageLikeRepresentation,
) -> BaseMessage:
"""Instantiate a message from a variety of message formats.
The message format can be one of the following:
- BaseMessagePromptTemplate
- BaseMessage
- 2-tuple of (role string, template); e.g., ("human", "{user_input}")
- dict: a message dict with role and content keys
- string: shorthand for ("human", template); e.g., "{user_input}"
Args:
message: a representation of a message in one of the supported formats
Returns:
an instance of a message or a message template
"""
if isinstance(message, BaseMessage):
_message = message
elif isinstance(message, str):
_message = _create_message_from_message_type("human", message)
elif isinstance(message, tuple):
if len(message) != 2:
raise ValueError(f"Expected 2-tuple of (role, template), got {message}")
message_type_str, template = message
_message = _create_message_from_message_type(message_type_str, template)
elif isinstance(message, dict):
msg_kwargs = message.copy()
try:
msg_type = msg_kwargs.pop("role")
msg_content = msg_kwargs.pop("content")
except KeyError:
raise ValueError(
f"Message dict must contain 'role' and 'content' keys, got {message}"
)
_message = _create_message_from_message_type(
msg_type, msg_content, **msg_kwargs
)
else:
raise NotImplementedError(f"Unsupported message type: {type(message)}")

return _message


def convert_to_messages(
messages: Sequence[MessageLikeRepresentation],
) -> List[BaseMessage]:
"""Convert a sequence of messages to a list of messages.
Args:
messages: Sequence of messages to convert.
Returns:
List of messages (BaseMessages).
"""
return [_convert_to_message(m) for m in messages]


__all__ = [
"AIMessage",
"AIMessageChunk",
Expand All @@ -133,6 +237,7 @@ def message_chunk_to_message(chunk: BaseMessageChunk) -> BaseMessage:
"SystemMessageChunk",
"ToolMessage",
"ToolMessageChunk",
"convert_to_messages",
"get_buffer_string",
"message_chunk_to_message",
"messages_from_dict",
Expand Down
3 changes: 2 additions & 1 deletion libs/core/langchain_core/prompts/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
ChatMessage,
HumanMessage,
SystemMessage,
convert_to_messages,
)
from langchain_core.messages.base import get_msg_title_repr
from langchain_core.prompt_values import ChatPromptValue, PromptValue
Expand Down Expand Up @@ -126,7 +127,7 @@ def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
f"variable {self.variable_name} should be a list of base messages, "
f"got {value}"
)
for v in value:
for v in convert_to_messages(value):
if not isinstance(v, BaseMessage):
raise ValueError(
f"variable {self.variable_name} should be a list of base messages,"
Expand Down
21 changes: 21 additions & 0 deletions libs/core/tests/unit_tests/fake/chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,3 +301,24 @@ async def _astream(
@property
def _llm_type(self) -> str:
return "generic-fake-chat-model"


class ParrotFakeChatModel(BaseChatModel):
"""A generic fake chat model that can be used to test the chat model interface.
* Chat model should be usable in both sync and async tests
"""

def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Top Level call"""
return ChatResult(generations=[ChatGeneration(message=messages[-1])])

@property
def _llm_type(self) -> str:
return "parrot-fake-chat-model"
11 changes: 10 additions & 1 deletion libs/core/tests/unit_tests/fake/test_fake_chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@

from langchain_core.callbacks.base import AsyncCallbackHandler
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
from langchain_core.messages.human import HumanMessage
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
from tests.unit_tests.fake.chat_model import GenericFakeChatModel
from tests.unit_tests.fake.chat_model import GenericFakeChatModel, ParrotFakeChatModel


def test_generic_fake_chat_model_invoke() -> None:
Expand Down Expand Up @@ -182,3 +183,11 @@ async def on_llm_new_token(
AIMessageChunk(content="goodbye"),
]
assert tokens == ["hello", " ", "goodbye"]


def test_chat_model_inputs() -> None:
fake = ParrotFakeChatModel()

assert fake.invoke("hello") == HumanMessage(content="hello")
assert fake.invoke([("ai", "blah")]) == AIMessage(content="blah")
assert fake.invoke([AIMessage(content="blah")]) == AIMessage(content="blah")
1 change: 1 addition & 0 deletions libs/core/tests/unit_tests/messages/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"SystemMessageChunk",
"ToolMessage",
"ToolMessageChunk",
"convert_to_messages",
"get_buffer_string",
"message_chunk_to_message",
"messages_from_dict",
Expand Down
6 changes: 6 additions & 0 deletions libs/core/tests/unit_tests/prompts/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,3 +369,9 @@ def test_messages_placeholder() -> None:
prompt.format_messages()
prompt = MessagesPlaceholder("history", optional=True)
assert prompt.format_messages() == []
prompt.format_messages(
history=[("system", "You are an AI assistant."), "Hello!"]
) == [
SystemMessage(content="You are an AI assistant."),
HumanMessage(content="Hello!"),
]
52 changes: 52 additions & 0 deletions libs/core/tests/unit_tests/test_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
HumanMessageChunk,
SystemMessage,
ToolMessage,
convert_to_messages,
get_buffer_string,
message_chunk_to_message,
messages_from_dict,
Expand Down Expand Up @@ -428,3 +429,54 @@ def test_tool_calls_merge() -> None:
]
},
)


def test_convert_to_messages() -> None:
# dicts
assert convert_to_messages(
[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello!"},
{"role": "ai", "content": "Hi!"},
{"role": "human", "content": "Hello!", "name": "Jane"},
{
"role": "assistant",
"content": "Hi!",
"name": "JaneBot",
"function_call": {"name": "greet", "arguments": '{"name": "Jane"}'},
},
{"role": "function", "name": "greet", "content": "Hi!"},
{"role": "tool", "tool_call_id": "tool_id", "content": "Hi!"},
]
) == [
SystemMessage(content="You are a helpful assistant."),
HumanMessage(content="Hello!"),
AIMessage(content="Hi!"),
HumanMessage(content="Hello!", name="Jane"),
AIMessage(
content="Hi!",
name="JaneBot",
additional_kwargs={
"function_call": {"name": "greet", "arguments": '{"name": "Jane"}'}
},
),
FunctionMessage(name="greet", content="Hi!"),
ToolMessage(tool_call_id="tool_id", content="Hi!"),
]

# tuples
assert convert_to_messages(
[
("system", "You are a helpful assistant."),
"hello!",
("ai", "Hi!"),
("human", "Hello!"),
("assistant", "Hi!"),
]
) == [
SystemMessage(content="You are a helpful assistant."),
HumanMessage(content="hello!"),
AIMessage(content="Hi!"),
HumanMessage(content="Hello!"),
AIMessage(content="Hi!"),
]

0 comments on commit 52ccae3

Please sign in to comment.