Skip to content

Commit

Permalink
[core/minor] Runnables: Implement a context api (langchain-ai#14046)
Browse files Browse the repository at this point in the history
<!-- Thank you for contributing to LangChain!

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes (if applicable),
  - **Dependencies:** any dependencies required for this change,
- **Tag maintainer:** for a quicker response, tag the relevant
maintainer (see below),
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc:

https://github.com/langchain-ai/langchain/blob/master/.github/CONTRIBUTING.md

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in `docs/extras`
directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->

---------

Co-authored-by: Brace Sproul <[email protected]>
  • Loading branch information
2 people authored and aymeric-roucher committed Dec 11, 2023
1 parent ec7580e commit 99de0be
Show file tree
Hide file tree
Showing 7 changed files with 811 additions and 16 deletions.
2 changes: 2 additions & 0 deletions libs/core/langchain_core/runnables/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
get_config_list,
patch_config,
)
from langchain_core.runnables.context import Context
from langchain_core.runnables.fallbacks import RunnableWithFallbacks
from langchain_core.runnables.passthrough import RunnablePassthrough
from langchain_core.runnables.router import RouterInput, RouterRunnable
Expand All @@ -47,6 +48,7 @@
"ConfigurableField",
"ConfigurableFieldSingleOption",
"ConfigurableFieldMultiOption",
"Context",
"patch_config",
"RouterInput",
"RouterRunnable",
Expand Down
82 changes: 68 additions & 14 deletions libs/core/langchain_core/runnables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from concurrent.futures import FIRST_COMPLETED, wait
from copy import deepcopy
from functools import partial, wraps
from itertools import tee
from itertools import groupby, tee
from operator import itemgetter
from typing import (
TYPE_CHECKING,
Expand All @@ -22,6 +22,7 @@
Mapping,
Optional,
Sequence,
Set,
Tuple,
Type,
TypeVar,
Expand Down Expand Up @@ -1401,9 +1402,46 @@ def get_output_schema(

@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
return get_unique_config_specs(
spec for step in self.steps for spec in step.config_specs
from langchain_core.runnables.context import CONTEXT_CONFIG_PREFIX, _key_from_id

# get all specs
all_specs = [
(spec, idx)
for idx, step in enumerate(self.steps)
for spec in step.config_specs
]
# calculate context dependencies
specs_by_pos = groupby(
[tup for tup in all_specs if tup[0].id.startswith(CONTEXT_CONFIG_PREFIX)],
lambda x: x[1],
)
next_deps: Set[str] = set()
deps_by_pos: Dict[int, Set[str]] = {}
for pos, specs in specs_by_pos:
deps_by_pos[pos] = next_deps
next_deps = next_deps | {spec[0].id for spec in specs}
# assign context dependencies
for pos, (spec, idx) in enumerate(all_specs):
if spec.id.startswith(CONTEXT_CONFIG_PREFIX):
all_specs[pos] = (
ConfigurableFieldSpec(
id=spec.id,
annotation=spec.annotation,
name=spec.name,
default=spec.default,
description=spec.description,
is_shared=spec.is_shared,
dependencies=[
d
for d in deps_by_pos[idx]
if _key_from_id(d) != _key_from_id(spec.id)
]
+ (spec.dependencies or []),
),
idx,
)

return get_unique_config_specs(spec for spec, _ in all_specs)

def __repr__(self) -> str:
return "\n| ".join(
Expand Down Expand Up @@ -1456,8 +1494,10 @@ def __ror__(
)

def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
# setup callbacks
config = ensure_config(config)
from langchain_core.runnables.context import config_with_context

# setup callbacks and context
config = config_with_context(ensure_config(config), self.steps)
callback_manager = get_callback_manager_for_config(config)
# start the root run
run_manager = callback_manager.on_chain_start(
Expand Down Expand Up @@ -1488,8 +1528,10 @@ async def ainvoke(
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Output:
# setup callbacks
config = ensure_config(config)
from langchain_core.runnables.context import aconfig_with_context

# setup callbacks and context
config = aconfig_with_context(ensure_config(config), self.steps)
callback_manager = get_async_callback_manager_for_config(config)
# start the root run
run_manager = await callback_manager.on_chain_start(
Expand Down Expand Up @@ -1523,12 +1565,16 @@ def batch(
**kwargs: Optional[Any],
) -> List[Output]:
from langchain_core.callbacks.manager import CallbackManager
from langchain_core.runnables.context import config_with_context

if not inputs:
return []

# setup callbacks
configs = get_config_list(config, len(inputs))
# setup callbacks and context
configs = [
config_with_context(c, self.steps)
for c in get_config_list(config, len(inputs))
]
callback_managers = [
CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
Expand Down Expand Up @@ -1641,15 +1687,17 @@ async def abatch(
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Output]:
from langchain_core.callbacks.manager import (
AsyncCallbackManager,
)
from langchain_core.callbacks.manager import AsyncCallbackManager
from langchain_core.runnables.context import aconfig_with_context

if not inputs:
return []

# setup callbacks
configs = get_config_list(config, len(inputs))
# setup callbacks and context
configs = [
aconfig_with_context(c, self.steps)
for c in get_config_list(config, len(inputs))
]
callback_managers = [
AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
Expand Down Expand Up @@ -1763,7 +1811,10 @@ def _transform(
run_manager: CallbackManagerForChainRun,
config: RunnableConfig,
) -> Iterator[Output]:
from langchain_core.runnables.context import config_with_context

steps = [self.first] + self.middle + [self.last]
config = config_with_context(config, self.steps)

# transform the input stream of each step with the next
# steps that don't natively support transforming an input stream will
Expand All @@ -1787,7 +1838,10 @@ async def _atransform(
run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig,
) -> AsyncIterator[Output]:
from langchain_core.runnables.context import aconfig_with_context

steps = [self.first] + self.middle + [self.last]
config = aconfig_with_context(config, self.steps)

# stream the last steps
# transform the input stream of each step with the next
Expand Down
13 changes: 12 additions & 1 deletion libs/core/langchain_core/runnables/branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
get_callback_manager_for_config,
patch_config,
)
from langchain_core.runnables.context import (
CONTEXT_CONFIG_PREFIX,
CONTEXT_CONFIG_SUFFIX_SET,
)
from langchain_core.runnables.utils import (
ConfigurableFieldSpec,
Input,
Expand Down Expand Up @@ -148,7 +152,7 @@ def get_input_schema(

@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
return get_unique_config_specs(
specs = get_unique_config_specs(
spec
for step in (
[self.default]
Expand All @@ -157,6 +161,13 @@ def config_specs(self) -> List[ConfigurableFieldSpec]:
)
for spec in step.config_specs
)
if any(
s.id.startswith(CONTEXT_CONFIG_PREFIX)
and s.id.endswith(CONTEXT_CONFIG_SUFFIX_SET)
for s in specs
):
raise ValueError("RunnableBranch cannot contain context setters.")
return specs

def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
Expand Down
Loading

0 comments on commit 99de0be

Please sign in to comment.