Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update RunnableWithMessageHistory #14351

Merged
merged 14 commits into from
Dec 12, 2023
120 changes: 95 additions & 25 deletions libs/core/langchain_core/runnables/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,10 @@
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Type,
Union,
)

from langchain_core.chat_history import BaseChatMessageHistory
Expand All @@ -28,8 +25,36 @@
from langchain_core.runnables.config import RunnableConfig
from langchain_core.tracers.schemas import Run

import inspect
from typing import Callable, Dict, Union

MessagesOrDictWithMessages = Union[Sequence["BaseMessage"], Dict[str, Any]]
GetSessionHistoryCallable = Callable[..., BaseChatMessageHistory]
GetSessionHistoryCallable = Union[Callable[..., BaseChatMessageHistory]]
eyurtsev marked this conversation as resolved.
Show resolved Hide resolved


def check_callable_argument_type(callable_: GetSessionHistoryCallable) -> str:
eyurtsev marked this conversation as resolved.
Show resolved Hide resolved
"""Check if the callable accepts a single argument of type 'str' or 'Dict'."""
sig = inspect.signature(callable_)

if len(sig.parameters) == 1:
param_name, param = sig.parameters.popitem()
if param.annotation is str:
return (
f"The callable accepts a single argument of type 'str' "
f"named '{param_name}'."
)
elif param.annotation is Dict:
return (
f"The callable accepts a single argument of type 'Dict' "
f"named '{param_name}'."
)
else:
return (
"The callable accepts a single argument, but its type is "
"not 'str' or 'Dict'."
)
else:
return "The callable does not accept a single argument."


class RunnableWithMessageHistory(RunnableBindingBase):
Expand All @@ -38,7 +63,9 @@ class RunnableWithMessageHistory(RunnableBindingBase):
Base runnable must have inputs and outputs that can be converted to a list of
BaseMessages.

RunnableWithMessageHistory must always be called with a config that contains session_id, e.g.:
RunnableWithMessageHistory must always be called with a config that contains
session_id, e.g.:

``{"configurable": {"session_id": "<SESSION_ID>"}}``

Example (dict input):
Expand Down Expand Up @@ -84,6 +111,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
input_messages_key: Optional[str] = None
output_messages_key: Optional[str] = None
history_messages_key: Optional[str] = None
session_history_config_specs: Sequence[ConfigurableFieldSpec] = None

def __init__(
self,
Expand All @@ -96,6 +124,7 @@ def __init__(
input_messages_key: Optional[str] = None,
output_messages_key: Optional[str] = None,
history_messages_key: Optional[str] = None,
session_history_config_specs: Optional[Sequence[ConfigurableFieldSpec]] = None,
eyurtsev marked this conversation as resolved.
Show resolved Hide resolved
**kwargs: Any,
) -> None:
"""Initialize RunnableWithMessageHistory.
Expand Down Expand Up @@ -135,6 +164,8 @@ def get_session_history(
as output.
history_messages_key: Must be specified if the base runnable accepts a dict
as input and expects a separate key for historical messages.
custom_config_specs: Configure fields that should be passed to the chat
eyurtsev marked this conversation as resolved.
Show resolved Hide resolved
history factory. See ``ConfigurableFieldSpec`` for more details.
**kwargs: Arbitrary additional kwargs to pass to parent class
``RunnableBindingBase`` init.
""" # noqa: E501
Expand All @@ -149,29 +180,36 @@ def get_session_history(
bound = (
history_chain | runnable.with_listeners(on_end=self._exit_history)
).with_config(run_name="RunnableWithMessageHistory")

if session_history_config_specs:
_config_specs = session_history_config_specs
else:
# If not provided, then we'll use the default session_id field
_config_specs = [
ConfigurableFieldSpec(
id="session_id",
annotation=str,
name="Session ID",
description="Unique identifier for a session.",
default="",
is_shared=True,
),
]

super().__init__(
get_session_history=get_session_history,
input_messages_key=input_messages_key,
output_messages_key=output_messages_key,
bound=bound,
history_messages_key=history_messages_key,
session_history_config_specs=_config_specs,
**kwargs,
)

@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
return get_unique_config_specs(
super().config_specs
+ [
ConfigurableFieldSpec(
id="session_id",
annotation=str,
name="Session ID",
description="Unique identifier for a session.",
default="",
is_shared=True,
),
]
super().config_specs + list(self.session_history_config_specs)
)

def get_input_schema(
Expand Down Expand Up @@ -274,16 +312,48 @@ def _exit_history(self, run: Run, config: RunnableConfig) -> None:

def _merge_configs(self, *configs: Optional[RunnableConfig]) -> RunnableConfig:
config = super()._merge_configs(*configs)
# extract session_id
if "session_id" not in config.get("configurable", {}):
expected_keys = [
field_spec.id for field_spec in self.session_history_config_specs
]

configurable = config.get("configurable", {})

missing_keys = set(expected_keys) - set(configurable.keys())

if missing_keys:
example_input = {self.input_messages_key: "foo"}
example_config = {"configurable": {"session_id": "123"}}
example_configurable = {
missing_key: "[your-value-here]" for missing_key in missing_keys
}
example_config = {"configurable": example_configurable}
raise ValueError(
"session_id_id is required."
" Pass it in as part of the config argument to .invoke() or .stream()"
f"\neg. chain.invoke({example_input}, {example_config})"
f"Missing keys {sorted(missing_keys)} in config['configurable'] "
f"Expected keys are {sorted(expected_keys)}."
f"When using via .invoke() or .stream(), pass in a config; "
f"e.g., chain.invoke({example_input}, {example_config})"
)

parameter_names = _get_parameter_names(self.get_session_history)

if len(parameter_names) == 1 and len(expected_keys) == 1:
# If arity = 1, then invoke function by positional arguments
message_history = self.get_session_history(configurable[expected_keys[0]])
else:
# otherwise verify that names of keys patch and invoke by named arguments
if set(expected_keys) != set(parameter_names):
raise ValueError(
f"Expected keys {sorted(expected_keys)} do not match parameter "
f"names {sorted(parameter_names)} of get_session_history."
)

message_history = self.get_session_history(
**{key: configurable[key] for key in expected_keys}
)
# attach message_history
session_id = config["configurable"]["session_id"]
config["configurable"]["message_history"] = self.get_session_history(session_id)
config["configurable"]["message_history"] = message_history
return config


def _get_parameter_names(callable_: GetSessionHistoryCallable) -> List[str]:
"""Get the parameter names of the callable."""
sig = inspect.signature(callable_)
return list(sig.parameters.keys())
Loading