From 0ec86b51cc8ce7cbb318e2407fc1565a48a6bb0f Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Fri, 19 Apr 2024 18:11:25 -0400 Subject: [PATCH 1/4] x --- libs/core/langchain_core/runnables/base.py | 30 ++++++++++--------- .../unit_tests/runnables/test_runnable.py | 5 ++-- 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index a29a0677ee990..3b256f08ddae2 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -4000,13 +4000,14 @@ def _transform( ) -> Iterator[Output]: final: Optional[Input] = None for ichunk in input: - if final is None: - final = adapt_first_streaming_chunk(ichunk) # type: ignore - else: - try: - final = final + ichunk # type: ignore[operator] - except TypeError: - final = ichunk + # By definitions, RunnableLambdas consume all input before emitting output. + # If the input is not addable, then we'll assume that we can + # only operate on the last chunk. + # So we'll iterate until we get to the last chunk! + try: + final = final + ichunk # type: ignore[operator] + except TypeError: + final = ichunk if inspect.isgeneratorfunction(self.func): output: Optional[Output] = None @@ -4084,13 +4085,14 @@ async def _atransform( ) -> AsyncIterator[Output]: final: Optional[Input] = None async for ichunk in input: - if final is None: - final = adapt_first_streaming_chunk(ichunk) - else: - try: - final = final + ichunk # type: ignore[operator] - except TypeError: - final = ichunk + # By definitions, RunnableLambdas consume all input before emitting output. + # If the input is not addable, then we'll assume that we can + # only operate on the last chunk. + # So we'll iterate until we get to the last chunk! + try: + final = final + ichunk # type: ignore[operator] + except TypeError: + final = ichunk if hasattr(self, "afunc"): afunc = self.afunc diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index ca6d2a3adafac..c7ff4c272aec6 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -5401,11 +5401,10 @@ def test_transform_of_runnable_lambda_with_dicts() -> None: runnable = RunnableLambda(lambda x: x) chunks = iter( [ - {"foo": "a"}, {"foo": "n"}, ] ) - assert list(runnable.transform(chunks)) == [{"foo": "an"}] + assert list(runnable.transform(chunks)) == [{"foo": "n"}] async def test_atransform_of_runnable_lambda_with_dicts() -> None: @@ -5420,7 +5419,7 @@ async def chunk_iterator() -> AsyncIterator[Dict[str, str]]: yield {"foo": "n"} chunks = [chunk async for chunk in runnable.atransform(chunk_iterator())] - assert chunks == [{"foo": "an"}] + assert chunks == [{"foo": "n"}] def test_default_transform_with_dicts() -> None: From 7d795244d1159d1c37b068f1287b373d19150bad Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Mon, 22 Apr 2024 10:56:31 -0400 Subject: [PATCH 2/4] x --- libs/core/langchain_core/runnables/base.py | 69 +++++++++++-------- .../langchain_core/runnables/passthrough.py | 37 +++++++--- libs/core/langchain_core/runnables/utils.py | 8 --- .../unit_tests/runnables/test_runnable.py | 27 +++++++- 4 files changed, 92 insertions(+), 49 deletions(-) diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index 3b256f08ddae2..ee48234df3a80 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -69,7 +69,6 @@ accepts_config, accepts_context, accepts_run_manager, - adapt_first_streaming_chunk, create_model, gather_with_concurrency, get_function_first_arg_dict_keys, @@ -1280,21 +1279,22 @@ def transform( final: Input got_first_val = False - for chunk in input: + for ichunk in input: + # The default implementation of transform is to buffer input and + # then call stream. + # It'll attempt to gather all input into a single chunk using + # the `+` operator. + # If the input is not addable, then we'll assume that we can + # only operate on the last chunk, + # and we'll iterate until we get to the last chunk. if not got_first_val: - final = adapt_first_streaming_chunk(chunk) # type: ignore + final = ichunk got_first_val = True else: - # Make a best effort to gather, for any type that supports `+` - # This method should throw an error if gathering fails. try: - final = final + chunk # type: ignore[operator] + final = final + ichunk # type: ignore[operator] except TypeError: - raise TypeError( - f"Failed while trying to add together " - f"type {type(final)} and {type(chunk)}." - f"These types should be addable for transform to work." - ) + final = ichunk if got_first_val: yield from self.stream(final, config, **kwargs) @@ -1313,21 +1313,22 @@ async def atransform( final: Input got_first_val = False - async for chunk in input: + async for ichunk in input: + # The default implementation of transform is to buffer input and + # then call stream. + # It'll attempt to gather all input into a single chunk using + # the `+` operator. + # If the input is not addable, then we'll assume that we can + # only operate on the last chunk, + # and we'll iterate until we get to the last chunk. if not got_first_val: - final = adapt_first_streaming_chunk(chunk) # type: ignore + final = ichunk got_first_val = True else: - # Make a best effort to gather, for any type that supports `+` - # This method should throw an error if gathering fails. try: - final = final + chunk # type: ignore[operator] + final = final + ichunk # type: ignore[operator] except TypeError: - raise TypeError( - f"Failed while trying to add together " - f"type {type(final)} and {type(chunk)}." - f"These types should be addable for atransform to work." - ) + final = ichunk if got_first_val: async for output in self.astream(final, config, **kwargs): @@ -3998,16 +3999,21 @@ def _transform( config: RunnableConfig, **kwargs: Any, ) -> Iterator[Output]: - final: Optional[Input] = None + final: Input + got_first_val = False for ichunk in input: # By definitions, RunnableLambdas consume all input before emitting output. # If the input is not addable, then we'll assume that we can # only operate on the last chunk. # So we'll iterate until we get to the last chunk! - try: - final = final + ichunk # type: ignore[operator] - except TypeError: + if not got_first_val: final = ichunk + got_first_val = True + else: + try: + final = final + ichunk # type: ignore[operator] + except TypeError: + final = ichunk if inspect.isgeneratorfunction(self.func): output: Optional[Output] = None @@ -4083,16 +4089,21 @@ async def _atransform( config: RunnableConfig, **kwargs: Any, ) -> AsyncIterator[Output]: - final: Optional[Input] = None + final: Input + got_first_val = False async for ichunk in input: # By definitions, RunnableLambdas consume all input before emitting output. # If the input is not addable, then we'll assume that we can # only operate on the last chunk. # So we'll iterate until we get to the last chunk! - try: - final = final + ichunk # type: ignore[operator] - except TypeError: + if not got_first_val: final = ichunk + got_first_val = True + else: + try: + final = final + ichunk # type: ignore[operator] + except TypeError: + final = ichunk if hasattr(self, "afunc"): afunc = self.afunc diff --git a/libs/core/langchain_core/runnables/passthrough.py b/libs/core/langchain_core/runnables/passthrough.py index d2fbf30e4b9c8..fdb31c2f6349b 100644 --- a/libs/core/langchain_core/runnables/passthrough.py +++ b/libs/core/langchain_core/runnables/passthrough.py @@ -40,7 +40,6 @@ from langchain_core.runnables.utils import ( AddableDict, ConfigurableFieldSpec, - adapt_first_streaming_chunk, create_model, ) from langchain_core.utils.aiter import atee, py_anext @@ -243,16 +242,22 @@ def transform( for chunk in self._transform_stream_with_config(input, identity, config): yield chunk else: - final = None + final: Other + got_first_chunk = False for chunk in self._transform_stream_with_config(input, identity, config): yield chunk - if final is None: - final = adapt_first_streaming_chunk(chunk) + + if not got_first_chunk: + final = chunk + got_first_chunk = True else: - final = final + chunk + try: + final = final + chunk + except TypeError: + final = chunk - if final is not None: + if got_first_chunk: call_func_with_variable_args( self.func, final, ensure_config(config), **kwargs ) @@ -269,18 +274,28 @@ async def atransform( ): yield chunk else: - final = None + got_first_chunk = False async for chunk in self._atransform_stream_with_config( input, identity, config ): yield chunk - if final is None: - final = adapt_first_streaming_chunk(chunk) + + # By definitions, a function will operate on the aggregated + # input. So we'll aggregate the input until we get to the last + # chunk. + # If the input is not addable, then we'll assume that we can + # only operate on the last chunk. + if not got_first_chunk: + final = chunk + got_first_chunk = True else: - final = final + chunk + try: + final = final + chunk + except TypeError: + final = chunk - if final is not None: + if got_first_chunk: config = ensure_config(config) if self.afunc is not None: await acall_func_with_variable_args( diff --git a/libs/core/langchain_core/runnables/utils.py b/libs/core/langchain_core/runnables/utils.py index dff10ad04957a..d5553e786f519 100644 --- a/libs/core/langchain_core/runnables/utils.py +++ b/libs/core/langchain_core/runnables/utils.py @@ -524,11 +524,3 @@ def _create_model_cached( return _create_model_base( __model_name, __config__=_SchemaConfig, **field_definitions ) - - -def adapt_first_streaming_chunk(chunk: Any) -> Any: - """This might transform the first chunk of a stream into an AddableDict.""" - if isinstance(chunk, dict) and not isinstance(chunk, AddableDict): - return AddableDict(chunk) - else: - return chunk diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index c7ff4c272aec6..aea306e83ab9a 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -5406,6 +5406,17 @@ def test_transform_of_runnable_lambda_with_dicts() -> None: ) assert list(runnable.transform(chunks)) == [{"foo": "n"}] + # Test as part of a sequence + seq = runnable | runnable + chunks = iter( + [ + {"foo": "n"}, + ] + ) + assert list(seq.transform(chunks)) == [{"foo": "n"}] + # Test some other edge cases + assert list(seq.stream({"foo": "n"})) == [{"foo": "n"}] + async def test_atransform_of_runnable_lambda_with_dicts() -> None: async def identity(x: Dict[str, str]) -> Dict[str, str]: @@ -5421,6 +5432,10 @@ async def chunk_iterator() -> AsyncIterator[Dict[str, str]]: chunks = [chunk async for chunk in runnable.atransform(chunk_iterator())] assert chunks == [{"foo": "n"}] + seq = runnable | runnable + chunks = [chunk async for chunk in seq.atransform(chunk_iterator())] + assert chunks == [{"foo": "n"}] + def test_default_transform_with_dicts() -> None: """Test that default transform works with dicts.""" @@ -5439,7 +5454,8 @@ def invoke( ] ) - assert list(runnable.transform(chunks)) == [{"foo": "an"}] + assert list(runnable.transform(chunks)) == [{"foo": "n"}] + assert list(runnable.stream({"foo": "n"})) == [{"foo": "n"}] async def test_default_atransform_with_dicts() -> None: @@ -5459,6 +5475,15 @@ async def chunk_iterator() -> AsyncIterator[Dict[str, str]]: chunks = [chunk async for chunk in runnable.atransform(chunk_iterator())] + assert chunks == [{"foo": "n"}] + + # Test with addable dict + async def chunk_iterator() -> AsyncIterator[Dict[str, str]]: + yield AddableDict({"foo": "a"}) + yield AddableDict({"foo": "n"}) + + chunks = [chunk async for chunk in runnable.atransform(chunk_iterator())] + assert chunks == [{"foo": "an"}] From 292e066b51fd31198cac80b6e5b64d1e32afa93e Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Mon, 22 Apr 2024 13:00:40 -0400 Subject: [PATCH 3/4] x --- libs/core/langchain_core/runnables/passthrough.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libs/core/langchain_core/runnables/passthrough.py b/libs/core/langchain_core/runnables/passthrough.py index fdb31c2f6349b..ec081aea97f5e 100644 --- a/libs/core/langchain_core/runnables/passthrough.py +++ b/libs/core/langchain_core/runnables/passthrough.py @@ -253,7 +253,7 @@ def transform( got_first_chunk = True else: try: - final = final + chunk + final = final + chunk # type: ignore[operator] except TypeError: final = chunk @@ -291,7 +291,7 @@ async def atransform( got_first_chunk = True else: try: - final = final + chunk + final = final + chunk # type: ignore[operator] except TypeError: final = chunk From f4602d60895c79c86dfe96fac522f5d57fd4b991 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Tue, 23 Apr 2024 10:21:04 -0400 Subject: [PATCH 4/4] x --- libs/core/tests/unit_tests/runnables/test_runnable.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index aea306e83ab9a..07ebd37abf2c1 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -5478,11 +5478,13 @@ async def chunk_iterator() -> AsyncIterator[Dict[str, str]]: assert chunks == [{"foo": "n"}] # Test with addable dict - async def chunk_iterator() -> AsyncIterator[Dict[str, str]]: + async def chunk_iterator_with_addable() -> AsyncIterator[Dict[str, str]]: yield AddableDict({"foo": "a"}) yield AddableDict({"foo": "n"}) - chunks = [chunk async for chunk in runnable.atransform(chunk_iterator())] + chunks = [ + chunk async for chunk in runnable.atransform(chunk_iterator_with_addable()) + ] assert chunks == [{"foo": "an"}]