Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for handling pytest.skip in @unit #623

Merged
merged 5 commits into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 96 additions & 42 deletions python/langsmith/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import Any, Callable, Optional, Sequence, Tuple, TypeVar, overload

import orjson
import pytest
from typing_extensions import TypedDict

from langsmith import client as ls_client
Expand All @@ -22,6 +23,16 @@
from langsmith import schemas as ls_schemas
from langsmith import utils as ls_utils

try:
import pytest

SkipException = pytest.skip.Exception
except ImportError:

class SkipException(Exception): # type: ignore[no-redef]
pass


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -421,14 +432,27 @@ def get_version(self) -> Optional[datetime.datetime]:
with self._lock:
return self._version

def submit_result(self, run_id: uuid.UUID, error: Optional[str] = None) -> None:
self._executor.submit(self._submit_result, run_id, error)
def submit_result(
self, run_id: uuid.UUID, error: Optional[str] = None, skipped: bool = False
) -> None:
self._executor.submit(self._submit_result, run_id, error, skipped=skipped)

def _submit_result(self, run_id: uuid.UUID, error: Optional[str] = None) -> None:
def _submit_result(
self, run_id: uuid.UUID, error: Optional[str] = None, skipped: bool = False
) -> None:
if error:
self.client.create_feedback(
run_id, key="pass", score=0, comment=f"Error: {repr(error)}"
)
if skipped:
self.client.create_feedback(
run_id,
key="pass",
# Don't factor into aggregate score
score=None,
comment=f"Skipped: {repr(error)}",
)
else:
self.client.create_feedback(
run_id, key="pass", score=0, comment=f"Error: {repr(error)}"
)
else:
self.client.create_feedback(
run_id,
Expand Down Expand Up @@ -525,24 +549,39 @@ def _run_test(
run_id = uuid.uuid4()

def _test():
try:
func_ = func if rh.is_traceable_function(func) else rh.traceable(func)
func_(
*test_args,
**test_kwargs,
langsmith_extra={
"run_id": run_id,
"reference_example_id": example_id,
"project_name": test_suite.name,
},
)
except BaseException as e:
test_suite.submit_result(run_id, error=repr(e))
raise e
try:
test_suite.submit_result(run_id, error=None)
except BaseException as e:
logger.warning(f"Failed to create feedback for run_id {run_id}: {e}")
func_inputs = rh._get_inputs_safe(
inspect.signature(func), *test_args, **test_kwargs
)
with rh.trace(
name=getattr(func, "__name__", "Test"),
run_id=run_id,
reference_example_id=example_id,
inputs=func_inputs,
project_name=test_suite.name,
exceptions_to_handle=(SkipException,),
) as run_tree:
try:
result = func(*test_args, **test_kwargs)
run_tree.end(
outputs=(
result
if result is None or isinstance(result, dict)
else {"output": result}
)
)
except SkipException as e:
test_suite.submit_result(run_id, error=repr(e), skipped=True)
run_tree.end(
outputs={"skipped_reason": repr(e)},
)
raise e
except BaseException as e:
test_suite.submit_result(run_id, error=repr(e))
raise e
try:
test_suite.submit_result(run_id, error=None)
except BaseException as e:
logger.warning(f"Failed to create feedback for run_id {run_id}: {e}")

cache_path = (
Path(langtest_extra["cache"]) / f"{test_suite.id}.yaml"
Expand Down Expand Up @@ -574,24 +613,39 @@ async def _arun_test(
run_id = uuid.uuid4()

async def _test():
try:
func_ = func if rh.is_traceable_function(func) else rh.traceable(func)
await func_(
*test_args,
**test_kwargs,
langsmith_extra={
"run_id": run_id,
"reference_example_id": example_id,
"project_name": test_suite.name,
},
)
except BaseException as e:
test_suite.submit_result(run_id, error=repr(e))
raise e
try:
test_suite.submit_result(run_id, error=None)
except BaseException as e:
logger.warning(f"Failed to create feedback for run_id {run_id}: {e}")
func_inputs = rh._get_inputs_safe(
inspect.signature(func), *test_args, **test_kwargs
)
with rh.trace(
name=getattr(func, "__name__", "Test"),
run_id=run_id,
reference_example_id=example_id,
inputs=func_inputs,
project_name=test_suite.name,
exceptions_to_handle=(SkipException,),
) as run_tree:
try:
result = await func(*test_args, **test_kwargs)
run_tree.end(
outputs=(
result
if result is None or isinstance(result, dict)
else {"output": result}
)
)
except SkipException as e:
test_suite.submit_result(run_id, error=repr(e), skipped=True)
run_tree.end(
outputs={"skipped_reason": repr(e)},
)
raise e
except BaseException as e:
test_suite.submit_result(run_id, error=repr(e))
raise e
try:
test_suite.submit_result(run_id, error=None)
except BaseException as e:
logger.warning(f"Failed to create feedback for run_id {run_id}: {e}")

cache_path = (
Path(langtest_extra["cache"]) / f"{test_suite.id}.yaml"
Expand Down
2 changes: 2 additions & 0 deletions python/langsmith/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import atexit
import collections
import datetime
import functools
Expand Down Expand Up @@ -548,6 +549,7 @@ def __init__(
else ls_schemas.LangSmithInfo(**info)
)
weakref.finalize(self, close_session, self.session)
atexit.register(close_session, self.session)
# Initialize auto batching
if auto_batch_tracing:
self.tracing_queue: Optional[PriorityQueue] = PriorityQueue()
Expand Down
25 changes: 18 additions & 7 deletions python/langsmith/run_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
Mapping,
Optional,
Protocol,
Tuple,
Type,
TypedDict,
TypeVar,
Union,
Expand Down Expand Up @@ -425,7 +427,7 @@ async def async_wrapper(
**get_tracing_context(run_container["context"])
):
function_result = await fr_coro
except Exception as e:
except BaseException as e:
stacktrace = traceback.format_exc()
_container_end(run_container, error=stacktrace)
raise e
Expand Down Expand Up @@ -506,7 +508,7 @@ async def async_generator_wrapper(
if reduce_fn:
try:
function_result = reduce_fn(results)
except Exception as e:
except BaseException as e:
LOGGER.error(e)
function_result = results
else:
Expand Down Expand Up @@ -600,7 +602,7 @@ def generator_wrapper(
if reduce_fn:
try:
function_result = reduce_fn(results)
except Exception as e:
except BaseException as e:
LOGGER.error(e)
function_result = results
else:
Expand Down Expand Up @@ -643,6 +645,9 @@ def trace(
tags: Optional[List[str]] = None,
metadata: Optional[Mapping[str, Any]] = None,
client: Optional[ls_client.Client] = None,
run_id: Optional[ls_client.ID_TYPE] = None,
reference_example_id: Optional[ls_client.ID_TYPE] = None,
exceptions_to_handle: Optional[Tuple[Type[BaseException], ...]] = None,
**kwargs: Any,
) -> Generator[run_trees.RunTree, None, None]:
"""Context manager for creating a run tree."""
Expand Down Expand Up @@ -673,6 +678,7 @@ def trace(
if parent_run_ is not None:
new_run = parent_run_.create_child(
name=name,
run_id=run_id,
run_type=run_type,
extra=extra_outer,
inputs=inputs,
Expand All @@ -681,6 +687,8 @@ def trace(
else:
new_run = run_trees.RunTree(
name=name,
run_id=run_id,
reference_example_id=reference_example_id,
run_type=run_type,
extra=extra_outer,
project_name=project_name_,
Expand All @@ -694,7 +702,10 @@ def trace(
try:
yield new_run
except (Exception, KeyboardInterrupt, BaseException) as e:
tb = traceback.format_exc()
if exceptions_to_handle and isinstance(e, exceptions_to_handle):
tb = None
else:
tb = traceback.format_exc()
new_run.end(error=tb)
new_run.patch()
raise e
Expand Down Expand Up @@ -913,7 +924,7 @@ def _container_end(
if on_end is not None and callable(on_end):
try:
on_end(run_tree)
except Exception as e:
except BaseException as e:
LOGGER.warning(f"Failed to run on_end function: {e}")


Expand Down Expand Up @@ -1005,7 +1016,7 @@ def _setup_run(
if process_inputs:
try:
inputs = process_inputs(inputs)
except Exception as e:
except BaseException as e:
LOGGER.error(f"Failed to filter inputs for {name_}: {e}")
tags_ = (langsmith_extra.get("tags") or []) + (outer_tags or [])
context.run(_TAGS.set, tags_)
Expand Down Expand Up @@ -1044,7 +1055,7 @@ def _setup_run(
)
try:
new_run.post()
except Exception as e:
except BaseException as e:
LOGGER.error(f"Failed to post run {new_run.id}: {e}")
response_container = _TraceableContainer(
new_run=new_run,
Expand Down
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "langsmith"
version = "0.1.49"
version = "0.1.50"
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
authors = ["LangChain <[email protected]>"]
license = "MIT"
Expand Down
10 changes: 10 additions & 0 deletions python/tests/evaluation/test_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,13 @@ async def test_bar_async_parametrized(x, y, z):
await asyncio.sleep(0.1)
expect(x + y).to_equal(z)
return {"z": x + y}


@unit
def test_pytest_skip():
pytest.skip("Skip this test")


@unit
async def test_async_pytest_skip():
pytest.skip("Skip this test")
Empty file.
78 changes: 78 additions & 0 deletions python/tests/external/test_instructor_evals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from enum import Enum
from itertools import product
from typing import Literal

import instructor # type: ignore
import pytest
from anthropic import AsyncAnthropic # type: ignore
from openai import AsyncOpenAI
from pydantic import BaseModel

from langsmith import unit


class Models(str, Enum):
GPT35TURBO = "gpt-3.5-turbo"
GPT4TURBO = "gpt-4-turbo"
CLAUDE3_SONNET = "claude-3-sonnet-20240229"
CLAUDE3_OPUS = "claude-3-opus-20240229"
CLAUDE3_HAIKU = "claude-3-haiku-20240307"


clients = (
instructor.from_openai(
AsyncOpenAI(),
model=Models.GPT35TURBO,
),
instructor.from_openai(
AsyncOpenAI(),
model=Models.GPT4TURBO,
),
instructor.from_anthropic(
AsyncAnthropic(),
model=Models.CLAUDE3_OPUS,
max_tokens=4000,
),
instructor.from_anthropic(
AsyncAnthropic(),
model=Models.CLAUDE3_SONNET,
max_tokens=4000,
),
instructor.from_anthropic(
AsyncAnthropic(),
model=Models.CLAUDE3_HAIKU,
max_tokens=4000,
),
)


class ClassifySpam(BaseModel):
label: Literal["spam", "not_spam"]


data = [
("I am a spammer who sends many emails every day", "spam"),
("I am a responsible person who does not spam", "not_spam"),
]
d = list(product(clients, data))


@pytest.mark.asyncio_cooperative
@unit()
@pytest.mark.parametrize("client, data", d[:3])
async def test_classification(client, data):
input, expected = data
prediction = await client.create(
response_model=ClassifySpam,
messages=[
{
"role": "system",
"content": "Classify this text as 'spam' or 'not_spam'.",
},
{
"role": "user",
"content": input,
},
],
)
assert prediction.label == expected
Loading