Skip to content

Commit

Permalink
Update RunnableWithMessageHistory (#14351)
Browse files Browse the repository at this point in the history
This PR updates RunnableWithMessage history to support user specific
configuration for the factory.

It extends support to passing multiple named arguments into the factory
if the factory takes more than a single argument.
  • Loading branch information
eyurtsev authored Dec 12, 2023
1 parent 8a126c5 commit 76905aa
Show file tree
Hide file tree
Showing 2 changed files with 254 additions and 31 deletions.
169 changes: 140 additions & 29 deletions libs/core/langchain_core/runnables/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,10 @@
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Type,
Union,
)

from langchain_core.chat_history import BaseChatMessageHistory
Expand All @@ -28,6 +25,9 @@
from langchain_core.runnables.config import RunnableConfig
from langchain_core.tracers.schemas import Run

import inspect
from typing import Callable, Dict, Union

MessagesOrDictWithMessages = Union[Sequence["BaseMessage"], Dict[str, Any]]
GetSessionHistoryCallable = Callable[..., BaseChatMessageHistory]

Expand All @@ -38,8 +38,10 @@ class RunnableWithMessageHistory(RunnableBindingBase):
Base runnable must have inputs and outputs that can be converted to a list of
BaseMessages.
RunnableWithMessageHistory must always be called with a config that contains session_id, e.g.:
``{"configurable": {"session_id": "<SESSION_ID>"}}``
RunnableWithMessageHistory must always be called with a config that contains
session_id, e.g.:
``{"configurable": {"session_id": "<SESSION_ID>"}}`
Example (dict input):
.. code-block:: python
Expand Down Expand Up @@ -79,12 +81,66 @@ class RunnableWithMessageHistory(RunnableBindingBase):
)
# -> "The inverse of cosine is called arccosine ..."
Here's an example that uses an in memory chat history, and a factory that
takes in two keys (user_id and conversation id) to create a chat history instance.
.. code-block:: python
store = {}
def get_session_history(
user_id: str, conversation_id: str
) -> ChatMessageHistory:
if (user_id, conversation_id) not in store:
store[(user_id, conversation_id)] = ChatMessageHistory()
return store[(user_id, conversation_id)]
prompt = ChatPromptTemplate.from_messages([
("system", "You're an assistant who's good at {ability}"),
MessagesPlaceholder(variable_name="history"),
("human", "{question}"),
])
chain = prompt | ChatAnthropic(model="claude-2")
with_message_history = RunnableWithMessageHistory(
chain,
get_session_history=get_session_history,
input_messages_key="messages",
history_messages_key="history",
history_factory_config=[
ConfigurableFieldSpec(
id="user_id",
annotation=str,
name="User ID",
description="Unique identifier for the user.",
default="",
is_shared=True,
),
ConfigurableFieldSpec(
id="conversation_id",
annotation=str,
name="Conversation ID",
description="Unique identifier for the conversation.",
default="",
is_shared=True,
),
],
)
chain_with_history.invoke(
{"ability": "math", "question": "What does cosine mean?"},
config={"configurable": {"user_id": "123", "conversation_id": "1"}}
)
""" # noqa: E501

get_session_history: GetSessionHistoryCallable
input_messages_key: Optional[str] = None
output_messages_key: Optional[str] = None
history_messages_key: Optional[str] = None
history_factory_config: Sequence[ConfigurableFieldSpec]

@classmethod
def get_lc_namespace(cls) -> List[str]:
Expand All @@ -102,6 +158,7 @@ def __init__(
input_messages_key: Optional[str] = None,
output_messages_key: Optional[str] = None,
history_messages_key: Optional[str] = None,
history_factory_config: Optional[Sequence[ConfigurableFieldSpec]] = None,
**kwargs: Any,
) -> None:
"""Initialize RunnableWithMessageHistory.
Expand All @@ -121,10 +178,10 @@ def __init__(
- A BaseMessage or sequence of BaseMessages
- A dict with a key for a BaseMessage or sequence of BaseMessages
get_session_history: Function that returns a new BaseChatMessageHistory
given a session id. Should take a single
positional argument `session_id` which is a string and a named argument
`user_id` which can be a string or None. e.g.:
get_session_history: Function that returns a new BaseChatMessageHistory.
This function should either take a single positional argument
`session_id` of type string and return a corresponding
chat message history instance.
```python
def get_session_history(
Expand All @@ -135,12 +192,29 @@ def get_session_history(
...
```
Or it should take keyword arguments that match the keys of
`session_history_config_specs` and return a corresponding
chat message history instance.
```python
def get_session_history(
*,
user_id: str,
thread_id: str,
) -> BaseChatMessageHistory:
...
```
input_messages_key: Must be specified if the base runnable accepts a dict
as input.
output_messages_key: Must be specified if the base runnable returns a dict
as output.
history_messages_key: Must be specified if the base runnable accepts a dict
as input and expects a separate key for historical messages.
history_factory_config: Configure fields that should be passed to the
chat history factory. See ``ConfigurableFieldSpec`` for more details.
Specifying these allows you to pass multiple config keys
into the get_session_history factory.
**kwargs: Arbitrary additional kwargs to pass to parent class
``RunnableBindingBase`` init.
""" # noqa: E501
Expand All @@ -155,29 +229,36 @@ def get_session_history(
bound = (
history_chain | runnable.with_listeners(on_end=self._exit_history)
).with_config(run_name="RunnableWithMessageHistory")

if history_factory_config:
_config_specs = history_factory_config
else:
# If not provided, then we'll use the default session_id field
_config_specs = [
ConfigurableFieldSpec(
id="session_id",
annotation=str,
name="Session ID",
description="Unique identifier for a session.",
default="",
is_shared=True,
),
]

super().__init__(
get_session_history=get_session_history,
input_messages_key=input_messages_key,
output_messages_key=output_messages_key,
bound=bound,
history_messages_key=history_messages_key,
history_factory_config=_config_specs,
**kwargs,
)

@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
return get_unique_config_specs(
super().config_specs
+ [
ConfigurableFieldSpec(
id="session_id",
annotation=str,
name="Session ID",
description="Unique identifier for a session.",
default="",
is_shared=True,
),
]
super().config_specs + list(self.history_factory_config)
)

def get_input_schema(
Expand Down Expand Up @@ -278,16 +359,46 @@ def _exit_history(self, run: Run, config: RunnableConfig) -> None:

def _merge_configs(self, *configs: Optional[RunnableConfig]) -> RunnableConfig:
config = super()._merge_configs(*configs)
# extract session_id
if "session_id" not in config.get("configurable", {}):
expected_keys = [field_spec.id for field_spec in self.history_factory_config]

configurable = config.get("configurable", {})

missing_keys = set(expected_keys) - set(configurable.keys())

if missing_keys:
example_input = {self.input_messages_key: "foo"}
example_config = {"configurable": {"session_id": "123"}}
example_configurable = {
missing_key: "[your-value-here]" for missing_key in missing_keys
}
example_config = {"configurable": example_configurable}
raise ValueError(
"session_id is required."
" Pass it in as part of the config argument to .invoke() or .stream()"
f"\neg. chain.invoke({example_input}, {example_config})"
f"Missing keys {sorted(missing_keys)} in config['configurable'] "
f"Expected keys are {sorted(expected_keys)}."
f"When using via .invoke() or .stream(), pass in a config; "
f"e.g., chain.invoke({example_input}, {example_config})"
)
# attach message_history
session_id = config["configurable"]["session_id"]
config["configurable"]["message_history"] = self.get_session_history(session_id)

parameter_names = _get_parameter_names(self.get_session_history)

if len(expected_keys) == 1:
# If arity = 1, then invoke function by positional arguments
message_history = self.get_session_history(configurable[expected_keys[0]])
else:
# otherwise verify that names of keys patch and invoke by named arguments
if set(expected_keys) != set(parameter_names):
raise ValueError(
f"Expected keys {sorted(expected_keys)} do not match parameter "
f"names {sorted(parameter_names)} of get_session_history."
)

message_history = self.get_session_history(
**{key: configurable[key] for key in expected_keys}
)
config["configurable"]["message_history"] = message_history
return config


def _get_parameter_names(callable_: GetSessionHistoryCallable) -> List[str]:
"""Get the parameter names of the callable."""
sig = inspect.signature(callable_)
return list(sig.parameters.keys())
116 changes: 114 additions & 2 deletions libs/core/tests/unit_tests/runnables/test_history.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import Any, Callable, Sequence, Union
from typing import Any, Callable, Dict, List, Sequence, Union

from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables.base import RunnableLambda
from langchain_core.runnables.config import RunnableConfig
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.runnables.utils import ConfigurableFieldSpec
from tests.unit_tests.fake.memory import ChatMessageHistory


Expand Down Expand Up @@ -130,7 +131,7 @@ def test_output_messages() -> None:
)
get_session_history = _get_get_session_history()
with_history = RunnableWithMessageHistory(
runnable,
runnable, # type: ignore
get_session_history,
input_messages_key="input",
history_messages_key="history",
Expand Down Expand Up @@ -238,3 +239,114 @@ class RunnableWithChatHistoryInput(BaseModel):
with_history.get_input_schema().schema()
== RunnableWithChatHistoryInput.schema()
)


def test_using_custom_config_specs() -> None:
"""Test that we can configure which keys should be passed to the session factory."""

def _fake_llm(input: Dict[str, Any]) -> List[BaseMessage]:
messages = input["messages"]
return [
AIMessage(
content="you said: "
+ "\n".join(
str(m.content) for m in messages if isinstance(m, HumanMessage)
)
)
]

runnable = RunnableLambda(_fake_llm)
store = {}

def get_session_history(user_id: str, conversation_id: str) -> ChatMessageHistory:
if (user_id, conversation_id) not in store:
store[(user_id, conversation_id)] = ChatMessageHistory()
return store[(user_id, conversation_id)]

with_message_history = RunnableWithMessageHistory(
runnable, # type: ignore
get_session_history=get_session_history,
input_messages_key="messages",
history_messages_key="history",
history_factory_config=[
ConfigurableFieldSpec(
id="user_id",
annotation=str,
name="User ID",
description="Unique identifier for the user.",
default="",
is_shared=True,
),
ConfigurableFieldSpec(
id="conversation_id",
annotation=str,
name="Conversation ID",
description="Unique identifier for the conversation.",
default=None,
is_shared=True,
),
],
)
result = with_message_history.invoke(
{
"messages": [HumanMessage(content="hello")],
},
{"configurable": {"user_id": "user1", "conversation_id": "1"}},
)
assert result == [
AIMessage(content="you said: hello"),
]
assert store == {
("user1", "1"): ChatMessageHistory(
messages=[
HumanMessage(content="hello"),
AIMessage(content="you said: hello"),
]
)
}

result = with_message_history.invoke(
{
"messages": [HumanMessage(content="goodbye")],
},
{"configurable": {"user_id": "user1", "conversation_id": "1"}},
)
assert result == [
AIMessage(content="you said: goodbye"),
]
assert store == {
("user1", "1"): ChatMessageHistory(
messages=[
HumanMessage(content="hello"),
AIMessage(content="you said: hello"),
HumanMessage(content="goodbye"),
AIMessage(content="you said: goodbye"),
]
)
}

result = with_message_history.invoke(
{
"messages": [HumanMessage(content="meow")],
},
{"configurable": {"user_id": "user2", "conversation_id": "1"}},
)
assert result == [
AIMessage(content="you said: meow"),
]
assert store == {
("user1", "1"): ChatMessageHistory(
messages=[
HumanMessage(content="hello"),
AIMessage(content="you said: hello"),
HumanMessage(content="goodbye"),
AIMessage(content="you said: goodbye"),
]
),
("user2", "1"): ChatMessageHistory(
messages=[
HumanMessage(content="meow"),
AIMessage(content="you said: meow"),
]
),
}

0 comments on commit 76905aa

Please sign in to comment.