From 99de0be8d33a92bfd7ba0b84a0c1426921adfcf9 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Wed, 6 Dec 2023 15:02:29 -0800 Subject: [PATCH] [core/minor] Runnables: Implement a context api (#14046) --------- Co-authored-by: Brace Sproul --- .../core/langchain_core/runnables/__init__.py | 2 + libs/core/langchain_core/runnables/base.py | 82 +++- libs/core/langchain_core/runnables/branch.py | 13 +- libs/core/langchain_core/runnables/context.py | 313 +++++++++++++ libs/core/langchain_core/runnables/utils.py | 5 +- .../unit_tests/runnables/test_context.py | 411 ++++++++++++++++++ .../unit_tests/runnables/test_imports.py | 1 + 7 files changed, 811 insertions(+), 16 deletions(-) create mode 100644 libs/core/langchain_core/runnables/context.py create mode 100644 libs/core/tests/unit_tests/runnables/test_context.py diff --git a/libs/core/langchain_core/runnables/__init__.py b/libs/core/langchain_core/runnables/__init__.py index b51a94eea3f23..e6ace8ceb2404 100644 --- a/libs/core/langchain_core/runnables/__init__.py +++ b/libs/core/langchain_core/runnables/__init__.py @@ -30,6 +30,7 @@ get_config_list, patch_config, ) +from langchain_core.runnables.context import Context from langchain_core.runnables.fallbacks import RunnableWithFallbacks from langchain_core.runnables.passthrough import RunnablePassthrough from langchain_core.runnables.router import RouterInput, RouterRunnable @@ -47,6 +48,7 @@ "ConfigurableField", "ConfigurableFieldSingleOption", "ConfigurableFieldMultiOption", + "Context", "patch_config", "RouterInput", "RouterRunnable", diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index d0d1d8d06f0f3..de28dbb83f594 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -7,7 +7,7 @@ from concurrent.futures import FIRST_COMPLETED, wait from copy import deepcopy from functools import partial, wraps -from itertools import tee +from itertools import groupby, tee from operator import itemgetter from typing import ( TYPE_CHECKING, @@ -22,6 +22,7 @@ Mapping, Optional, Sequence, + Set, Tuple, Type, TypeVar, @@ -1401,9 +1402,46 @@ def get_output_schema( @property def config_specs(self) -> List[ConfigurableFieldSpec]: - return get_unique_config_specs( - spec for step in self.steps for spec in step.config_specs + from langchain_core.runnables.context import CONTEXT_CONFIG_PREFIX, _key_from_id + + # get all specs + all_specs = [ + (spec, idx) + for idx, step in enumerate(self.steps) + for spec in step.config_specs + ] + # calculate context dependencies + specs_by_pos = groupby( + [tup for tup in all_specs if tup[0].id.startswith(CONTEXT_CONFIG_PREFIX)], + lambda x: x[1], ) + next_deps: Set[str] = set() + deps_by_pos: Dict[int, Set[str]] = {} + for pos, specs in specs_by_pos: + deps_by_pos[pos] = next_deps + next_deps = next_deps | {spec[0].id for spec in specs} + # assign context dependencies + for pos, (spec, idx) in enumerate(all_specs): + if spec.id.startswith(CONTEXT_CONFIG_PREFIX): + all_specs[pos] = ( + ConfigurableFieldSpec( + id=spec.id, + annotation=spec.annotation, + name=spec.name, + default=spec.default, + description=spec.description, + is_shared=spec.is_shared, + dependencies=[ + d + for d in deps_by_pos[idx] + if _key_from_id(d) != _key_from_id(spec.id) + ] + + (spec.dependencies or []), + ), + idx, + ) + + return get_unique_config_specs(spec for spec, _ in all_specs) def __repr__(self) -> str: return "\n| ".join( @@ -1456,8 +1494,10 @@ def __ror__( ) def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output: - # setup callbacks - config = ensure_config(config) + from langchain_core.runnables.context import config_with_context + + # setup callbacks and context + config = config_with_context(ensure_config(config), self.steps) callback_manager = get_callback_manager_for_config(config) # start the root run run_manager = callback_manager.on_chain_start( @@ -1488,8 +1528,10 @@ async def ainvoke( config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], ) -> Output: - # setup callbacks - config = ensure_config(config) + from langchain_core.runnables.context import aconfig_with_context + + # setup callbacks and context + config = aconfig_with_context(ensure_config(config), self.steps) callback_manager = get_async_callback_manager_for_config(config) # start the root run run_manager = await callback_manager.on_chain_start( @@ -1523,12 +1565,16 @@ def batch( **kwargs: Optional[Any], ) -> List[Output]: from langchain_core.callbacks.manager import CallbackManager + from langchain_core.runnables.context import config_with_context if not inputs: return [] - # setup callbacks - configs = get_config_list(config, len(inputs)) + # setup callbacks and context + configs = [ + config_with_context(c, self.steps) + for c in get_config_list(config, len(inputs)) + ] callback_managers = [ CallbackManager.configure( inheritable_callbacks=config.get("callbacks"), @@ -1641,15 +1687,17 @@ async def abatch( return_exceptions: bool = False, **kwargs: Optional[Any], ) -> List[Output]: - from langchain_core.callbacks.manager import ( - AsyncCallbackManager, - ) + from langchain_core.callbacks.manager import AsyncCallbackManager + from langchain_core.runnables.context import aconfig_with_context if not inputs: return [] - # setup callbacks - configs = get_config_list(config, len(inputs)) + # setup callbacks and context + configs = [ + aconfig_with_context(c, self.steps) + for c in get_config_list(config, len(inputs)) + ] callback_managers = [ AsyncCallbackManager.configure( inheritable_callbacks=config.get("callbacks"), @@ -1763,7 +1811,10 @@ def _transform( run_manager: CallbackManagerForChainRun, config: RunnableConfig, ) -> Iterator[Output]: + from langchain_core.runnables.context import config_with_context + steps = [self.first] + self.middle + [self.last] + config = config_with_context(config, self.steps) # transform the input stream of each step with the next # steps that don't natively support transforming an input stream will @@ -1787,7 +1838,10 @@ async def _atransform( run_manager: AsyncCallbackManagerForChainRun, config: RunnableConfig, ) -> AsyncIterator[Output]: + from langchain_core.runnables.context import aconfig_with_context + steps = [self.first] + self.middle + [self.last] + config = aconfig_with_context(config, self.steps) # stream the last steps # transform the input stream of each step with the next diff --git a/libs/core/langchain_core/runnables/branch.py b/libs/core/langchain_core/runnables/branch.py index 96e9685dc6d5e..5f7c1b009dc10 100644 --- a/libs/core/langchain_core/runnables/branch.py +++ b/libs/core/langchain_core/runnables/branch.py @@ -26,6 +26,10 @@ get_callback_manager_for_config, patch_config, ) +from langchain_core.runnables.context import ( + CONTEXT_CONFIG_PREFIX, + CONTEXT_CONFIG_SUFFIX_SET, +) from langchain_core.runnables.utils import ( ConfigurableFieldSpec, Input, @@ -148,7 +152,7 @@ def get_input_schema( @property def config_specs(self) -> List[ConfigurableFieldSpec]: - return get_unique_config_specs( + specs = get_unique_config_specs( spec for step in ( [self.default] @@ -157,6 +161,13 @@ def config_specs(self) -> List[ConfigurableFieldSpec]: ) for spec in step.config_specs ) + if any( + s.id.startswith(CONTEXT_CONFIG_PREFIX) + and s.id.endswith(CONTEXT_CONFIG_SUFFIX_SET) + for s in specs + ): + raise ValueError("RunnableBranch cannot contain context setters.") + return specs def invoke( self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any diff --git a/libs/core/langchain_core/runnables/context.py b/libs/core/langchain_core/runnables/context.py new file mode 100644 index 0000000000000..c29e88b2977b3 --- /dev/null +++ b/libs/core/langchain_core/runnables/context.py @@ -0,0 +1,313 @@ +import asyncio +import threading +from collections import defaultdict +from functools import partial +from itertools import groupby +from typing import ( + Any, + Awaitable, + Callable, + DefaultDict, + Dict, + List, + Mapping, + Optional, + Type, + TypeVar, + Union, +) + +from langchain_core.runnables.base import ( + Runnable, + RunnableSerializable, + coerce_to_runnable, +) +from langchain_core.runnables.config import RunnableConfig, patch_config +from langchain_core.runnables.utils import ConfigurableFieldSpec, Input, Output + +T = TypeVar("T") +Values = Dict[Union[asyncio.Event, threading.Event], Any] +CONTEXT_CONFIG_PREFIX = "__context__/" +CONTEXT_CONFIG_SUFFIX_GET = "/get" +CONTEXT_CONFIG_SUFFIX_SET = "/set" + + +async def _asetter(done: asyncio.Event, values: Values, value: T) -> T: + values[done] = value + done.set() + return value + + +async def _agetter(done: asyncio.Event, values: Values) -> Any: + await done.wait() + return values[done] + + +def _setter(done: threading.Event, values: Values, value: T) -> T: + values[done] = value + done.set() + return value + + +def _getter(done: threading.Event, values: Values) -> Any: + done.wait() + return values[done] + + +def _key_from_id(id_: str) -> str: + wout_prefix = id_.split(CONTEXT_CONFIG_PREFIX, maxsplit=1)[1] + if wout_prefix.endswith(CONTEXT_CONFIG_SUFFIX_GET): + return wout_prefix[: -len(CONTEXT_CONFIG_SUFFIX_GET)] + elif wout_prefix.endswith(CONTEXT_CONFIG_SUFFIX_SET): + return wout_prefix[: -len(CONTEXT_CONFIG_SUFFIX_SET)] + else: + raise ValueError(f"Invalid context config id {id_}") + + +def _config_with_context( + config: RunnableConfig, + steps: List[Runnable], + setter: Callable, + getter: Callable, + event_cls: Union[Type[threading.Event], Type[asyncio.Event]], +) -> RunnableConfig: + if any(k.startswith(CONTEXT_CONFIG_PREFIX) for k in config.get("configurable", {})): + return config + + context_specs = [ + (spec, i) + for i, step in enumerate(steps) + for spec in step.config_specs + if spec.id.startswith(CONTEXT_CONFIG_PREFIX) + ] + grouped_by_key = { + key: list(group) + for key, group in groupby( + sorted(context_specs, key=lambda s: s[0].id), + key=lambda s: _key_from_id(s[0].id), + ) + } + deps_by_key = { + key: set( + _key_from_id(dep) for spec in group for dep in (spec[0].dependencies or []) + ) + for key, group in grouped_by_key.items() + } + + values: Values = {} + events: DefaultDict[str, Union[asyncio.Event, threading.Event]] = defaultdict( + event_cls + ) + context_funcs: Dict[str, Callable[[], Any]] = {} + for key, group in grouped_by_key.items(): + getters = [s for s in group if s[0].id.endswith(CONTEXT_CONFIG_SUFFIX_GET)] + setters = [s for s in group if s[0].id.endswith(CONTEXT_CONFIG_SUFFIX_SET)] + + for dep in deps_by_key[key]: + if key in deps_by_key[dep]: + raise ValueError( + f"Deadlock detected between context keys {key} and {dep}" + ) + if len(getters) < 1: + raise ValueError(f"Expected at least one getter for context key {key}") + if len(setters) != 1: + raise ValueError(f"Expected exactly one setter for context key {key}") + setter_idx = setters[0][1] + if any(getter_idx < setter_idx for _, getter_idx in getters): + raise ValueError( + f"Context setter for key {key} must be defined after all getters." + ) + + context_funcs[getters[0][0].id] = partial(getter, events[key], values) + context_funcs[setters[0][0].id] = partial(setter, events[key], values) + + return patch_config(config, configurable=context_funcs) + + +def aconfig_with_context( + config: RunnableConfig, + steps: List[Runnable], +) -> RunnableConfig: + return _config_with_context(config, steps, _asetter, _agetter, asyncio.Event) + + +def config_with_context( + config: RunnableConfig, + steps: List[Runnable], +) -> RunnableConfig: + return _config_with_context(config, steps, _setter, _getter, threading.Event) + + +class ContextGet(RunnableSerializable): + prefix: str = "" + + key: Union[str, List[str]] + + @property + def ids(self) -> List[str]: + prefix = self.prefix + "/" if self.prefix else "" + keys = self.key if isinstance(self.key, list) else [self.key] + return [ + f"{CONTEXT_CONFIG_PREFIX}{prefix}{k}{CONTEXT_CONFIG_SUFFIX_GET}" + for k in keys + ] + + @property + def config_specs(self) -> List[ConfigurableFieldSpec]: + return super().config_specs + [ + ConfigurableFieldSpec( + id=id_, + annotation=Callable[[], Any], + ) + for id_ in self.ids + ] + + def invoke(self, input: Any, config: Optional[RunnableConfig] = None) -> Any: + config = config or {} + configurable = config.get("configurable", {}) + if isinstance(self.key, list): + return {key: configurable[id_]() for key, id_ in zip(self.key, self.ids)} + else: + return configurable[self.ids[0]]() + + async def ainvoke( + self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any + ) -> Any: + config = config or {} + configurable = config.get("configurable", {}) + if isinstance(self.key, list): + values = await asyncio.gather(*(configurable[id_]() for id_ in self.ids)) + return {key: value for key, value in zip(self.key, values)} + else: + return await configurable[self.ids[0]]() + + +SetValue = Union[ + Runnable[Input, Output], + Callable[[Input], Output], + Callable[[Input], Awaitable[Output]], + Any, +] + + +def _coerce_set_value(value: SetValue) -> Runnable[Input, Output]: + if not isinstance(value, Runnable) and not callable(value): + return coerce_to_runnable(lambda _: value) + return coerce_to_runnable(value) + + +class ContextSet(RunnableSerializable): + prefix: str = "" + + keys: Mapping[str, Optional[Runnable]] + + class Config: + arbitrary_types_allowed = True + + def __init__( + self, + key: Optional[str] = None, + value: Optional[SetValue] = None, + prefix: str = "", + **kwargs: SetValue, + ): + if key is not None: + kwargs[key] = value + super().__init__( + keys={ + k: _coerce_set_value(v) if v is not None else None + for k, v in kwargs.items() + }, + prefix=prefix, + ) + + @property + def ids(self) -> List[str]: + prefix = self.prefix + "/" if self.prefix else "" + return [ + f"{CONTEXT_CONFIG_PREFIX}{prefix}{key}{CONTEXT_CONFIG_SUFFIX_SET}" + for key in self.keys + ] + + @property + def config_specs(self) -> List[ConfigurableFieldSpec]: + mapper_config_specs = [ + s + for mapper in self.keys.values() + if mapper is not None + for s in mapper.config_specs + ] + for spec in mapper_config_specs: + if spec.id.endswith(CONTEXT_CONFIG_SUFFIX_GET): + getter_key = spec.id.split("/")[1] + if getter_key in self.keys: + raise ValueError( + f"Circular reference in context setter for key {getter_key}" + ) + return super().config_specs + [ + ConfigurableFieldSpec( + id=id_, + annotation=Callable[[], Any], + ) + for id_ in self.ids + ] + + def invoke(self, input: Any, config: Optional[RunnableConfig] = None) -> Any: + config = config or {} + configurable = config.get("configurable", {}) + for id_, mapper in zip(self.ids, self.keys.values()): + if mapper is not None: + configurable[id_](mapper.invoke(input, config)) + else: + configurable[id_](input) + return input + + async def ainvoke( + self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any + ) -> Any: + config = config or {} + configurable = config.get("configurable", {}) + for id_, mapper in zip(self.ids, self.keys.values()): + if mapper is not None: + await configurable[id_](await mapper.ainvoke(input, config)) + else: + await configurable[id_](input) + return input + + +class Context: + @staticmethod + def create_scope(scope: str, /) -> "PrefixContext": + return PrefixContext(prefix=scope) + + @staticmethod + def getter(key: Union[str, List[str]], /) -> ContextGet: + return ContextGet(key=key) + + @staticmethod + def setter( + _key: Optional[str] = None, + _value: Optional[SetValue] = None, + /, + **kwargs: SetValue, + ) -> ContextSet: + return ContextSet(_key, _value, prefix="", **kwargs) + + +class PrefixContext: + prefix: str = "" + + def __init__(self, prefix: str = ""): + self.prefix = prefix + + def getter(self, key: Union[str, List[str]], /) -> ContextGet: + return ContextGet(key=key, prefix=self.prefix) + + def setter( + self, + _key: Optional[str] = None, + _value: Optional[SetValue] = None, + /, + **kwargs: SetValue, + ) -> ContextSet: + return ContextSet(_key, _value, prefix=self.prefix, **kwargs) diff --git a/libs/core/langchain_core/runnables/utils.py b/libs/core/langchain_core/runnables/utils.py index cd7652bff3319..7570caeca251f 100644 --- a/libs/core/langchain_core/runnables/utils.py +++ b/libs/core/langchain_core/runnables/utils.py @@ -308,13 +308,16 @@ class ConfigurableFieldSpec(NamedTuple): description: Optional[str] = None default: Any = None is_shared: bool = False + dependencies: Optional[List[str]] = None def get_unique_config_specs( specs: Iterable[ConfigurableFieldSpec], ) -> List[ConfigurableFieldSpec]: """Get the unique config specs from a sequence of config specs.""" - grouped = groupby(sorted(specs, key=lambda s: s.id), lambda s: s.id) + grouped = groupby( + sorted(specs, key=lambda s: (s.id, *(s.dependencies or []))), lambda s: s.id + ) unique: List[ConfigurableFieldSpec] = [] for id, dupes in grouped: first = next(dupes) diff --git a/libs/core/tests/unit_tests/runnables/test_context.py b/libs/core/tests/unit_tests/runnables/test_context.py new file mode 100644 index 0000000000000..21fe1992d768d --- /dev/null +++ b/libs/core/tests/unit_tests/runnables/test_context.py @@ -0,0 +1,411 @@ +from typing import Any, Callable, List, NamedTuple, Union + +import pytest + +from langchain_core.output_parsers.string import StrOutputParser +from langchain_core.prompt_values import StringPromptValue +from langchain_core.prompts.prompt import PromptTemplate +from langchain_core.runnables.base import Runnable, RunnableLambda +from langchain_core.runnables.context import Context +from langchain_core.runnables.passthrough import RunnablePassthrough +from langchain_core.runnables.utils import aadd, add +from tests.unit_tests.fake.llm import FakeListLLM, FakeStreamingListLLM + + +class TestCase(NamedTuple): + input: Any + output: Any + + +def seq_naive_rag() -> Runnable: + context = [ + "Hi there!", + "How are you?", + "What's your name?", + ] + + retriever = RunnableLambda(lambda x: context) + prompt = PromptTemplate.from_template("{context} {question}") + llm = FakeListLLM(responses=["hello"]) + + return ( + Context.setter("input") + | { + "context": retriever | Context.setter("context"), + "question": RunnablePassthrough(), + } + | prompt + | llm + | StrOutputParser() + | { + "result": RunnablePassthrough(), + "context": Context.getter("context"), + "input": Context.getter("input"), + } + ) + + +def seq_naive_rag_alt() -> Runnable: + context = [ + "Hi there!", + "How are you?", + "What's your name?", + ] + + retriever = RunnableLambda(lambda x: context) + prompt = PromptTemplate.from_template("{context} {question}") + llm = FakeListLLM(responses=["hello"]) + + return ( + Context.setter("input") + | { + "context": retriever | Context.setter("context"), + "question": RunnablePassthrough(), + } + | prompt + | llm + | StrOutputParser() + | Context.setter("result") + | Context.getter(["context", "input", "result"]) + ) + + +def seq_naive_rag_scoped() -> Runnable: + context = [ + "Hi there!", + "How are you?", + "What's your name?", + ] + + retriever = RunnableLambda(lambda x: context) + prompt = PromptTemplate.from_template("{context} {question}") + llm = FakeListLLM(responses=["hello"]) + + scoped = Context.create_scope("a_scope") + + return ( + Context.setter("input") + | { + "context": retriever | Context.setter("context"), + "question": RunnablePassthrough(), + "scoped": scoped.setter("context") | scoped.getter("context"), + } + | prompt + | llm + | StrOutputParser() + | Context.setter("result") + | Context.getter(["context", "input", "result"]) + ) + + +test_cases = [ + ( + Context.setter("foo") | Context.getter("foo"), + ( + TestCase("foo", "foo"), + TestCase("bar", "bar"), + ), + ), + ( + Context.setter("input") | {"bar": Context.getter("input")}, + ( + TestCase("foo", {"bar": "foo"}), + TestCase("bar", {"bar": "bar"}), + ), + ), + ( + {"bar": Context.setter("input")} | Context.getter("input"), + ( + TestCase("foo", "foo"), + TestCase("bar", "bar"), + ), + ), + ( + ( + PromptTemplate.from_template("{foo} {bar}") + | Context.setter("prompt") + | FakeListLLM(responses=["hello"]) + | StrOutputParser() + | { + "response": RunnablePassthrough(), + "prompt": Context.getter("prompt"), + } + ), + ( + TestCase( + {"foo": "foo", "bar": "bar"}, + {"response": "hello", "prompt": StringPromptValue(text="foo bar")}, + ), + TestCase( + {"foo": "bar", "bar": "foo"}, + {"response": "hello", "prompt": StringPromptValue(text="bar foo")}, + ), + ), + ), + ( + ( + PromptTemplate.from_template("{foo} {bar}") + | Context.setter("prompt", prompt_str=lambda x: x.to_string()) + | FakeListLLM(responses=["hello"]) + | StrOutputParser() + | { + "response": RunnablePassthrough(), + "prompt": Context.getter("prompt"), + "prompt_str": Context.getter("prompt_str"), + } + ), + ( + TestCase( + {"foo": "foo", "bar": "bar"}, + { + "response": "hello", + "prompt": StringPromptValue(text="foo bar"), + "prompt_str": "foo bar", + }, + ), + TestCase( + {"foo": "bar", "bar": "foo"}, + { + "response": "hello", + "prompt": StringPromptValue(text="bar foo"), + "prompt_str": "bar foo", + }, + ), + ), + ), + ( + ( + PromptTemplate.from_template("{foo} {bar}") + | Context.setter(prompt_str=lambda x: x.to_string()) + | FakeListLLM(responses=["hello"]) + | StrOutputParser() + | { + "response": RunnablePassthrough(), + "prompt_str": Context.getter("prompt_str"), + } + ), + ( + TestCase( + {"foo": "foo", "bar": "bar"}, + {"response": "hello", "prompt_str": "foo bar"}, + ), + TestCase( + {"foo": "bar", "bar": "foo"}, + {"response": "hello", "prompt_str": "bar foo"}, + ), + ), + ), + ( + ( + PromptTemplate.from_template("{foo} {bar}") + | Context.setter("prompt_str", lambda x: x.to_string()) + | FakeListLLM(responses=["hello"]) + | StrOutputParser() + | { + "response": RunnablePassthrough(), + "prompt_str": Context.getter("prompt_str"), + } + ), + ( + TestCase( + {"foo": "foo", "bar": "bar"}, + {"response": "hello", "prompt_str": "foo bar"}, + ), + TestCase( + {"foo": "bar", "bar": "foo"}, + {"response": "hello", "prompt_str": "bar foo"}, + ), + ), + ), + ( + ( + PromptTemplate.from_template("{foo} {bar}") + | Context.setter("prompt") + | FakeStreamingListLLM(responses=["hello"]) + | StrOutputParser() + | { + "response": RunnablePassthrough(), + "prompt": Context.getter("prompt"), + } + ), + ( + TestCase( + {"foo": "foo", "bar": "bar"}, + {"response": "hello", "prompt": StringPromptValue(text="foo bar")}, + ), + TestCase( + {"foo": "bar", "bar": "foo"}, + {"response": "hello", "prompt": StringPromptValue(text="bar foo")}, + ), + ), + ), + ( + seq_naive_rag, + ( + TestCase( + "What up", + { + "result": "hello", + "context": [ + "Hi there!", + "How are you?", + "What's your name?", + ], + "input": "What up", + }, + ), + TestCase( + "Howdy", + { + "result": "hello", + "context": [ + "Hi there!", + "How are you?", + "What's your name?", + ], + "input": "Howdy", + }, + ), + ), + ), + ( + seq_naive_rag_alt, + ( + TestCase( + "What up", + { + "result": "hello", + "context": [ + "Hi there!", + "How are you?", + "What's your name?", + ], + "input": "What up", + }, + ), + TestCase( + "Howdy", + { + "result": "hello", + "context": [ + "Hi there!", + "How are you?", + "What's your name?", + ], + "input": "Howdy", + }, + ), + ), + ), + ( + seq_naive_rag_scoped, + ( + TestCase( + "What up", + { + "result": "hello", + "context": [ + "Hi there!", + "How are you?", + "What's your name?", + ], + "input": "What up", + }, + ), + TestCase( + "Howdy", + { + "result": "hello", + "context": [ + "Hi there!", + "How are you?", + "What's your name?", + ], + "input": "Howdy", + }, + ), + ), + ), +] + + +@pytest.mark.parametrize("runnable, cases", test_cases) +async def test_context_runnables( + runnable: Union[Runnable, Callable[[], Runnable]], cases: List[TestCase] +) -> None: + runnable = runnable if isinstance(runnable, Runnable) else runnable() + assert runnable.invoke(cases[0].input) == cases[0].output + assert await runnable.ainvoke(cases[1].input) == cases[1].output + assert runnable.batch([case.input for case in cases]) == [ + case.output for case in cases + ] + assert await runnable.abatch([case.input for case in cases]) == [ + case.output for case in cases + ] + assert add(runnable.stream(cases[0].input)) == cases[0].output + assert await aadd(runnable.astream(cases[1].input)) == cases[1].output + + +def test_runnable_context_seq_key_not_found() -> None: + seq: Runnable = {"bar": Context.setter("input")} | Context.getter("foo") + + with pytest.raises(ValueError): + seq.invoke("foo") + + +def test_runnable_context_seq_key_order() -> None: + seq: Runnable = {"bar": Context.getter("foo")} | Context.setter("foo") + + with pytest.raises(ValueError): + seq.invoke("foo") + + +def test_runnable_context_deadlock() -> None: + seq: Runnable = { + "bar": Context.setter("input") | Context.getter("foo"), + "foo": Context.setter("foo") | Context.getter("input"), + } | RunnablePassthrough() + + with pytest.raises(ValueError): + seq.invoke("foo") + + +def test_runnable_context_seq_key_circular_ref() -> None: + seq: Runnable = { + "bar": Context.setter(input=Context.getter("input")) + } | Context.getter("foo") + + with pytest.raises(ValueError): + seq.invoke("foo") + + +async def test_runnable_seq_streaming_chunks() -> None: + chain: Runnable = ( + PromptTemplate.from_template("{foo} {bar}") + | Context.setter("prompt") + | FakeStreamingListLLM(responses=["hello"]) + | StrOutputParser() + | { + "response": RunnablePassthrough(), + "prompt": Context.getter("prompt"), + } + ) + + chunks = [c for c in chain.stream({"foo": "foo", "bar": "bar"})] + achunks = [c async for c in chain.astream({"foo": "foo", "bar": "bar"})] + for c in chunks: + assert c in achunks + for c in achunks: + assert c in chunks + + assert len(chunks) == 6 + assert [c for c in chunks if c.get("response")] == [ + {"response": "h"}, + {"response": "e"}, + {"response": "l"}, + {"response": "l"}, + {"response": "o"}, + ] + assert [c for c in chunks if c.get("prompt")] == [ + {"prompt": StringPromptValue(text="foo bar")}, + ] diff --git a/libs/core/tests/unit_tests/runnables/test_imports.py b/libs/core/tests/unit_tests/runnables/test_imports.py index 935571ed12a77..0a2eb92edc27c 100644 --- a/libs/core/tests/unit_tests/runnables/test_imports.py +++ b/libs/core/tests/unit_tests/runnables/test_imports.py @@ -2,6 +2,7 @@ EXPECTED_ALL = [ "AddableDict", + "Context", "ConfigurableField", "ConfigurableFieldSingleOption", "ConfigurableFieldMultiOption",