From 76905aa043e3e604b5b34faf5e91d0aedb5ed6dd Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Mon, 11 Dec 2023 21:34:49 -0500 Subject: [PATCH] Update RunnableWithMessageHistory (#14351) This PR updates RunnableWithMessage history to support user specific configuration for the factory. It extends support to passing multiple named arguments into the factory if the factory takes more than a single argument. --- libs/core/langchain_core/runnables/history.py | 169 +++++++++++++++--- .../unit_tests/runnables/test_history.py | 116 +++++++++++- 2 files changed, 254 insertions(+), 31 deletions(-) diff --git a/libs/core/langchain_core/runnables/history.py b/libs/core/langchain_core/runnables/history.py index e0bb9281ffd5e..4ca57d0cbaa73 100644 --- a/libs/core/langchain_core/runnables/history.py +++ b/libs/core/langchain_core/runnables/history.py @@ -4,13 +4,10 @@ from typing import ( TYPE_CHECKING, Any, - Callable, - Dict, List, Optional, Sequence, Type, - Union, ) from langchain_core.chat_history import BaseChatMessageHistory @@ -28,6 +25,9 @@ from langchain_core.runnables.config import RunnableConfig from langchain_core.tracers.schemas import Run +import inspect +from typing import Callable, Dict, Union + MessagesOrDictWithMessages = Union[Sequence["BaseMessage"], Dict[str, Any]] GetSessionHistoryCallable = Callable[..., BaseChatMessageHistory] @@ -38,8 +38,10 @@ class RunnableWithMessageHistory(RunnableBindingBase): Base runnable must have inputs and outputs that can be converted to a list of BaseMessages. - RunnableWithMessageHistory must always be called with a config that contains session_id, e.g.: - ``{"configurable": {"session_id": ""}}`` + RunnableWithMessageHistory must always be called with a config that contains + session_id, e.g.: + + ``{"configurable": {"session_id": ""}}` Example (dict input): .. code-block:: python @@ -79,12 +81,66 @@ class RunnableWithMessageHistory(RunnableBindingBase): ) # -> "The inverse of cosine is called arccosine ..." + + Here's an example that uses an in memory chat history, and a factory that + takes in two keys (user_id and conversation id) to create a chat history instance. + + .. code-block:: python + + store = {} + + def get_session_history( + user_id: str, conversation_id: str + ) -> ChatMessageHistory: + if (user_id, conversation_id) not in store: + store[(user_id, conversation_id)] = ChatMessageHistory() + return store[(user_id, conversation_id)] + + prompt = ChatPromptTemplate.from_messages([ + ("system", "You're an assistant who's good at {ability}"), + MessagesPlaceholder(variable_name="history"), + ("human", "{question}"), + ]) + + chain = prompt | ChatAnthropic(model="claude-2") + + with_message_history = RunnableWithMessageHistory( + chain, + get_session_history=get_session_history, + input_messages_key="messages", + history_messages_key="history", + history_factory_config=[ + ConfigurableFieldSpec( + id="user_id", + annotation=str, + name="User ID", + description="Unique identifier for the user.", + default="", + is_shared=True, + ), + ConfigurableFieldSpec( + id="conversation_id", + annotation=str, + name="Conversation ID", + description="Unique identifier for the conversation.", + default="", + is_shared=True, + ), + ], + ) + + chain_with_history.invoke( + {"ability": "math", "question": "What does cosine mean?"}, + config={"configurable": {"user_id": "123", "conversation_id": "1"}} + ) + """ # noqa: E501 get_session_history: GetSessionHistoryCallable input_messages_key: Optional[str] = None output_messages_key: Optional[str] = None history_messages_key: Optional[str] = None + history_factory_config: Sequence[ConfigurableFieldSpec] @classmethod def get_lc_namespace(cls) -> List[str]: @@ -102,6 +158,7 @@ def __init__( input_messages_key: Optional[str] = None, output_messages_key: Optional[str] = None, history_messages_key: Optional[str] = None, + history_factory_config: Optional[Sequence[ConfigurableFieldSpec]] = None, **kwargs: Any, ) -> None: """Initialize RunnableWithMessageHistory. @@ -121,10 +178,10 @@ def __init__( - A BaseMessage or sequence of BaseMessages - A dict with a key for a BaseMessage or sequence of BaseMessages - get_session_history: Function that returns a new BaseChatMessageHistory - given a session id. Should take a single - positional argument `session_id` which is a string and a named argument - `user_id` which can be a string or None. e.g.: + get_session_history: Function that returns a new BaseChatMessageHistory. + This function should either take a single positional argument + `session_id` of type string and return a corresponding + chat message history instance. ```python def get_session_history( @@ -135,12 +192,29 @@ def get_session_history( ... ``` + Or it should take keyword arguments that match the keys of + `session_history_config_specs` and return a corresponding + chat message history instance. + + ```python + def get_session_history( + *, + user_id: str, + thread_id: str, + ) -> BaseChatMessageHistory: + ... + ``` + input_messages_key: Must be specified if the base runnable accepts a dict as input. output_messages_key: Must be specified if the base runnable returns a dict as output. history_messages_key: Must be specified if the base runnable accepts a dict as input and expects a separate key for historical messages. + history_factory_config: Configure fields that should be passed to the + chat history factory. See ``ConfigurableFieldSpec`` for more details. + Specifying these allows you to pass multiple config keys + into the get_session_history factory. **kwargs: Arbitrary additional kwargs to pass to parent class ``RunnableBindingBase`` init. """ # noqa: E501 @@ -155,29 +229,36 @@ def get_session_history( bound = ( history_chain | runnable.with_listeners(on_end=self._exit_history) ).with_config(run_name="RunnableWithMessageHistory") + + if history_factory_config: + _config_specs = history_factory_config + else: + # If not provided, then we'll use the default session_id field + _config_specs = [ + ConfigurableFieldSpec( + id="session_id", + annotation=str, + name="Session ID", + description="Unique identifier for a session.", + default="", + is_shared=True, + ), + ] + super().__init__( get_session_history=get_session_history, input_messages_key=input_messages_key, output_messages_key=output_messages_key, bound=bound, history_messages_key=history_messages_key, + history_factory_config=_config_specs, **kwargs, ) @property def config_specs(self) -> List[ConfigurableFieldSpec]: return get_unique_config_specs( - super().config_specs - + [ - ConfigurableFieldSpec( - id="session_id", - annotation=str, - name="Session ID", - description="Unique identifier for a session.", - default="", - is_shared=True, - ), - ] + super().config_specs + list(self.history_factory_config) ) def get_input_schema( @@ -278,16 +359,46 @@ def _exit_history(self, run: Run, config: RunnableConfig) -> None: def _merge_configs(self, *configs: Optional[RunnableConfig]) -> RunnableConfig: config = super()._merge_configs(*configs) - # extract session_id - if "session_id" not in config.get("configurable", {}): + expected_keys = [field_spec.id for field_spec in self.history_factory_config] + + configurable = config.get("configurable", {}) + + missing_keys = set(expected_keys) - set(configurable.keys()) + + if missing_keys: example_input = {self.input_messages_key: "foo"} - example_config = {"configurable": {"session_id": "123"}} + example_configurable = { + missing_key: "[your-value-here]" for missing_key in missing_keys + } + example_config = {"configurable": example_configurable} raise ValueError( - "session_id is required." - " Pass it in as part of the config argument to .invoke() or .stream()" - f"\neg. chain.invoke({example_input}, {example_config})" + f"Missing keys {sorted(missing_keys)} in config['configurable'] " + f"Expected keys are {sorted(expected_keys)}." + f"When using via .invoke() or .stream(), pass in a config; " + f"e.g., chain.invoke({example_input}, {example_config})" ) - # attach message_history - session_id = config["configurable"]["session_id"] - config["configurable"]["message_history"] = self.get_session_history(session_id) + + parameter_names = _get_parameter_names(self.get_session_history) + + if len(expected_keys) == 1: + # If arity = 1, then invoke function by positional arguments + message_history = self.get_session_history(configurable[expected_keys[0]]) + else: + # otherwise verify that names of keys patch and invoke by named arguments + if set(expected_keys) != set(parameter_names): + raise ValueError( + f"Expected keys {sorted(expected_keys)} do not match parameter " + f"names {sorted(parameter_names)} of get_session_history." + ) + + message_history = self.get_session_history( + **{key: configurable[key] for key in expected_keys} + ) + config["configurable"]["message_history"] = message_history return config + + +def _get_parameter_names(callable_: GetSessionHistoryCallable) -> List[str]: + """Get the parameter names of the callable.""" + sig = inspect.signature(callable_) + return list(sig.parameters.keys()) diff --git a/libs/core/tests/unit_tests/runnables/test_history.py b/libs/core/tests/unit_tests/runnables/test_history.py index 49f68825ba239..0bec016248b1f 100644 --- a/libs/core/tests/unit_tests/runnables/test_history.py +++ b/libs/core/tests/unit_tests/runnables/test_history.py @@ -1,10 +1,11 @@ -from typing import Any, Callable, Sequence, Union +from typing import Any, Callable, Dict, List, Sequence, Union from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage from langchain_core.pydantic_v1 import BaseModel from langchain_core.runnables.base import RunnableLambda from langchain_core.runnables.config import RunnableConfig from langchain_core.runnables.history import RunnableWithMessageHistory +from langchain_core.runnables.utils import ConfigurableFieldSpec from tests.unit_tests.fake.memory import ChatMessageHistory @@ -130,7 +131,7 @@ def test_output_messages() -> None: ) get_session_history = _get_get_session_history() with_history = RunnableWithMessageHistory( - runnable, + runnable, # type: ignore get_session_history, input_messages_key="input", history_messages_key="history", @@ -238,3 +239,114 @@ class RunnableWithChatHistoryInput(BaseModel): with_history.get_input_schema().schema() == RunnableWithChatHistoryInput.schema() ) + + +def test_using_custom_config_specs() -> None: + """Test that we can configure which keys should be passed to the session factory.""" + + def _fake_llm(input: Dict[str, Any]) -> List[BaseMessage]: + messages = input["messages"] + return [ + AIMessage( + content="you said: " + + "\n".join( + str(m.content) for m in messages if isinstance(m, HumanMessage) + ) + ) + ] + + runnable = RunnableLambda(_fake_llm) + store = {} + + def get_session_history(user_id: str, conversation_id: str) -> ChatMessageHistory: + if (user_id, conversation_id) not in store: + store[(user_id, conversation_id)] = ChatMessageHistory() + return store[(user_id, conversation_id)] + + with_message_history = RunnableWithMessageHistory( + runnable, # type: ignore + get_session_history=get_session_history, + input_messages_key="messages", + history_messages_key="history", + history_factory_config=[ + ConfigurableFieldSpec( + id="user_id", + annotation=str, + name="User ID", + description="Unique identifier for the user.", + default="", + is_shared=True, + ), + ConfigurableFieldSpec( + id="conversation_id", + annotation=str, + name="Conversation ID", + description="Unique identifier for the conversation.", + default=None, + is_shared=True, + ), + ], + ) + result = with_message_history.invoke( + { + "messages": [HumanMessage(content="hello")], + }, + {"configurable": {"user_id": "user1", "conversation_id": "1"}}, + ) + assert result == [ + AIMessage(content="you said: hello"), + ] + assert store == { + ("user1", "1"): ChatMessageHistory( + messages=[ + HumanMessage(content="hello"), + AIMessage(content="you said: hello"), + ] + ) + } + + result = with_message_history.invoke( + { + "messages": [HumanMessage(content="goodbye")], + }, + {"configurable": {"user_id": "user1", "conversation_id": "1"}}, + ) + assert result == [ + AIMessage(content="you said: goodbye"), + ] + assert store == { + ("user1", "1"): ChatMessageHistory( + messages=[ + HumanMessage(content="hello"), + AIMessage(content="you said: hello"), + HumanMessage(content="goodbye"), + AIMessage(content="you said: goodbye"), + ] + ) + } + + result = with_message_history.invoke( + { + "messages": [HumanMessage(content="meow")], + }, + {"configurable": {"user_id": "user2", "conversation_id": "1"}}, + ) + assert result == [ + AIMessage(content="you said: meow"), + ] + assert store == { + ("user1", "1"): ChatMessageHistory( + messages=[ + HumanMessage(content="hello"), + AIMessage(content="you said: hello"), + HumanMessage(content="goodbye"), + AIMessage(content="you said: goodbye"), + ] + ), + ("user2", "1"): ChatMessageHistory( + messages=[ + HumanMessage(content="meow"), + AIMessage(content="you said: meow"), + ] + ), + }