Skip to content

Commit

Permalink
Send collected outputs for generators (#1082)
Browse files Browse the repository at this point in the history
  • Loading branch information
hinthornw authored Oct 10, 2024
1 parent f562471 commit 81dce7a
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 29 deletions.
54 changes: 27 additions & 27 deletions python/langsmith/run_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,22 +570,18 @@ async def async_generator_wrapper(
yield item
except BaseException as e:
await asyncio.shield(
aitertools.aio_to_thread(_on_run_end, run_container, error=e)
aitertools.aio_to_thread(
_on_run_end,
run_container,
error=e,
outputs=_get_function_result(results, reduce_fn),
)
)
raise e
if results:
if reduce_fn:
try:
function_result = reduce_fn(results)
except BaseException as e:
LOGGER.error(e)
function_result = results
else:
function_result = results
else:
function_result = None
await aitertools.aio_to_thread(
_on_run_end, run_container, outputs=function_result
_on_run_end,
run_container,
outputs=_get_function_result(results, reduce_fn),
)

@functools.wraps(func)
Expand Down Expand Up @@ -652,21 +648,13 @@ def generator_wrapper(
results.append(function_return)

except BaseException as e:
_on_run_end(run_container, error=e)
_on_run_end(
run_container,
error=e,
outputs=_get_function_result(results, reduce_fn),
)
raise e

if results:
if reduce_fn:
try:
function_result = reduce_fn(results)
except BaseException as e:
LOGGER.error(e)
function_result = results
else:
function_result = results
else:
function_result = None
_on_run_end(run_container, outputs=function_result)
_on_run_end(run_container, outputs=_get_function_result(results, reduce_fn))
return function_return

# "Stream" functions (used in methods like OpenAI/Anthropic's SDKs)
Expand Down Expand Up @@ -1709,3 +1697,15 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
return await self.__ls_stream__.__aexit__(exc_type, exc_val, exc_tb)
finally:
await self._aend_trace()


def _get_function_result(results: list, reduce_fn: Callable) -> Any:
if results:
if reduce_fn is not None:
try:
return reduce_fn(results)
except BaseException as e:
LOGGER.error(e)
return results
else:
return results
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "langsmith"
version = "0.1.132"
version = "0.1.133"
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
authors = ["LangChain <[email protected]>"]
license = "MIT"
Expand Down
59 changes: 58 additions & 1 deletion python/tests/unit_tests/test_run_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import time
import uuid
import warnings
from typing import Any, AsyncGenerator, Generator, Optional, Set, cast
from typing import Any, AsyncGenerator, Generator, List, Optional, Set, Tuple, cast
from unittest.mock import MagicMock, patch

import pytest
Expand Down Expand Up @@ -50,6 +50,17 @@ def _get_calls(
return calls


def _get_datas(mock_calls: List[Any]) -> List[Tuple[str, dict]]:
datas = []
for call_ in mock_calls:
data = json.loads(call_.kwargs["data"])
for verb in ("post", "patch"):
for payload in data.get(verb) or []:
datas.append((verb, payload))

return datas


def test__get_inputs_with_no_args() -> None:
def foo() -> None:
pass
Expand Down Expand Up @@ -1466,7 +1477,53 @@ async def my_function(a: int) -> AsyncGenerator[int, None]:
mock_calls = _get_calls(
mock_client, verbs={"POST", "PATCH", "GET"}, minimum=num_calls
)

assert len(mock_calls) == num_calls
if auto_batch_tracing:
datas = _get_datas(mock_calls)
outputs = [p["outputs"] for _, p in datas if p.get("outputs")]
assert len(outputs) == 1
assert outputs[0]["output"] == list(range(5))


@pytest.mark.parametrize("auto_batch_tracing", [True, False])
async def test_traceable_gen_exception(auto_batch_tracing: bool):
mock_client = _get_mock_client(
auto_batch_tracing=auto_batch_tracing,
info=ls_schemas.LangSmithInfo(
batch_ingest_config=ls_schemas.BatchIngestConfig(
size_limit_bytes=None, # Note this field is not used here
size_limit=100,
scale_up_nthreads_limit=16,
scale_up_qsize_trigger=1000,
scale_down_nempty_trigger=4,
)
),
)

@traceable
def my_function(a: int) -> Generator[int, None, None]:
for i in range(5):
yield i
raise ValueError("foo")

with tracing_context(enabled=True):
with pytest.raises(ValueError, match="foo"):
for _ in my_function(1, langsmith_extra={"client": mock_client}):
pass

# Get ALL the call args for the mock_client
num_calls = 1 if auto_batch_tracing else 2
mock_calls = _get_calls(
mock_client, verbs={"POST", "PATCH", "GET"}, minimum=num_calls
)

assert len(mock_calls) == num_calls
if auto_batch_tracing:
datas = _get_datas(mock_calls)
outputs = [p["outputs"] for _, p in datas if p.get("outputs")]
assert len(outputs) == 1
assert outputs[0]["output"] == list(range(5))


@pytest.mark.parametrize("env_var", [True, False])
Expand Down

0 comments on commit 81dce7a

Please sign in to comment.