From 5156233f973a6c4541012f679de44faaa503afdb Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Wed, 6 Dec 2023 12:22:08 -0500 Subject: [PATCH 01/12] x --- libs/core/langchain_core/runnables/history.py | 120 ++++++++++++++---- 1 file changed, 95 insertions(+), 25 deletions(-) diff --git a/libs/core/langchain_core/runnables/history.py b/libs/core/langchain_core/runnables/history.py index 52c04c0e104c7..dd235aad1207f 100644 --- a/libs/core/langchain_core/runnables/history.py +++ b/libs/core/langchain_core/runnables/history.py @@ -4,13 +4,10 @@ from typing import ( TYPE_CHECKING, Any, - Callable, - Dict, List, Optional, Sequence, Type, - Union, ) from langchain_core.chat_history import BaseChatMessageHistory @@ -28,8 +25,36 @@ from langchain_core.runnables.config import RunnableConfig from langchain_core.tracers.schemas import Run +import inspect +from typing import Callable, Dict, Union + MessagesOrDictWithMessages = Union[Sequence["BaseMessage"], Dict[str, Any]] -GetSessionHistoryCallable = Callable[..., BaseChatMessageHistory] +GetSessionHistoryCallable = Union[Callable[..., BaseChatMessageHistory]] + + +def check_callable_argument_type(callable_: GetSessionHistoryCallable) -> str: + """Check if the callable accepts a single argument of type 'str' or 'Dict'.""" + sig = inspect.signature(callable_) + + if len(sig.parameters) == 1: + param_name, param = sig.parameters.popitem() + if param.annotation is str: + return ( + f"The callable accepts a single argument of type 'str' " + f"named '{param_name}'." + ) + elif param.annotation is Dict: + return ( + f"The callable accepts a single argument of type 'Dict' " + f"named '{param_name}'." + ) + else: + return ( + "The callable accepts a single argument, but its type is " + "not 'str' or 'Dict'." + ) + else: + return "The callable does not accept a single argument." class RunnableWithMessageHistory(RunnableBindingBase): @@ -38,7 +63,9 @@ class RunnableWithMessageHistory(RunnableBindingBase): Base runnable must have inputs and outputs that can be converted to a list of BaseMessages. - RunnableWithMessageHistory must always be called with a config that contains session_id, e.g.: + RunnableWithMessageHistory must always be called with a config that contains + session_id, e.g.: + ``{"configurable": {"session_id": ""}}`` Example (dict input): @@ -84,6 +111,7 @@ class RunnableWithMessageHistory(RunnableBindingBase): input_messages_key: Optional[str] = None output_messages_key: Optional[str] = None history_messages_key: Optional[str] = None + session_history_config_specs: Sequence[ConfigurableFieldSpec] = None def __init__( self, @@ -96,6 +124,7 @@ def __init__( input_messages_key: Optional[str] = None, output_messages_key: Optional[str] = None, history_messages_key: Optional[str] = None, + session_history_config_specs: Optional[Sequence[ConfigurableFieldSpec]] = None, **kwargs: Any, ) -> None: """Initialize RunnableWithMessageHistory. @@ -135,6 +164,8 @@ def get_session_history( as output. history_messages_key: Must be specified if the base runnable accepts a dict as input and expects a separate key for historical messages. + custom_config_specs: Configure fields that should be passed to the chat + history factory. See ``ConfigurableFieldSpec`` for more details. **kwargs: Arbitrary additional kwargs to pass to parent class ``RunnableBindingBase`` init. """ # noqa: E501 @@ -149,29 +180,36 @@ def get_session_history( bound = ( history_chain | runnable.with_listeners(on_end=self._exit_history) ).with_config(run_name="RunnableWithMessageHistory") + + if session_history_config_specs: + _config_specs = session_history_config_specs + else: + # If not provided, then we'll use the default session_id field + _config_specs = [ + ConfigurableFieldSpec( + id="session_id", + annotation=str, + name="Session ID", + description="Unique identifier for a session.", + default="", + is_shared=True, + ), + ] + super().__init__( get_session_history=get_session_history, input_messages_key=input_messages_key, output_messages_key=output_messages_key, bound=bound, history_messages_key=history_messages_key, + session_history_config_specs=_config_specs, **kwargs, ) @property def config_specs(self) -> List[ConfigurableFieldSpec]: return get_unique_config_specs( - super().config_specs - + [ - ConfigurableFieldSpec( - id="session_id", - annotation=str, - name="Session ID", - description="Unique identifier for a session.", - default="", - is_shared=True, - ), - ] + super().config_specs + list(self.session_history_config_specs) ) def get_input_schema( @@ -274,16 +312,48 @@ def _exit_history(self, run: Run, config: RunnableConfig) -> None: def _merge_configs(self, *configs: Optional[RunnableConfig]) -> RunnableConfig: config = super()._merge_configs(*configs) - # extract session_id - if "session_id" not in config.get("configurable", {}): + expected_keys = [ + field_spec.id for field_spec in self.session_history_config_specs + ] + + configurable = config.get("configurable", {}) + + missing_keys = set(expected_keys) - set(configurable.keys()) + + if missing_keys: example_input = {self.input_messages_key: "foo"} - example_config = {"configurable": {"session_id": "123"}} + example_configurable = { + missing_key: "[your-value-here]" for missing_key in missing_keys + } + example_config = {"configurable": example_configurable} raise ValueError( - "session_id_id is required." - " Pass it in as part of the config argument to .invoke() or .stream()" - f"\neg. chain.invoke({example_input}, {example_config})" + f"Missing keys {sorted(missing_keys)} in config['configurable'] " + f"Expected keys are {sorted(expected_keys)}." + f"When using via .invoke() or .stream(), pass in a config; " + f"e.g., chain.invoke({example_input}, {example_config})" + ) + + parameter_names = _get_parameter_names(self.get_session_history) + + if len(parameter_names) == 1 and len(expected_keys) == 1: + # If arity = 1, then invoke function by positional arguments + message_history = self.get_session_history(configurable[expected_keys[0]]) + else: + # otherwise verify that names of keys patch and invoke by named arguments + if set(expected_keys) != set(parameter_names): + raise ValueError( + f"Expected keys {sorted(expected_keys)} do not match parameter " + f"names {sorted(parameter_names)} of get_session_history." + ) + + message_history = self.get_session_history( + **{key: configurable[key] for key in expected_keys} ) - # attach message_history - session_id = config["configurable"]["session_id"] - config["configurable"]["message_history"] = self.get_session_history(session_id) + config["configurable"]["message_history"] = message_history return config + + +def _get_parameter_names(callable_: GetSessionHistoryCallable) -> List[str]: + """Get the parameter names of the callable.""" + sig = inspect.signature(callable_) + return list(sig.parameters.keys()) From 7f55b675e220a9c3be91248f95394b89f13b2f85 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Wed, 6 Dec 2023 12:28:25 -0500 Subject: [PATCH 02/12] x --- libs/core/langchain_core/runnables/history.py | 25 ------------------- 1 file changed, 25 deletions(-) diff --git a/libs/core/langchain_core/runnables/history.py b/libs/core/langchain_core/runnables/history.py index dd235aad1207f..f5c5963be8067 100644 --- a/libs/core/langchain_core/runnables/history.py +++ b/libs/core/langchain_core/runnables/history.py @@ -32,31 +32,6 @@ GetSessionHistoryCallable = Union[Callable[..., BaseChatMessageHistory]] -def check_callable_argument_type(callable_: GetSessionHistoryCallable) -> str: - """Check if the callable accepts a single argument of type 'str' or 'Dict'.""" - sig = inspect.signature(callable_) - - if len(sig.parameters) == 1: - param_name, param = sig.parameters.popitem() - if param.annotation is str: - return ( - f"The callable accepts a single argument of type 'str' " - f"named '{param_name}'." - ) - elif param.annotation is Dict: - return ( - f"The callable accepts a single argument of type 'Dict' " - f"named '{param_name}'." - ) - else: - return ( - "The callable accepts a single argument, but its type is " - "not 'str' or 'Dict'." - ) - else: - return "The callable does not accept a single argument." - - class RunnableWithMessageHistory(RunnableBindingBase): """A runnable that manages chat message history for another runnable. From 5410e4a3aa1e62c271b67842265d5a39c264856c Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Fri, 8 Dec 2023 12:36:29 -0500 Subject: [PATCH 03/12] x --- libs/core/langchain_core/runnables/history.py | 34 ++++++--- .../unit_tests/runnables/test_history.py | 70 +++++++++++++++++++ 2 files changed, 95 insertions(+), 9 deletions(-) diff --git a/libs/core/langchain_core/runnables/history.py b/libs/core/langchain_core/runnables/history.py index f5c5963be8067..8566bf5ea4b3f 100644 --- a/libs/core/langchain_core/runnables/history.py +++ b/libs/core/langchain_core/runnables/history.py @@ -48,8 +48,9 @@ class RunnableWithMessageHistory(RunnableBindingBase): from typing import Optional - from langchain_core.chat_models import ChatAnthropic - from langchain_core.memory.chat_message_histories import RedisChatMessageHistory + from langchain.chat_models import ChatAnthropic + from langchain.memory.chat_message_histories import RedisChatMessageHistory + from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.runnables.history import RunnableWithMessageHistory @@ -119,10 +120,10 @@ def __init__( - A BaseMessage or sequence of BaseMessages - A dict with a key for a BaseMessage or sequence of BaseMessages - get_session_history: Function that returns a new BaseChatMessageHistory - given a session id. Should take a single - positional argument `session_id` which is a string and a named argument - `user_id` which can be a string or None. e.g.: + get_session_history: Function that returns a new BaseChatMessageHistory. + This function should either take a single positional argument + `session_id` of type string and return a corresponding + chat message history instance. ```python def get_session_history( @@ -133,14 +134,29 @@ def get_session_history( ... ``` + Or it should take keyword arguments that match the keys of + `session_history_config_specs` and return a corresponding + chat message history instance. + + ```python + def get_session_history( + *, + user_id: str, + thread_id: str, + ) -> BaseChatMessageHistory: + ... + ``` + input_messages_key: Must be specified if the base runnable accepts a dict as input. output_messages_key: Must be specified if the base runnable returns a dict as output. history_messages_key: Must be specified if the base runnable accepts a dict as input and expects a separate key for historical messages. - custom_config_specs: Configure fields that should be passed to the chat - history factory. See ``ConfigurableFieldSpec`` for more details. + session_history_config_specs: Configure fields that should be passed to the + chat history factory. See ``ConfigurableFieldSpec`` for more details. + Specifying these allows you to pass multiple config keys + into the get_session_history factory. **kwargs: Arbitrary additional kwargs to pass to parent class ``RunnableBindingBase`` init. """ # noqa: E501 @@ -310,7 +326,7 @@ def _merge_configs(self, *configs: Optional[RunnableConfig]) -> RunnableConfig: parameter_names = _get_parameter_names(self.get_session_history) - if len(parameter_names) == 1 and len(expected_keys) == 1: + if len(expected_keys) == 1: # If arity = 1, then invoke function by positional arguments message_history = self.get_session_history(configurable[expected_keys[0]]) else: diff --git a/libs/core/tests/unit_tests/runnables/test_history.py b/libs/core/tests/unit_tests/runnables/test_history.py index c5f02c931337a..efc57450be641 100644 --- a/libs/core/tests/unit_tests/runnables/test_history.py +++ b/libs/core/tests/unit_tests/runnables/test_history.py @@ -5,6 +5,7 @@ from langchain_core.runnables.base import RunnableLambda from langchain_core.runnables.config import RunnableConfig from langchain_core.runnables.history import RunnableWithMessageHistory +from langchain_core.runnables.utils import ConfigurableFieldSpec from tests.unit_tests.fake.memory import ChatMessageHistory @@ -239,3 +240,72 @@ class RunnableWithChatHistoryInput(BaseModel): with_history.get_input_schema().schema() == RunnableWithChatHistoryInput.schema() ) + + +def test_using_custom_config_specs() -> None: + """Test that we can configure which keys should be passed to the session factory.""" + runnable = RunnableLambda( + lambda messages: { + "output": [ + AIMessage( + content="you said: " + + "\n".join( + [ + str(m.content) + for m in messages + if isinstance(m, HumanMessage) + ] + ) + ) + ] + } + ) + + store = {} + + def get_session_history(user_id: str, conversation_id: str) -> ChatMessageHistory: + if (user_id, conversation_id) not in store: + store[(user_id, conversation_id)] = ChatMessageHistory() + return store[(user_id, conversation_id)] + + with_message_history = RunnableWithMessageHistory( + runnable, + get_session_history=get_session_history, + input_messages_key="input", + session_history_config_specs=[ + ConfigurableFieldSpec( + id="user_id", + annotation=str, + name="User ID", + description="Unique identifier for the user.", + default="", + is_shared=True, + ), + ConfigurableFieldSpec( + id="conversation_id", + annotation=str, + name="Conversation ID", + description="Unique identifier for the conversation.", + # None means that the conversation ID will be generated automatically + default=None, + is_shared=True, + ), + ], + ) + result = with_message_history.invoke( + "hello", {"configurable": {"user_id": "user1", "conversation_id": "1"}} + ) + assert result == { + "output": [ + AIMessage(content="you said: hello"), + ] + } + with_message_history.invoke( + {"input": "meow"}, {"configurable": {"user_id": "user1", "conversation_id": "1"}} + ) + + with_message_history.invoke( + {"input": "goodbye"}, {"configurable": {"user_id": "user2", "conversation_id": "1"}} + ) + assert store == { + } From 408f64d93dc6cd2cf043b82432d03606cf357b95 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Fri, 8 Dec 2023 12:41:13 -0500 Subject: [PATCH 04/12] x --- .../unit_tests/runnables/test_history.py | 41 ++++++++++++++----- 1 file changed, 31 insertions(+), 10 deletions(-) diff --git a/libs/core/tests/unit_tests/runnables/test_history.py b/libs/core/tests/unit_tests/runnables/test_history.py index efc57450be641..98176e9bdba5a 100644 --- a/libs/core/tests/unit_tests/runnables/test_history.py +++ b/libs/core/tests/unit_tests/runnables/test_history.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Sequence, Union +from typing import Any, Callable, Sequence, Union, Tuple, Optional, Dict from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage from langchain_core.pydantic_v1 import BaseModel @@ -18,8 +18,10 @@ def test_interfaces() -> None: assert str(history) == "System: system\nHuman: human 1\nAI: ai\nHuman: human 2" -def _get_get_session_history() -> Callable[..., ChatMessageHistory]: - chat_history_store = {} +def _get_get_session_history( + *, store: Optional[Dict[str, Any]] = None +) -> Callable[..., ChatMessageHistory]: + chat_history_store = store if store is not None else {} def get_session_history(session_id: str, **kwargs: Any) -> ChatMessageHistory: if session_id not in chat_history_store: @@ -34,13 +36,15 @@ def test_input_messages() -> None: lambda messages: "you said: " + "\n".join(str(m.content) for m in messages if isinstance(m, HumanMessage)) ) - get_session_history = _get_get_session_history() + store = {} + get_session_history = _get_get_session_history(store=store) with_history = RunnableWithMessageHistory(runnable, get_session_history) config: RunnableConfig = {"configurable": {"session_id": "1"}} output = with_history.invoke([HumanMessage(content="hello")], config) assert output == "you said: hello" output = with_history.invoke([HumanMessage(content="good bye")], config) assert output == "you said: hello\ngood bye" + assert store == {} def test_input_dict() -> None: @@ -271,7 +275,6 @@ def get_session_history(user_id: str, conversation_id: str) -> ChatMessageHistor with_message_history = RunnableWithMessageHistory( runnable, get_session_history=get_session_history, - input_messages_key="input", session_history_config_specs=[ ConfigurableFieldSpec( id="user_id", @@ -300,12 +303,30 @@ def get_session_history(user_id: str, conversation_id: str) -> ChatMessageHistor AIMessage(content="you said: hello"), ] } - with_message_history.invoke( - {"input": "meow"}, {"configurable": {"user_id": "user1", "conversation_id": "1"}} - ) + assert store == { + ("user1", "1"): ChatMessageHistory( + messages=[ + HumanMessage(content="hello"), + AIMessage(content="you said: hello"), + ] + ) + } - with_message_history.invoke( - {"input": "goodbye"}, {"configurable": {"user_id": "user2", "conversation_id": "1"}} + result = with_message_history.invoke( + "goodbye", {"configurable": {"user_id": "user1", "conversation_id": "1"}} ) + assert result == { + "output": [ + AIMessage(content="you said: hello\ngoodbye"), + ] + } assert store == { + ("user1", "1"): ChatMessageHistory( + messages=[ + HumanMessage(content="hello"), + AIMessage(content="you said: hello"), + HumanMessage(content="goodbye"), + AIMessage(content="you said: hello\ngoodbye"), + ] + ) } From 5931b3e79d268f5232e1587b4280f2cd1c9485c7 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Fri, 8 Dec 2023 12:44:36 -0500 Subject: [PATCH 05/12] x --- libs/core/tests/unit_tests/runnables/test_history.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/libs/core/tests/unit_tests/runnables/test_history.py b/libs/core/tests/unit_tests/runnables/test_history.py index 98176e9bdba5a..8f7c99cbaa3be 100644 --- a/libs/core/tests/unit_tests/runnables/test_history.py +++ b/libs/core/tests/unit_tests/runnables/test_history.py @@ -44,7 +44,16 @@ def test_input_messages() -> None: assert output == "you said: hello" output = with_history.invoke([HumanMessage(content="good bye")], config) assert output == "you said: hello\ngood bye" - assert store == {} + assert store == { + "1": ChatMessageHistory( + messages=[ + HumanMessage(content="hello"), + AIMessage(content="you said: hello"), + HumanMessage(content="good bye"), + AIMessage(content="you said: hello\ngood bye"), + ] + ) + } def test_input_dict() -> None: From c52d397f64ad1d68728e258ec8d90e0e366611e4 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Fri, 8 Dec 2023 14:05:41 -0500 Subject: [PATCH 06/12] x --- libs/core/langchain_core/runnables/history.py | 53 ++++++++++++ .../unit_tests/runnables/test_history.py | 85 ++++++++++++------- 2 files changed, 108 insertions(+), 30 deletions(-) diff --git a/libs/core/langchain_core/runnables/history.py b/libs/core/langchain_core/runnables/history.py index 8566bf5ea4b3f..276550dd537db 100644 --- a/libs/core/langchain_core/runnables/history.py +++ b/libs/core/langchain_core/runnables/history.py @@ -81,6 +81,59 @@ class RunnableWithMessageHistory(RunnableBindingBase): ) # -> "The inverse of cosine is called arccosine ..." + + Here's an example that uses an in memory chat history, and a factory that + takes in two keys (user_id and conversation id) to create a chat history instance. + + .. code-block:: python + + store = {} + + def get_session_history( + user_id: str, conversation_id: str + ) -> ChatMessageHistory: + if (user_id, conversation_id) not in store: + store[(user_id, conversation_id)] = ChatMessageHistory() + return store[(user_id, conversation_id)] + + prompt = ChatPromptTemplate.from_messages([ + ("system", "You're an assistant who's good at {ability}"), + MessagesPlaceholder(variable_name="history"), + ("human", "{question}"), + ]) + + chain = prompt | ChatAnthropic(model="claude-2") + + with_message_history = RunnableWithMessageHistory( + chain, + get_session_history=get_session_history, + input_messages_key="messages", + history_messages_key="history", + session_history_config_specs=[ + ConfigurableFieldSpec( + id="user_id", + annotation=str, + name="User ID", + description="Unique identifier for the user.", + default="", + is_shared=True, + ), + ConfigurableFieldSpec( + id="conversation_id", + annotation=str, + name="Conversation ID", + description="Unique identifier for the conversation.", + default="", + is_shared=True, + ), + ], + ) + + chain_with_history.invoke( + {"ability": "math", "question": "What does cosine mean?"}, + config={"configurable": {"user_id": "123", "conversation_id": "1"}} + ) + """ # noqa: E501 get_session_history: GetSessionHistoryCallable diff --git a/libs/core/tests/unit_tests/runnables/test_history.py b/libs/core/tests/unit_tests/runnables/test_history.py index 8f7c99cbaa3be..d932693c2a210 100644 --- a/libs/core/tests/unit_tests/runnables/test_history.py +++ b/libs/core/tests/unit_tests/runnables/test_history.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Sequence, Union, Tuple, Optional, Dict +from typing import Any, Callable, Dict, List, Optional, Sequence, Union from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage from langchain_core.pydantic_v1 import BaseModel @@ -257,23 +257,19 @@ class RunnableWithChatHistoryInput(BaseModel): def test_using_custom_config_specs() -> None: """Test that we can configure which keys should be passed to the session factory.""" - runnable = RunnableLambda( - lambda messages: { - "output": [ - AIMessage( - content="you said: " - + "\n".join( - [ - str(m.content) - for m in messages - if isinstance(m, HumanMessage) - ] - ) + + def _fake_llm(input) -> List[BaseMessage]: + messages = input["messages"] + return [ + AIMessage( + content="you said: " + + "\n".join( + str(m.content) for m in messages if isinstance(m, HumanMessage) ) - ] - } - ) + ) + ] + runnable = RunnableLambda(_fake_llm) store = {} def get_session_history(user_id: str, conversation_id: str) -> ChatMessageHistory: @@ -284,6 +280,8 @@ def get_session_history(user_id: str, conversation_id: str) -> ChatMessageHistor with_message_history = RunnableWithMessageHistory( runnable, get_session_history=get_session_history, + input_messages_key="messages", + history_messages_key="history", session_history_config_specs=[ ConfigurableFieldSpec( id="user_id", @@ -298,20 +296,20 @@ def get_session_history(user_id: str, conversation_id: str) -> ChatMessageHistor annotation=str, name="Conversation ID", description="Unique identifier for the conversation.", - # None means that the conversation ID will be generated automatically default=None, is_shared=True, ), ], ) result = with_message_history.invoke( - "hello", {"configurable": {"user_id": "user1", "conversation_id": "1"}} + { + "messages": [HumanMessage(content="hello")], + }, + {"configurable": {"user_id": "user1", "conversation_id": "1"}}, ) - assert result == { - "output": [ - AIMessage(content="you said: hello"), - ] - } + assert result == [ + AIMessage(content="you said: hello"), + ] assert store == { ("user1", "1"): ChatMessageHistory( messages=[ @@ -322,20 +320,47 @@ def get_session_history(user_id: str, conversation_id: str) -> ChatMessageHistor } result = with_message_history.invoke( - "goodbye", {"configurable": {"user_id": "user1", "conversation_id": "1"}} + { + "messages": [HumanMessage(content="goodbye")], + }, + {"configurable": {"user_id": "user1", "conversation_id": "1"}}, ) - assert result == { - "output": [ - AIMessage(content="you said: hello\ngoodbye"), - ] - } + assert result == [ + AIMessage(content="you said: goodbye"), + ] assert store == { ("user1", "1"): ChatMessageHistory( messages=[ HumanMessage(content="hello"), AIMessage(content="you said: hello"), HumanMessage(content="goodbye"), - AIMessage(content="you said: hello\ngoodbye"), + AIMessage(content="you said: goodbye"), ] ) } + + result = with_message_history.invoke( + { + "messages": [HumanMessage(content="meow")], + }, + {"configurable": {"user_id": "user2", "conversation_id": "1"}}, + ) + assert result == [ + AIMessage(content="you said: meow"), + ] + assert store == { + ("user1", "1"): ChatMessageHistory( + messages=[ + HumanMessage(content="hello"), + AIMessage(content="you said: hello"), + HumanMessage(content="goodbye"), + AIMessage(content="you said: goodbye"), + ] + ), + ("user2", "1"): ChatMessageHistory( + messages=[ + HumanMessage(content="meow"), + AIMessage(content="you said: meow"), + ] + ), + } From 7e8a3725f1d73ac1d46719d1a29df8c992144eab Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Fri, 8 Dec 2023 14:08:51 -0500 Subject: [PATCH 07/12] x --- libs/core/langchain_core/runnables/history.py | 22 +++++++++---------- .../unit_tests/runnables/test_history.py | 2 +- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/libs/core/langchain_core/runnables/history.py b/libs/core/langchain_core/runnables/history.py index 276550dd537db..31c864ad271b3 100644 --- a/libs/core/langchain_core/runnables/history.py +++ b/libs/core/langchain_core/runnables/history.py @@ -29,7 +29,7 @@ from typing import Callable, Dict, Union MessagesOrDictWithMessages = Union[Sequence["BaseMessage"], Dict[str, Any]] -GetSessionHistoryCallable = Union[Callable[..., BaseChatMessageHistory]] +GetSessionHistoryCallable = Callable[..., BaseChatMessageHistory] class RunnableWithMessageHistory(RunnableBindingBase): @@ -41,7 +41,7 @@ class RunnableWithMessageHistory(RunnableBindingBase): RunnableWithMessageHistory must always be called with a config that contains session_id, e.g.: - ``{"configurable": {"session_id": ""}}`` + ``{"configurable": {"session_id": ""}}` Example (dict input): .. code-block:: python @@ -109,7 +109,7 @@ def get_session_history( get_session_history=get_session_history, input_messages_key="messages", history_messages_key="history", - session_history_config_specs=[ + history_factory_config=[ ConfigurableFieldSpec( id="user_id", annotation=str, @@ -140,7 +140,7 @@ def get_session_history( input_messages_key: Optional[str] = None output_messages_key: Optional[str] = None history_messages_key: Optional[str] = None - session_history_config_specs: Sequence[ConfigurableFieldSpec] = None + history_factory_config: Sequence[ConfigurableFieldSpec] = None def __init__( self, @@ -153,7 +153,7 @@ def __init__( input_messages_key: Optional[str] = None, output_messages_key: Optional[str] = None, history_messages_key: Optional[str] = None, - session_history_config_specs: Optional[Sequence[ConfigurableFieldSpec]] = None, + history_factory_config: Optional[Sequence[ConfigurableFieldSpec]] = None, **kwargs: Any, ) -> None: """Initialize RunnableWithMessageHistory. @@ -206,7 +206,7 @@ def get_session_history( as output. history_messages_key: Must be specified if the base runnable accepts a dict as input and expects a separate key for historical messages. - session_history_config_specs: Configure fields that should be passed to the + history_factory_config: Configure fields that should be passed to the chat history factory. See ``ConfigurableFieldSpec`` for more details. Specifying these allows you to pass multiple config keys into the get_session_history factory. @@ -225,8 +225,8 @@ def get_session_history( history_chain | runnable.with_listeners(on_end=self._exit_history) ).with_config(run_name="RunnableWithMessageHistory") - if session_history_config_specs: - _config_specs = session_history_config_specs + if history_factory_config: + _config_specs = history_factory_config else: # If not provided, then we'll use the default session_id field _config_specs = [ @@ -246,14 +246,14 @@ def get_session_history( output_messages_key=output_messages_key, bound=bound, history_messages_key=history_messages_key, - session_history_config_specs=_config_specs, + history_factory_config=_config_specs, **kwargs, ) @property def config_specs(self) -> List[ConfigurableFieldSpec]: return get_unique_config_specs( - super().config_specs + list(self.session_history_config_specs) + super().config_specs + list(self.history_factory_config) ) def get_input_schema( @@ -357,7 +357,7 @@ def _exit_history(self, run: Run, config: RunnableConfig) -> None: def _merge_configs(self, *configs: Optional[RunnableConfig]) -> RunnableConfig: config = super()._merge_configs(*configs) expected_keys = [ - field_spec.id for field_spec in self.session_history_config_specs + field_spec.id for field_spec in self.history_factory_config ] configurable = config.get("configurable", {}) diff --git a/libs/core/tests/unit_tests/runnables/test_history.py b/libs/core/tests/unit_tests/runnables/test_history.py index d932693c2a210..c257c6aad66be 100644 --- a/libs/core/tests/unit_tests/runnables/test_history.py +++ b/libs/core/tests/unit_tests/runnables/test_history.py @@ -282,7 +282,7 @@ def get_session_history(user_id: str, conversation_id: str) -> ChatMessageHistor get_session_history=get_session_history, input_messages_key="messages", history_messages_key="history", - session_history_config_specs=[ + history_factory_config=[ ConfigurableFieldSpec( id="user_id", annotation=str, From 6feb3b4ef735fa90d96f5703e171fe6aeba68598 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Fri, 8 Dec 2023 14:09:05 -0500 Subject: [PATCH 08/12] x --- libs/core/langchain_core/runnables/history.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/libs/core/langchain_core/runnables/history.py b/libs/core/langchain_core/runnables/history.py index 31c864ad271b3..f1b494981b86a 100644 --- a/libs/core/langchain_core/runnables/history.py +++ b/libs/core/langchain_core/runnables/history.py @@ -356,9 +356,7 @@ def _exit_history(self, run: Run, config: RunnableConfig) -> None: def _merge_configs(self, *configs: Optional[RunnableConfig]) -> RunnableConfig: config = super()._merge_configs(*configs) - expected_keys = [ - field_spec.id for field_spec in self.history_factory_config - ] + expected_keys = [field_spec.id for field_spec in self.history_factory_config] configurable = config.get("configurable", {}) From 743bf37b7eb1f039fbaea620d10de49d4f79fabb Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Fri, 8 Dec 2023 14:12:35 -0500 Subject: [PATCH 09/12] x --- libs/core/langchain_core/runnables/history.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/core/langchain_core/runnables/history.py b/libs/core/langchain_core/runnables/history.py index f1b494981b86a..6a45f002c394e 100644 --- a/libs/core/langchain_core/runnables/history.py +++ b/libs/core/langchain_core/runnables/history.py @@ -140,7 +140,7 @@ def get_session_history( input_messages_key: Optional[str] = None output_messages_key: Optional[str] = None history_messages_key: Optional[str] = None - history_factory_config: Sequence[ConfigurableFieldSpec] = None + history_factory_config: Sequence[ConfigurableFieldSpec] def __init__( self, From 8f8b8902bfbc41089efb04b9c22b88bf863a40ff Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Fri, 8 Dec 2023 16:16:37 -0500 Subject: [PATCH 10/12] x --- .../unit_tests/runnables/test_history.py | 21 ++++--------------- 1 file changed, 4 insertions(+), 17 deletions(-) diff --git a/libs/core/tests/unit_tests/runnables/test_history.py b/libs/core/tests/unit_tests/runnables/test_history.py index c257c6aad66be..74872bd5dc5c0 100644 --- a/libs/core/tests/unit_tests/runnables/test_history.py +++ b/libs/core/tests/unit_tests/runnables/test_history.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, List, Optional, Sequence, Union +from typing import Any, Callable, List, Sequence, Union from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage from langchain_core.pydantic_v1 import BaseModel @@ -18,10 +18,8 @@ def test_interfaces() -> None: assert str(history) == "System: system\nHuman: human 1\nAI: ai\nHuman: human 2" -def _get_get_session_history( - *, store: Optional[Dict[str, Any]] = None -) -> Callable[..., ChatMessageHistory]: - chat_history_store = store if store is not None else {} +def _get_get_session_history() -> Callable[..., ChatMessageHistory]: + chat_history_store = {} def get_session_history(session_id: str, **kwargs: Any) -> ChatMessageHistory: if session_id not in chat_history_store: @@ -36,24 +34,13 @@ def test_input_messages() -> None: lambda messages: "you said: " + "\n".join(str(m.content) for m in messages if isinstance(m, HumanMessage)) ) - store = {} - get_session_history = _get_get_session_history(store=store) + get_session_history = _get_get_session_history() with_history = RunnableWithMessageHistory(runnable, get_session_history) config: RunnableConfig = {"configurable": {"session_id": "1"}} output = with_history.invoke([HumanMessage(content="hello")], config) assert output == "you said: hello" output = with_history.invoke([HumanMessage(content="good bye")], config) assert output == "you said: hello\ngood bye" - assert store == { - "1": ChatMessageHistory( - messages=[ - HumanMessage(content="hello"), - AIMessage(content="you said: hello"), - HumanMessage(content="good bye"), - AIMessage(content="you said: hello\ngood bye"), - ] - ) - } def test_input_dict() -> None: From 00ed122c0a0e7da9d456943e1585ed958c91c8fb Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Fri, 8 Dec 2023 22:15:23 -0500 Subject: [PATCH 11/12] x --- libs/core/tests/unit_tests/runnables/test_history.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/core/tests/unit_tests/runnables/test_history.py b/libs/core/tests/unit_tests/runnables/test_history.py index 74872bd5dc5c0..0f59b0577f336 100644 --- a/libs/core/tests/unit_tests/runnables/test_history.py +++ b/libs/core/tests/unit_tests/runnables/test_history.py @@ -245,7 +245,7 @@ class RunnableWithChatHistoryInput(BaseModel): def test_using_custom_config_specs() -> None: """Test that we can configure which keys should be passed to the session factory.""" - def _fake_llm(input) -> List[BaseMessage]: + def _fake_llm(input: dict) -> List[BaseMessage]: messages = input["messages"] return [ AIMessage( From a1dfd34d9d53ee32f9a0073c3e7abd23f0342d42 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Sun, 10 Dec 2023 22:25:46 -0500 Subject: [PATCH 12/12] fix mypy errors --- libs/core/tests/unit_tests/runnables/test_history.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/libs/core/tests/unit_tests/runnables/test_history.py b/libs/core/tests/unit_tests/runnables/test_history.py index 0f59b0577f336..1c6923f3b18ac 100644 --- a/libs/core/tests/unit_tests/runnables/test_history.py +++ b/libs/core/tests/unit_tests/runnables/test_history.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, List, Sequence, Union +from typing import Any, Callable, Dict, List, Sequence, Union from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage from langchain_core.pydantic_v1 import BaseModel @@ -131,7 +131,7 @@ def test_output_messages() -> None: ) get_session_history = _get_get_session_history() with_history = RunnableWithMessageHistory( - runnable, + runnable, # type: ignore get_session_history, input_messages_key="input", history_messages_key="history", @@ -245,7 +245,7 @@ class RunnableWithChatHistoryInput(BaseModel): def test_using_custom_config_specs() -> None: """Test that we can configure which keys should be passed to the session factory.""" - def _fake_llm(input: dict) -> List[BaseMessage]: + def _fake_llm(input: Dict[str, Any]) -> List[BaseMessage]: messages = input["messages"] return [ AIMessage( @@ -265,7 +265,7 @@ def get_session_history(user_id: str, conversation_id: str) -> ChatMessageHistor return store[(user_id, conversation_id)] with_message_history = RunnableWithMessageHistory( - runnable, + runnable, # type: ignore get_session_history=get_session_history, input_messages_key="messages", history_messages_key="history",