Skip to content

Commit

Permalink
Fix executor
Browse files Browse the repository at this point in the history
  • Loading branch information
nfcampos committed Dec 29, 2023
1 parent 9bb1fbc commit 4e4b119
Showing 1 changed file with 60 additions and 13 deletions.
73 changes: 60 additions & 13 deletions libs/core/langchain_core/runnables/config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from __future__ import annotations

import asyncio
from concurrent.futures import Executor, ThreadPoolExecutor
from concurrent.futures import Executor, Future, ThreadPoolExecutor
from contextlib import contextmanager
from contextvars import Context, copy_context
from contextvars import ContextVar, copy_context
from functools import partial
from typing import (
TYPE_CHECKING,
Expand All @@ -12,6 +12,8 @@
Callable,
Dict,
Generator,
Iterable,
Iterator,
List,
Optional,
TypeVar,
Expand Down Expand Up @@ -94,6 +96,11 @@ class RunnableConfig(TypedDict, total=False):
"""


var_child_runnable_config = ContextVar(
"child_runnable_config", default=RunnableConfig()
)


def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
"""Ensure that a config is a dict with all keys present.
Expand All @@ -110,6 +117,10 @@ def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
callbacks=None,
recursion_limit=25,
)
if var_config := var_child_runnable_config.get():
empty.update(
cast(RunnableConfig, {k: v for k, v in var_config.items() if v is not None})
)
if config is not None:
empty.update(
cast(RunnableConfig, {k: v for k, v in config.items() if v is not None})
Expand Down Expand Up @@ -391,9 +402,51 @@ def get_async_callback_manager_for_config(
)


def _set_context(context: Context) -> None:
for var, value in context.items():
var.set(value)
P = ParamSpec("P")
T = TypeVar("T")


class ContextThreadPoolExecutor(ThreadPoolExecutor):
"""ThreadPoolExecutor that copies the context to the child thread."""

def submit( # type: ignore[override]
self,
func: Callable[P, T],
*args: P.args,
**kwargs: P.kwargs,
) -> Future[T]:
"""Submit a function to the executor.
Args:
func (Callable[..., T]): The function to submit.
*args (Any): The positional arguments to the function.
**kwargs (Any): The keyword arguments to the function.
Returns:
Future[T]: The future for the function.
"""
return super().submit(
cast(Callable[..., T], partial(copy_context().run, func, *args, **kwargs))
)

def map(
self,
fn: Callable[..., T],
*iterables: Iterable[Any],
timeout: float | None = None,
chunksize: int = 1,
) -> Iterator[T]:
contexts = [copy_context() for _ in range(len(iterables[0]))] # type: ignore[arg-type]

def _wrapped_fn(*args: Any) -> T:
return contexts.pop().run(fn, *args)

return super().map(
_wrapped_fn,
*iterables,
timeout=timeout,
chunksize=chunksize,
)


@contextmanager
Expand All @@ -409,18 +462,12 @@ def get_executor_for_config(
Generator[Executor, None, None]: The executor.
"""
config = config or {}
with ThreadPoolExecutor(
max_workers=config.get("max_concurrency"),
initializer=_set_context,
initargs=(copy_context(),),
with ContextThreadPoolExecutor(
max_workers=config.get("max_concurrency")
) as executor:
yield executor


P = ParamSpec("P")
T = TypeVar("T")


async def run_in_executor(
executor_or_config: Optional[Union[Executor, RunnableConfig]],
func: Callable[P, T],
Expand Down

0 comments on commit 4e4b119

Please sign in to comment.