From 515242d0bac763cded351272437ed7db130de719 Mon Sep 17 00:00:00 2001 From: vbarda Date: Tue, 3 Dec 2024 15:40:18 -0500 Subject: [PATCH] langgraph: allow passing kwargs to SDK methods in RemoteGraph's invoke/stream --- libs/langgraph/langgraph/pregel/remote.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/libs/langgraph/langgraph/pregel/remote.py b/libs/langgraph/langgraph/pregel/remote.py index d45cdb310..45b837dd7 100644 --- a/libs/langgraph/langgraph/pregel/remote.py +++ b/libs/langgraph/langgraph/pregel/remote.py @@ -575,6 +575,7 @@ def stream( interrupt_before: Optional[Union[All, Sequence[str]]] = None, interrupt_after: Optional[Union[All, Sequence[str]]] = None, subgraphs: bool = False, + **kwargs: Any, ) -> Iterator[Union[dict[str, Any], Any]]: """Create a run and stream the results. @@ -589,6 +590,7 @@ def stream( interrupt_before: Interrupt the graph before these nodes. interrupt_after: Interrupt the graph after these nodes. subgraphs: Stream from subgraphs. + **kwargs: Additional params to pass to client.runs.stream. Yields: The output of the graph. @@ -616,6 +618,7 @@ def stream( interrupt_after=interrupt_after, stream_subgraphs=subgraphs or stream is not None, if_not_exists="create", + **kwargs, ): # split mode and ns if NS_SEP in chunk.event: @@ -664,6 +667,7 @@ async def astream( interrupt_before: Optional[Union[All, Sequence[str]]] = None, interrupt_after: Optional[Union[All, Sequence[str]]] = None, subgraphs: bool = False, + **kwargs: Any, ) -> AsyncIterator[Union[dict[str, Any], Any]]: """Create a run and stream the results. @@ -678,6 +682,7 @@ async def astream( interrupt_before: Interrupt the graph before these nodes. interrupt_after: Interrupt the graph after these nodes. subgraphs: Stream from subgraphs. + **kwargs: Additional params to pass to client.runs.stream. Yields: The output of the graph. @@ -705,6 +710,7 @@ async def astream( interrupt_after=interrupt_after, stream_subgraphs=subgraphs or stream is not None, if_not_exists="create", + **kwargs, ): # split mode and ns if NS_SEP in chunk.event: @@ -767,18 +773,16 @@ def invoke( *, interrupt_before: Optional[Union[All, Sequence[str]]] = None, interrupt_after: Optional[Union[All, Sequence[str]]] = None, + **kwargs: Any, ) -> Union[dict[str, Any], Any]: """Create a run, wait until it finishes and return the final state. - This method calls `POST /threads/{thread_id}/runs/wait` if a `thread_id` - is speciffed in the `configurable` field of the config or - `POST /runs/wait` otherwise. - Args: input: Input to the graph. config: A `RunnableConfig` for graph invocation. interrupt_before: Interrupt the graph before these nodes. interrupt_after: Interrupt the graph after these nodes. + **kwargs: Additional params to pass to RemoteGraph.stream. Returns: The output of the graph. @@ -789,6 +793,7 @@ def invoke( interrupt_before=interrupt_before, interrupt_after=interrupt_after, stream_mode="values", + **kwargs, ): pass try: @@ -803,18 +808,16 @@ async def ainvoke( *, interrupt_before: Optional[Union[All, Sequence[str]]] = None, interrupt_after: Optional[Union[All, Sequence[str]]] = None, + **kwargs: Any, ) -> Union[dict[str, Any], Any]: """Create a run, wait until it finishes and return the final state. - This method calls `POST /threads/{thread_id}/runs/wait` if a `thread_id` - is speciffed in the `configurable` field of the config or - `POST /runs/wait` otherwise. - Args: input: Input to the graph. config: A `RunnableConfig` for graph invocation. interrupt_before: Interrupt the graph before these nodes. interrupt_after: Interrupt the graph after these nodes. + **kwargs: Additional params to pass to RemoteGraph.astream. Returns: The output of the graph. @@ -825,6 +828,7 @@ async def ainvoke( interrupt_before=interrupt_before, interrupt_after=interrupt_after, stream_mode="values", + **kwargs, ): pass try: