diff --git a/python/langsmith/run_helpers.py b/python/langsmith/run_helpers.py index 6da368a06..391a256a0 100644 --- a/python/langsmith/run_helpers.py +++ b/python/langsmith/run_helpers.py @@ -570,22 +570,18 @@ async def async_generator_wrapper( yield item except BaseException as e: await asyncio.shield( - aitertools.aio_to_thread(_on_run_end, run_container, error=e) + aitertools.aio_to_thread( + _on_run_end, + run_container, + error=e, + outputs=_get_function_result(results, reduce_fn), + ) ) raise e - if results: - if reduce_fn: - try: - function_result = reduce_fn(results) - except BaseException as e: - LOGGER.error(e) - function_result = results - else: - function_result = results - else: - function_result = None await aitertools.aio_to_thread( - _on_run_end, run_container, outputs=function_result + _on_run_end, + run_container, + outputs=_get_function_result(results, reduce_fn), ) @functools.wraps(func) @@ -652,21 +648,13 @@ def generator_wrapper( results.append(function_return) except BaseException as e: - _on_run_end(run_container, error=e) + _on_run_end( + run_container, + error=e, + outputs=_get_function_result(results, reduce_fn), + ) raise e - - if results: - if reduce_fn: - try: - function_result = reduce_fn(results) - except BaseException as e: - LOGGER.error(e) - function_result = results - else: - function_result = results - else: - function_result = None - _on_run_end(run_container, outputs=function_result) + _on_run_end(run_container, outputs=_get_function_result(results, reduce_fn)) return function_return # "Stream" functions (used in methods like OpenAI/Anthropic's SDKs) @@ -1709,3 +1697,15 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): return await self.__ls_stream__.__aexit__(exc_type, exc_val, exc_tb) finally: await self._aend_trace() + + +def _get_function_result(results: list, reduce_fn: Callable) -> Any: + if results: + if reduce_fn is not None: + try: + return reduce_fn(results) + except BaseException as e: + LOGGER.error(e) + return results + else: + return results diff --git a/python/pyproject.toml b/python/pyproject.toml index 62df69cef..2e8cd7499 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langsmith" -version = "0.1.132" +version = "0.1.133" description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." authors = ["LangChain "] license = "MIT" diff --git a/python/tests/unit_tests/test_run_helpers.py b/python/tests/unit_tests/test_run_helpers.py index 2f48dbff7..dc5291d5a 100644 --- a/python/tests/unit_tests/test_run_helpers.py +++ b/python/tests/unit_tests/test_run_helpers.py @@ -7,7 +7,7 @@ import time import uuid import warnings -from typing import Any, AsyncGenerator, Generator, Optional, Set, cast +from typing import Any, AsyncGenerator, Generator, List, Optional, Set, Tuple, cast from unittest.mock import MagicMock, patch import pytest @@ -50,6 +50,17 @@ def _get_calls( return calls +def _get_datas(mock_calls: List[Any]) -> List[Tuple[str, dict]]: + datas = [] + for call_ in mock_calls: + data = json.loads(call_.kwargs["data"]) + for verb in ("post", "patch"): + for payload in data.get(verb) or []: + datas.append((verb, payload)) + + return datas + + def test__get_inputs_with_no_args() -> None: def foo() -> None: pass @@ -1466,7 +1477,53 @@ async def my_function(a: int) -> AsyncGenerator[int, None]: mock_calls = _get_calls( mock_client, verbs={"POST", "PATCH", "GET"}, minimum=num_calls ) + + assert len(mock_calls) == num_calls + if auto_batch_tracing: + datas = _get_datas(mock_calls) + outputs = [p["outputs"] for _, p in datas if p.get("outputs")] + assert len(outputs) == 1 + assert outputs[0]["output"] == list(range(5)) + + +@pytest.mark.parametrize("auto_batch_tracing", [True, False]) +async def test_traceable_gen_exception(auto_batch_tracing: bool): + mock_client = _get_mock_client( + auto_batch_tracing=auto_batch_tracing, + info=ls_schemas.LangSmithInfo( + batch_ingest_config=ls_schemas.BatchIngestConfig( + size_limit_bytes=None, # Note this field is not used here + size_limit=100, + scale_up_nthreads_limit=16, + scale_up_qsize_trigger=1000, + scale_down_nempty_trigger=4, + ) + ), + ) + + @traceable + def my_function(a: int) -> Generator[int, None, None]: + for i in range(5): + yield i + raise ValueError("foo") + + with tracing_context(enabled=True): + with pytest.raises(ValueError, match="foo"): + for _ in my_function(1, langsmith_extra={"client": mock_client}): + pass + + # Get ALL the call args for the mock_client + num_calls = 1 if auto_batch_tracing else 2 + mock_calls = _get_calls( + mock_client, verbs={"POST", "PATCH", "GET"}, minimum=num_calls + ) + assert len(mock_calls) == num_calls + if auto_batch_tracing: + datas = _get_datas(mock_calls) + outputs = [p["outputs"] for _, p in datas if p.get("outputs")] + assert len(outputs) == 1 + assert outputs[0]["output"] == list(range(5)) @pytest.mark.parametrize("env_var", [True, False])