diff --git a/python/docs/create_api_rst.py b/python/docs/create_api_rst.py index 253352767..7a2f75260 100644 --- a/python/docs/create_api_rst.py +++ b/python/docs/create_api_rst.py @@ -105,7 +105,9 @@ def _load_module_members(module_path: str, namespace: str) -> ModuleMembers: else ( "enum" if issubclass(type_, Enum) - else "Pydantic" if issubclass(type_, BaseModel) else "Regular" + else "Pydantic" + if issubclass(type_, BaseModel) + else "Regular" ) ) if hasattr(type_, "__slots__"): diff --git a/python/langsmith/pytest/__init__.py b/python/langsmith/pytest/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/langsmith/pytest/mark.py b/python/langsmith/pytest/mark.py new file mode 100644 index 000000000..cbe077f19 --- /dev/null +++ b/python/langsmith/pytest/mark.py @@ -0,0 +1,93 @@ +from __future__ import annotations + +import inspect +from typing import Any, Callable, Optional + +import pytest + +from langsmith import evaluate +from langsmith.evaluation._runner import TARGET_T + + +def parametrize( + dataset_name: str, + target_fn: TARGET_T, + *, + client: Optional[Any] = None, + max_concurrency: Optional[int] = None, +) -> Callable: + """Decorator to parametrize a test function with LangSmith dataset examples. + + Args: + dataset_name: Name of the LangSmith dataset to use + target_fn: Function to test that takes inputs dict and returns outputs dict + client: Optional LangSmith client to use + max_concurrency: Optional max number of concurrent evaluations + + Returns: + Decorated test function that will be parametrized with dataset examples. + """ + + def decorator(test_fn: Callable) -> Callable: + # Verify test function signature + sig = inspect.signature(test_fn) + required_params = {"inputs", "outputs", "reference_outputs"} + if not all(param in sig.parameters for param in required_params): + raise ValueError(f"Test function must accept parameters: {required_params}") + + def evaluator(run, example): + """Evaluator that runs the test function and returns pass/fail result.""" + try: + results = test_fn( + inputs=example.inputs, + outputs=run.outputs, + reference_outputs=example.outputs, + ) + except AssertionError as e: + return {"score": 0.0, "key": "pass", "comment": str(e)} + except Exception as e: + return { + "score": 0.0, + "key": "pass", + "comment": f"Unexpected error: {str(e)}", + } + else: + if not results: + return {"score": 1.0, "key": "pass"} + elif "results" not in results: + results = {"results": results} + else: + pass + results["results"].append({"score": 1.0, "key": "pass"}) + return results + + @pytest.mark.parametrize( + "example_result", + evaluate( + target_fn, + data=dataset_name, + evaluators=[evaluator], + client=client, + max_concurrency=max_concurrency, + experiment_prefix=f"pytest_{test_fn.__name__}", + blocking=False, + ), + ) + # @functools.wraps(test_fn) + def wrapped(example_result): + """Wrapped test function that gets parametrized with results.""" + # Fail the test if the evaluation failed + eval_results = example_result["evaluation_results"]["results"] + if not eval_results: + pytest.fail("No evaluation results") + + pass_result = [r for r in eval_results if r.key == "pass"][0] + if not pass_result.score: + error = pass_result.comment + pytest.fail( + f"Test failed for example {example_result['example'].id}: {error}" + ) + + return wrapped + + return decorator diff --git a/python/tests/unit_tests/test_pytest.py b/python/tests/unit_tests/test_pytest.py new file mode 100644 index 000000000..448600c86 --- /dev/null +++ b/python/tests/unit_tests/test_pytest.py @@ -0,0 +1,7 @@ +import langsmith as ls + + +@ls.pytest.mark.parametrize("Sample Dataset 3", (lambda x: x)) +def test_parametrize(inputs, outputs, reference_outputs) -> list: + assert inputs == outputs + return [{"key": "foo", "value": "bar"}]