diff --git a/.github/workflows/_test_langgraph.yml b/.github/workflows/_test_langgraph.yml index 2708d0f238..ba1cbd08e3 100644 --- a/.github/workflows/_test_langgraph.yml +++ b/.github/workflows/_test_langgraph.yml @@ -17,21 +17,11 @@ jobs: - "3.11" - "3.12" - "3.13" - core-version: - - "latest" - ff-send-v2: - - "false" - include: - - python-version: "3.11" - core-version: ">=0.2.42,<0.3.0" - - python-version: "3.11" - core-version: "latest" - ff-send-v2: "true" defaults: run: working-directory: libs/langgraph - name: "test #${{ matrix.python-version }} (langchain-core: ${{ matrix.core-version }}, ff-send-v2: ${{ matrix.ff-send-v2 }})" + name: "test #${{ matrix.python-version }}" steps: - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} + Poetry ${{ env.POETRY_VERSION }} @@ -51,14 +41,9 @@ jobs: shell: bash run: | poetry install --with dev - if [ "${{ matrix.core-version }}" != "latest" ]; then - poetry run pip install "langchain-core${{ matrix.core-version }}" - fi - name: Run tests shell: bash - env: - LANGGRAPH_FF_SEND_V2: ${{ matrix.ff-send-v2 }} run: | make test_parallel diff --git a/libs/langgraph/langgraph/func/__init__.py b/libs/langgraph/langgraph/func/__init__.py index 62bf138e64..d625be5605 100644 --- a/libs/langgraph/langgraph/func/__init__.py +++ b/libs/langgraph/langgraph/func/__init__.py @@ -19,7 +19,7 @@ from langgraph.channels.ephemeral_value import EphemeralValue from langgraph.channels.last_value import LastValue from langgraph.checkpoint.base import BaseCheckpointSaver -from langgraph.constants import END, START, TAG_HIDDEN +from langgraph.constants import CONF, END, START, TAG_HIDDEN from langgraph.pregel import Pregel from langgraph.pregel.call import get_runnable_for_func from langgraph.pregel.read import PregelNode @@ -39,11 +39,11 @@ def call( **kwargs: Any, ) -> concurrent.futures.Future[T]: from langgraph.constants import CONFIG_KEY_CALL - from langgraph.utils.config import get_configurable + from langgraph.utils.config import get_config - conf = get_configurable() - impl = conf[CONFIG_KEY_CALL] - fut = impl(func, (args, kwargs), retry=retry) + config = get_config() + impl = config[CONF][CONFIG_KEY_CALL] + fut = impl(func, (args, kwargs), retry=retry, callbacks=config["callbacks"]) return fut diff --git a/libs/langgraph/langgraph/pregel/algo.py b/libs/langgraph/langgraph/pregel/algo.py index 130d7e983d..4031516c8d 100644 --- a/libs/langgraph/langgraph/pregel/algo.py +++ b/libs/langgraph/langgraph/pregel/algo.py @@ -19,6 +19,7 @@ ) from uuid import UUID +from langchain_core.callbacks import Callbacks from langchain_core.callbacks.manager import AsyncParentRunManager, ParentRunManager from langchain_core.runnables.config import RunnableConfig @@ -107,18 +108,25 @@ class PregelTaskWrites(NamedTuple): class Call: - __slots__ = ("func", "input", "retry") + __slots__ = ("func", "input", "retry", "callbacks") func: Callable input: Any retry: Optional[RetryPolicy] + callbacks: Callbacks def __init__( - self, func: Callable, input: Any, *, retry: Optional[RetryPolicy] + self, + func: Callable, + input: Any, + *, + retry: Optional[RetryPolicy], + callbacks: Callbacks, ) -> None: self.func = func self.input = input self.retry = retry + self.callbacks = callbacks def should_interrupt( @@ -465,9 +473,8 @@ def prepare_single_task( patch_config( merge_configs(config, {"metadata": metadata}), run_name=name, - callbacks=( - manager.get_child(f"graph:step:{step}") if manager else None - ), + callbacks=call.callbacks + or (manager.get_child(f"graph:step:{step}") if manager else None), configurable={ CONFIG_KEY_TASK_ID: task_id, # deque.extend is thread-safe diff --git a/libs/langgraph/langgraph/pregel/executor.py b/libs/langgraph/langgraph/pregel/executor.py index 8ba4debdcc..8bd8671be8 100644 --- a/libs/langgraph/langgraph/pregel/executor.py +++ b/libs/langgraph/langgraph/pregel/executor.py @@ -64,10 +64,14 @@ def submit( # type: ignore[valid-type] __next_tick__: bool = False, **kwargs: P.kwargs, ) -> concurrent.futures.Future[T]: + ctx = copy_context() if __next_tick__: - task = self.executor.submit(next_tick, fn, *args, **kwargs) + task = cast( + concurrent.futures.Future[T], + self.executor.submit(next_tick, ctx.run, fn, *args, **kwargs), # type: ignore[arg-type] + ) else: - task = self.executor.submit(fn, *args, **kwargs) + task = self.executor.submit(ctx.run, fn, *args, **kwargs) self.tasks[task] = (__cancel_on_exit__, __reraise_on_exit__) # add a callback to remove the task from the tasks dict when it's done task.add_done_callback(self.done) diff --git a/libs/langgraph/langgraph/pregel/runner.py b/libs/langgraph/langgraph/pregel/runner.py index e680518a5b..4a7114e463 100644 --- a/libs/langgraph/langgraph/pregel/runner.py +++ b/libs/langgraph/langgraph/pregel/runner.py @@ -17,6 +17,8 @@ cast, ) +from langchain_core.callbacks import Callbacks + from langgraph.constants import ( CONF, CONFIG_KEY_CALL, @@ -148,9 +150,12 @@ def call( input: Any, *, retry: Optional[RetryPolicy] = None, + callbacks: Callbacks = None, ) -> concurrent.futures.Future[Any]: (fut,) = writer( - task, [(PUSH, None)], calls=[Call(func, input, retry=retry)] + task, + [(PUSH, None)], + calls=[Call(func, input, retry=retry, callbacks=callbacks)], ) assert fut is not None, "writer did not return a future for call" return fut @@ -337,9 +342,12 @@ def call( input: Any, *, retry: Optional[RetryPolicy] = None, + callbacks: Callbacks = None, ) -> Union[asyncio.Future[Any], concurrent.futures.Future[Any]]: (fut,) = writer( - task, [(PUSH, None)], calls=[Call(func, input, retry=retry)] + task, + [(PUSH, None)], + calls=[Call(func, input, retry=retry, callbacks=callbacks)], ) assert fut is not None, "writer did not return a future for call" if asyncio.iscoroutinefunction(func): diff --git a/libs/langgraph/langgraph/types.py b/libs/langgraph/langgraph/types.py index fa1adafe67..066d7d1060 100644 --- a/libs/langgraph/langgraph/types.py +++ b/libs/langgraph/langgraph/types.py @@ -453,9 +453,9 @@ def node(state: State): RESUME, ) from langgraph.errors import GraphInterrupt - from langgraph.utils.config import get_configurable + from langgraph.utils.config import get_config - conf = get_configurable() + conf = get_config()["configurable"] # track interrupt index scratchpad: PregelScratchpad = conf[CONFIG_KEY_SCRATCHPAD] if "interrupt_counter" not in scratchpad: diff --git a/libs/langgraph/langgraph/utils/config.py b/libs/langgraph/langgraph/utils/config.py index 5bff9e8487..372ea9616c 100644 --- a/libs/langgraph/langgraph/utils/config.py +++ b/libs/langgraph/langgraph/utils/config.py @@ -132,7 +132,7 @@ def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig: def patch_config( config: Optional[RunnableConfig], *, - callbacks: Optional[Callbacks] = None, + callbacks: Callbacks = None, recursion_limit: Optional[int] = None, max_concurrency: Optional[int] = None, run_name: Optional[str] = None, @@ -297,7 +297,7 @@ def ensure_config(*configs: Optional[RunnableConfig]) -> RunnableConfig: return empty -def get_configurable() -> dict[str, Any]: +def get_config() -> RunnableConfig: if sys.version_info < (3, 11): try: if asyncio.current_task(): @@ -307,6 +307,6 @@ def get_configurable() -> dict[str, Any]: except RuntimeError: pass if var_config := var_child_runnable_config.get(): - return var_config[CONF] + return var_config else: raise RuntimeError("Called get_configurable outside of a runnable context") diff --git a/libs/langgraph/tests/test_pregel_async.py b/libs/langgraph/tests/test_pregel_async.py index c656430268..26352e5453 100644 --- a/libs/langgraph/tests/test_pregel_async.py +++ b/libs/langgraph/tests/test_pregel_async.py @@ -2456,6 +2456,7 @@ async def test_imp_task(checkpointer_name: str) -> None: async def mapper(input: int) -> str: nonlocal mapper_calls mapper_calls += 1 + await asyncio.sleep(0.1 * input) return str(input) * 2 @entrypoint(checkpointer=checkpointer) @@ -2465,7 +2466,8 @@ async def graph(input: list[int]) -> list[str]: answer = interrupt("question") return [m + answer for m in mapped] - thread1 = {"configurable": {"thread_id": "1"}} + tracer = FakeTracer() + thread1 = {"configurable": {"thread_id": "1"}, "callbacks": [tracer]} assert [c async for c in graph.astream([0, 1], thread1)] == [ {"mapper": "00"}, {"mapper": "11"}, @@ -2481,6 +2483,9 @@ async def graph(input: list[int]) -> list[str]: }, ] assert mapper_calls == 2 + assert len(tracer.runs) == 1 + assert len(tracer.runs[0].child_runs) == 1 + assert tracer.runs[0].child_runs[0].name == "graph" assert await graph.ainvoke(Command(resume="answer"), thread1) == [ "00answer",