Skip to content

Commit

Permalink
support messages in messages out (#20862)
Browse files Browse the repository at this point in the history
  • Loading branch information
hwchase17 authored Apr 24, 2024
1 parent a1614b8 commit 43c041c
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 11 deletions.
42 changes: 31 additions & 11 deletions libs/core/langchain_core/runnables/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
)

if TYPE_CHECKING:
from langchain_core.language_models import LanguageModelLike
from langchain_core.messages import BaseMessage
from langchain_core.runnables.config import RunnableConfig
from langchain_core.tracers.schemas import Run
Expand Down Expand Up @@ -228,9 +229,12 @@ def get_lc_namespace(cls) -> List[str]:

def __init__(
self,
runnable: Runnable[
MessagesOrDictWithMessages,
Union[str, BaseMessage, MessagesOrDictWithMessages],
runnable: Union[
Runnable[
Union[MessagesOrDictWithMessages],
Union[str, BaseMessage, MessagesOrDictWithMessages],
],
LanguageModelLike,
],
get_session_history: GetSessionHistoryCallable,
*,
Expand Down Expand Up @@ -364,10 +368,19 @@ def get_input_schema(
return super_schema

def _get_input_messages(
self, input_val: Union[str, BaseMessage, Sequence[BaseMessage]]
self, input_val: Union[str, BaseMessage, Sequence[BaseMessage], dict]
) -> List[BaseMessage]:
from langchain_core.messages import BaseMessage

if isinstance(input_val, dict):
if self.input_messages_key:
key = self.input_messages_key
elif len(input_val) == 1:
key = list(input_val.keys())[0]
else:
key = "input"
input_val = input_val[key]

if isinstance(input_val, str):
from langchain_core.messages import HumanMessage

Expand All @@ -388,7 +401,18 @@ def _get_output_messages(
from langchain_core.messages import BaseMessage

if isinstance(output_val, dict):
output_val = output_val[self.output_messages_key or "output"]
if self.output_messages_key:
key = self.output_messages_key
elif len(output_val) == 1:
key = list(output_val.keys())[0]
else:
key = "output"
# If you are wrapping a chat model directly
# The output is actually this weird generations object
if key not in output_val and "generations" in output_val:
output_val = output_val["generations"][0][0]["message"]
else:
output_val = output_val[key]

if isinstance(output_val, str):
from langchain_core.messages import AIMessage
Expand All @@ -407,10 +431,7 @@ def _enter_history(self, input: Any, config: RunnableConfig) -> List[BaseMessage

if not self.history_messages_key:
# return all messages
input_val = (
input if not self.input_messages_key else input[self.input_messages_key]
)
messages += self._get_input_messages(input_val)
messages += self._get_input_messages(input)
return messages

async def _aenter_history(
Expand All @@ -432,8 +453,7 @@ def _exit_history(self, run: Run, config: RunnableConfig) -> None:

# Get the input messages
inputs = load(run.inputs)
input_val = inputs[self.input_messages_key or "input"]
input_messages = self._get_input_messages(input_val)
input_messages = self._get_input_messages(inputs)

# If historic messages were prepended to the input messages, remove them to
# avoid adding duplicate messages to history.
Expand Down
40 changes: 40 additions & 0 deletions libs/core/tests/unit_tests/runnables/test_history.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from typing import Any, Callable, Dict, List, Optional, Sequence, Union

from langchain_core.callbacks import (
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables.base import RunnableLambda
from langchain_core.runnables.config import RunnableConfig
Expand Down Expand Up @@ -127,6 +132,41 @@ def test_output_message() -> None:
assert output == AIMessage(content="you said: hello\ngood bye")


def test_input_messages_output_message() -> None:
class LengthChatModel(BaseChatModel):
"""A fake chat model that returns the length of the messages passed in."""

def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Top Level call"""
return ChatResult(
generations=[
ChatGeneration(message=AIMessage(content=str(len(messages))))
]
)

@property
def _llm_type(self) -> str:
return "length-fake-chat-model"

runnable = LengthChatModel()
get_session_history = _get_get_session_history()
with_history = RunnableWithMessageHistory(
runnable,
get_session_history,
)
config: RunnableConfig = {"configurable": {"session_id": "4"}}
output = with_history.invoke([HumanMessage(content="hi")], config)
assert output.content == "1"
output = with_history.invoke([HumanMessage(content="hi")], config)
assert output.content == "3"


def test_output_messages() -> None:
runnable = RunnableLambda(
lambda input: [
Expand Down

0 comments on commit 43c041c

Please sign in to comment.