Skip to content

Commit

Permalink
Add support for handling pytest.skip in @Unit (#623)
Browse files Browse the repository at this point in the history
1. This should be marked as a `pass: None` to not factor into aggregate
score
2. The run shouldn't be marked as an error

The changes in this PR reflect these desiderata
  • Loading branch information
hinthornw authored Apr 23, 2024
1 parent aba2b9e commit bfe8643
Show file tree
Hide file tree
Showing 7 changed files with 205 additions and 50 deletions.
138 changes: 96 additions & 42 deletions python/langsmith/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import Any, Callable, Optional, Sequence, Tuple, TypeVar, overload

import orjson
import pytest
from typing_extensions import TypedDict

from langsmith import client as ls_client
Expand All @@ -22,6 +23,16 @@
from langsmith import schemas as ls_schemas
from langsmith import utils as ls_utils

try:
import pytest

SkipException = pytest.skip.Exception
except ImportError:

class SkipException(Exception): # type: ignore[no-redef]
pass


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -421,14 +432,27 @@ def get_version(self) -> Optional[datetime.datetime]:
with self._lock:
return self._version

def submit_result(self, run_id: uuid.UUID, error: Optional[str] = None) -> None:
self._executor.submit(self._submit_result, run_id, error)
def submit_result(
self, run_id: uuid.UUID, error: Optional[str] = None, skipped: bool = False
) -> None:
self._executor.submit(self._submit_result, run_id, error, skipped=skipped)

def _submit_result(self, run_id: uuid.UUID, error: Optional[str] = None) -> None:
def _submit_result(
self, run_id: uuid.UUID, error: Optional[str] = None, skipped: bool = False
) -> None:
if error:
self.client.create_feedback(
run_id, key="pass", score=0, comment=f"Error: {repr(error)}"
)
if skipped:
self.client.create_feedback(
run_id,
key="pass",
# Don't factor into aggregate score
score=None,
comment=f"Skipped: {repr(error)}",
)
else:
self.client.create_feedback(
run_id, key="pass", score=0, comment=f"Error: {repr(error)}"
)
else:
self.client.create_feedback(
run_id,
Expand Down Expand Up @@ -525,24 +549,39 @@ def _run_test(
run_id = uuid.uuid4()

def _test():
try:
func_ = func if rh.is_traceable_function(func) else rh.traceable(func)
func_(
*test_args,
**test_kwargs,
langsmith_extra={
"run_id": run_id,
"reference_example_id": example_id,
"project_name": test_suite.name,
},
)
except BaseException as e:
test_suite.submit_result(run_id, error=repr(e))
raise e
try:
test_suite.submit_result(run_id, error=None)
except BaseException as e:
logger.warning(f"Failed to create feedback for run_id {run_id}: {e}")
func_inputs = rh._get_inputs_safe(
inspect.signature(func), *test_args, **test_kwargs
)
with rh.trace(
name=getattr(func, "__name__", "Test"),
run_id=run_id,
reference_example_id=example_id,
inputs=func_inputs,
project_name=test_suite.name,
exceptions_to_handle=(SkipException,),
) as run_tree:
try:
result = func(*test_args, **test_kwargs)
run_tree.end(
outputs=(
result
if result is None or isinstance(result, dict)
else {"output": result}
)
)
except SkipException as e:
test_suite.submit_result(run_id, error=repr(e), skipped=True)
run_tree.end(
outputs={"skipped_reason": repr(e)},
)
raise e
except BaseException as e:
test_suite.submit_result(run_id, error=repr(e))
raise e
try:
test_suite.submit_result(run_id, error=None)
except BaseException as e:
logger.warning(f"Failed to create feedback for run_id {run_id}: {e}")

cache_path = (
Path(langtest_extra["cache"]) / f"{test_suite.id}.yaml"
Expand Down Expand Up @@ -574,24 +613,39 @@ async def _arun_test(
run_id = uuid.uuid4()

async def _test():
try:
func_ = func if rh.is_traceable_function(func) else rh.traceable(func)
await func_(
*test_args,
**test_kwargs,
langsmith_extra={
"run_id": run_id,
"reference_example_id": example_id,
"project_name": test_suite.name,
},
)
except BaseException as e:
test_suite.submit_result(run_id, error=repr(e))
raise e
try:
test_suite.submit_result(run_id, error=None)
except BaseException as e:
logger.warning(f"Failed to create feedback for run_id {run_id}: {e}")
func_inputs = rh._get_inputs_safe(
inspect.signature(func), *test_args, **test_kwargs
)
with rh.trace(
name=getattr(func, "__name__", "Test"),
run_id=run_id,
reference_example_id=example_id,
inputs=func_inputs,
project_name=test_suite.name,
exceptions_to_handle=(SkipException,),
) as run_tree:
try:
result = await func(*test_args, **test_kwargs)
run_tree.end(
outputs=(
result
if result is None or isinstance(result, dict)
else {"output": result}
)
)
except SkipException as e:
test_suite.submit_result(run_id, error=repr(e), skipped=True)
run_tree.end(
outputs={"skipped_reason": repr(e)},
)
raise e
except BaseException as e:
test_suite.submit_result(run_id, error=repr(e))
raise e
try:
test_suite.submit_result(run_id, error=None)
except BaseException as e:
logger.warning(f"Failed to create feedback for run_id {run_id}: {e}")

cache_path = (
Path(langtest_extra["cache"]) / f"{test_suite.id}.yaml"
Expand Down
2 changes: 2 additions & 0 deletions python/langsmith/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import atexit
import collections
import datetime
import functools
Expand Down Expand Up @@ -548,6 +549,7 @@ def __init__(
else ls_schemas.LangSmithInfo(**info)
)
weakref.finalize(self, close_session, self.session)
atexit.register(close_session, self.session)
# Initialize auto batching
if auto_batch_tracing:
self.tracing_queue: Optional[PriorityQueue] = PriorityQueue()
Expand Down
25 changes: 18 additions & 7 deletions python/langsmith/run_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
Mapping,
Optional,
Protocol,
Tuple,
Type,
TypedDict,
TypeVar,
Union,
Expand Down Expand Up @@ -425,7 +427,7 @@ async def async_wrapper(
**get_tracing_context(run_container["context"])
):
function_result = await fr_coro
except Exception as e:
except BaseException as e:
stacktrace = traceback.format_exc()
_container_end(run_container, error=stacktrace)
raise e
Expand Down Expand Up @@ -506,7 +508,7 @@ async def async_generator_wrapper(
if reduce_fn:
try:
function_result = reduce_fn(results)
except Exception as e:
except BaseException as e:
LOGGER.error(e)
function_result = results
else:
Expand Down Expand Up @@ -600,7 +602,7 @@ def generator_wrapper(
if reduce_fn:
try:
function_result = reduce_fn(results)
except Exception as e:
except BaseException as e:
LOGGER.error(e)
function_result = results
else:
Expand Down Expand Up @@ -643,6 +645,9 @@ def trace(
tags: Optional[List[str]] = None,
metadata: Optional[Mapping[str, Any]] = None,
client: Optional[ls_client.Client] = None,
run_id: Optional[ls_client.ID_TYPE] = None,
reference_example_id: Optional[ls_client.ID_TYPE] = None,
exceptions_to_handle: Optional[Tuple[Type[BaseException], ...]] = None,
**kwargs: Any,
) -> Generator[run_trees.RunTree, None, None]:
"""Context manager for creating a run tree."""
Expand Down Expand Up @@ -673,6 +678,7 @@ def trace(
if parent_run_ is not None:
new_run = parent_run_.create_child(
name=name,
run_id=run_id,
run_type=run_type,
extra=extra_outer,
inputs=inputs,
Expand All @@ -681,6 +687,8 @@ def trace(
else:
new_run = run_trees.RunTree(
name=name,
run_id=run_id,
reference_example_id=reference_example_id,
run_type=run_type,
extra=extra_outer,
project_name=project_name_,
Expand All @@ -694,7 +702,10 @@ def trace(
try:
yield new_run
except (Exception, KeyboardInterrupt, BaseException) as e:
tb = traceback.format_exc()
if exceptions_to_handle and isinstance(e, exceptions_to_handle):
tb = None
else:
tb = traceback.format_exc()
new_run.end(error=tb)
new_run.patch()
raise e
Expand Down Expand Up @@ -913,7 +924,7 @@ def _container_end(
if on_end is not None and callable(on_end):
try:
on_end(run_tree)
except Exception as e:
except BaseException as e:
LOGGER.warning(f"Failed to run on_end function: {e}")


Expand Down Expand Up @@ -1005,7 +1016,7 @@ def _setup_run(
if process_inputs:
try:
inputs = process_inputs(inputs)
except Exception as e:
except BaseException as e:
LOGGER.error(f"Failed to filter inputs for {name_}: {e}")
tags_ = (langsmith_extra.get("tags") or []) + (outer_tags or [])
context.run(_TAGS.set, tags_)
Expand Down Expand Up @@ -1044,7 +1055,7 @@ def _setup_run(
)
try:
new_run.post()
except Exception as e:
except BaseException as e:
LOGGER.error(f"Failed to post run {new_run.id}: {e}")
response_container = _TraceableContainer(
new_run=new_run,
Expand Down
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "langsmith"
version = "0.1.49"
version = "0.1.50"
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
authors = ["LangChain <[email protected]>"]
license = "MIT"
Expand Down
10 changes: 10 additions & 0 deletions python/tests/evaluation/test_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,13 @@ async def test_bar_async_parametrized(x, y, z):
await asyncio.sleep(0.1)
expect(x + y).to_equal(z)
return {"z": x + y}


@unit
def test_pytest_skip():
pytest.skip("Skip this test")


@unit
async def test_async_pytest_skip():
pytest.skip("Skip this test")
Empty file.
78 changes: 78 additions & 0 deletions python/tests/external/test_instructor_evals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from enum import Enum
from itertools import product
from typing import Literal

import instructor # type: ignore
import pytest
from anthropic import AsyncAnthropic # type: ignore
from openai import AsyncOpenAI
from pydantic import BaseModel

from langsmith import unit


class Models(str, Enum):
GPT35TURBO = "gpt-3.5-turbo"
GPT4TURBO = "gpt-4-turbo"
CLAUDE3_SONNET = "claude-3-sonnet-20240229"
CLAUDE3_OPUS = "claude-3-opus-20240229"
CLAUDE3_HAIKU = "claude-3-haiku-20240307"


clients = (
instructor.from_openai(
AsyncOpenAI(),
model=Models.GPT35TURBO,
),
instructor.from_openai(
AsyncOpenAI(),
model=Models.GPT4TURBO,
),
instructor.from_anthropic(
AsyncAnthropic(),
model=Models.CLAUDE3_OPUS,
max_tokens=4000,
),
instructor.from_anthropic(
AsyncAnthropic(),
model=Models.CLAUDE3_SONNET,
max_tokens=4000,
),
instructor.from_anthropic(
AsyncAnthropic(),
model=Models.CLAUDE3_HAIKU,
max_tokens=4000,
),
)


class ClassifySpam(BaseModel):
label: Literal["spam", "not_spam"]


data = [
("I am a spammer who sends many emails every day", "spam"),
("I am a responsible person who does not spam", "not_spam"),
]
d = list(product(clients, data))


@pytest.mark.asyncio_cooperative
@unit()
@pytest.mark.parametrize("client, data", d[:3])
async def test_classification(client, data):
input, expected = data
prediction = await client.create(
response_model=ClassifySpam,
messages=[
{
"role": "system",
"content": "Classify this text as 'spam' or 'not_spam'.",
},
{
"role": "user",
"content": input,
},
],
)
assert prediction.label == expected

0 comments on commit bfe8643

Please sign in to comment.