diff --git a/python/langsmith/__init__.py b/python/langsmith/__init__.py index cd3d6c4c9..52a9251c2 100644 --- a/python/langsmith/__init__.py +++ b/python/langsmith/__init__.py @@ -1,8 +1,9 @@ """LangSmith Client.""" -from typing import TYPE_CHECKING, Any, List +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: + from langsmith._testing import unit from langsmith.client import Client from langsmith.evaluation.evaluator import EvaluationResult, RunEvaluator from langsmith.run_helpers import trace, traceable @@ -41,6 +42,10 @@ def __getattr__(name: str) -> Any: from langsmith.run_helpers import traceable return traceable + elif name == "unit": + from langsmith._testing import unit + + return unit raise AttributeError(f"module {__name__!r} has no attribute {name!r}") @@ -53,8 +58,5 @@ def __getattr__(name: str) -> Any: "RunEvaluator", "traceable", "trace", + "unit", ] - - -def __dir__() -> List[str]: - return __all__ diff --git a/python/langsmith/_testing.py b/python/langsmith/_testing.py new file mode 100644 index 000000000..acbb7593c --- /dev/null +++ b/python/langsmith/_testing.py @@ -0,0 +1,387 @@ +from __future__ import annotations + +import atexit +import concurrent.futures +import datetime +import functools +import inspect +import json +import logging +import os +import threading +import uuid +import warnings +from typing import Any, Callable, Optional, Sequence, Tuple, TypeVar, overload + +from typing_extensions import TypedDict + +from langsmith import client as ls_client +from langsmith import env as ls_env +from langsmith import run_helpers as rh +from langsmith import schemas as ls_schemas +from langsmith import utils as ls_utils + +logger = logging.getLogger(__name__) + + +T = TypeVar("T") +U = TypeVar("U") + + +@overload +def unit( + func: Callable, +) -> Callable: ... + + +@overload +def unit( + *, + id: Optional[uuid.UUID] = None, + output_keys: Optional[Sequence[str]] = None, + client: Optional[ls_client.Client] = None, + test_suite_name: Optional[str] = None, +) -> Callable[[Callable], Callable]: ... + + +def unit(*args: Any, **kwargs: Any) -> Callable: + """Create a unit test case in LangSmith. + + This decorator is used to mark a function as a test case for LangSmith. It ensures + that the necessary example data is created and associated with the test function. + The decorated function will be executed as a test case, and the results will be + recorded and reported by LangSmith. + + Args: + - id (Optional[uuid.UUID]): A unique identifier for the test case. If not + provided, an ID will be generated based on the test function's module + and name. + - output_keys (Optional[Sequence[str]]): A list of keys to be considered as + the output keys for the test case. These keys will be extracted from the + test function's inputs and stored as the expected outputs. + - client (Optional[ls_client.Client]): An instance of the LangSmith client + to be used for communication with the LangSmith service. If not provided, + a default client will be used. + - test_suite_name (Optional[str]): The name of the test suite to which the + test case belongs. If not provided, the test suite name will be determined + based on the environment or the package name. + + Returns: + Callable: The decorated test function. + + Example: + For basic usage, simply decorate a test function with `@unit`: + + >>> @unit + ... def test_addition(): + ... assert 3 + 4 == 7 + + + Any code that is traced (such as those traced using `@traceable` + or `wrap_*` functions) will be traced within the test case for + improved visibility and debugging. + + >>> import openai + >>> from langsmith.wrappers import wrap_openai + >>> @unit + ... def test_openai_says_hello(): + ... # Traced code will be included in the test case + ... oai_client = wrap_openai(openai.Client()) + ... response = oai_client.chat.completions.create( + ... model="gpt-3.5-turbo", + ... messages=[ + ... {"role": "system", "content": "You are a helpful assistant."}, + ... {"role": "user", "content": "Say hello!"}, + ... ], + ... ) + ... assert "hello" in response.choices[0].message.content.lower() + + The `@unit` decorator works natively with pytest fixtures. + The values will populate the "inputs" of the corresponding example in LangSmith. + + >>> import pytest + >>> @pytest.fixture + ... def some_input(): + ... return "Some input" + >>> + >>> @unit + ... def test_with_fixture(some_input: str): + ... assert "input" in some_input + >>> + + You can still use pytest.parametrize() as usual to run multiple test cases + using the same test function. + + >>> @unit(output_keys=["expected"]) + ... @pytest.mark.parametrize( + ... "a, b, expected", + ... [ + ... (1, 2, 3), + ... (3, 4, 7), + ... ], + ... ) + ... def test_addition_with_multiple_inputs(a: int, b: int, expected: int): + ... assert a + b == expected + + By default, each test case will be assigned a consistent, unique identifier + based on the function name and module. You can also provide a custom identifier + using the `id` argument: + + >>> @unit(id=uuid.uuid4()) + ... def test_multiplication(): + ... assert 3 * 4 == 12 + + By default, all unit test inputs are saved as "inputs" to a dataset. + You can specify the `output_keys` argument to persist those keys + within the dataset's "outputs" fields. + + >>> @pytest.fixture + ... def expected_output(): + ... return "input" + >>> @unit(output_keys=["expected_output"]) + ... def test_with_expected_output(some_input: str, expected_output: str): + ... assert expected_output in some_input + + + To run these tests, use the pytest CLI. Or directly run the test functions. + >>> test_addition() + >>> test_with_fixture("Some input") + >>> test_with_expected_output("Some input", "Some") + >>> test_multiplication() + >>> test_openai_says_hello() + >>> test_addition_with_multiple_inputs(1, 2, 3) + """ + langtest_extra = _UTExtra( + id=kwargs.pop("id", None), + output_keys=kwargs.pop("output_keys", None), + client=kwargs.pop("client", None), + test_suite_name=kwargs.pop("test_suite_name", None), + ) + if kwargs: + warnings.warn(f"Unexpected keyword arguments: {kwargs.keys()}") + if args and callable(args[0]): + func = args[0] + + @functools.wraps(func) + def wrapper(*test_args, **test_kwargs): + _run_test( + func, + *test_args, + **test_kwargs, + langtest_extra=langtest_extra, + ) + + return wrapper + + def decorator(func): + @functools.wraps(func) + def wrapper(*test_args, **test_kwargs): + _run_test(func, *test_args, **test_kwargs, langtest_extra=langtest_extra) + + return wrapper + + return decorator + + +## Private functions + + +def _get_experiment_name() -> str: + # TODO Make more easily configurable + prefix = ls_utils.get_tracer_project(False) or "TestSuite" + name = f"{prefix}:{uuid.uuid4().hex[:8]}" + return name + + +def _get_test_suite_name() -> str: + # TODO: This naming stuff is inelegant + test_suite_name = os.environ.get("LANGCHAIN_TEST_SUITE") + if test_suite_name: + return test_suite_name + if __package__: + return __package__ + git_info = ls_env.get_git_info() + if git_info: + if git_info["remote_url"]: + repo_name = git_info["remote_url"].split("/")[-1].split(".")[0] + if repo_name: + return repo_name + raise ValueError("Please set the LANGCHAIN_TEST_SUITE environment variable.") + + +def _get_test_suite(client: ls_client.Client) -> ls_schemas.Dataset: + test_suite_name = _get_test_suite_name() + + if client.has_dataset(dataset_name=test_suite_name): + return client.read_dataset(dataset_name=test_suite_name) + else: + return client.create_dataset(dataset_name=test_suite_name) + + +def _start_experiment( + client: ls_client.Client, + test_suite: ls_schemas.Dataset, +) -> ls_schemas.TracerSessionResult: + experiment_name = _get_experiment_name() + return client.create_project(experiment_name, reference_dataset_id=test_suite.id) + + +def _get_id(func: Callable, inputs: dict) -> uuid.UUID: + input_json = json.dumps(inputs, sort_keys=True) + identifier = f"{func.__module__}.{func.__name__}_{input_json}" + + # Generate a UUID based on the identifier + return uuid.uuid5(uuid.NAMESPACE_DNS, identifier) + + +def _end_tests( + test_suite: _LangSmithTestSuite, +): + git_info = ls_env.get_git_info() or {} + test_suite.client.update_project( + test_suite.experiment_id, + end_time=datetime.datetime.now(datetime.timezone.utc), + metadata={**git_info, "dataset_version": test_suite.get_version()}, + ) + test_suite.wait() + + +class _LangSmithTestSuite: + _instance = None + _lock = threading.RLock() + + def __init__( + self, + client: Optional[ls_client.Client], + experiment: ls_schemas.TracerSessionResult, + dataset: ls_schemas.Dataset, + ): + self.client = client or ls_client.Client() + self._experiment = experiment + self._dataset = dataset + self._version: Optional[datetime.datetime] = None + self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + atexit.register(_end_tests, self) + + @property + def id(self): + return self._dataset.id + + @property + def experiment_id(self): + return self._experiment.id + + @classmethod + def get_singleton(cls, client: Optional[ls_client.Client]) -> _LangSmithTestSuite: + client = client or ls_client.Client() + with cls._lock: + if not cls._instance: + test_suite = _get_test_suite(client) + experiment = _start_experiment(client, test_suite) + cls._instance = cls(client, experiment, test_suite) + return cls._instance + + @property + def name(self): + return self._experiment.name + + def update_version(self, version: datetime.datetime) -> None: + if self._version is None or version > self._version: + self._version = version + + def get_version(self) -> Optional[datetime.datetime]: + return self._version + + def submit_result(self, run_id: uuid.UUID, error: Optional[str] = None) -> None: + self._executor.submit(self._submit_result, run_id, error) + + def _submit_result(self, run_id: uuid.UUID, error: Optional[str] = None) -> None: + if error: + self.client.create_feedback( + run_id, key="pass", score=0, comment=f"Error: {repr(error)}" + ) + else: + self.client.create_feedback( + run_id, + key="pass", + score=1, + ) + + def sync_example(self, example_id: uuid.UUID, inputs: dict, outputs: dict) -> None: + self._executor.submit(self._sync_example, example_id, inputs, outputs) + + def _sync_example(self, example_id: uuid.UUID, inputs: dict, outputs: dict) -> None: + try: + example = self.client.read_example(example_id=example_id) + if inputs != example.inputs or outputs != example.outputs: + self.client.update_example( + example_id=example.id, + inputs=inputs, + outputs=outputs, + ) + except ls_utils.LangSmithNotFoundError: + example = self.client.create_example( + example_id=example_id, + inputs=inputs, + outputs=outputs, + dataset_id=self.id, + ) + if example.modified_at: + self.update_version(example.modified_at) + + def wait(self): + self._executor.shutdown(wait=True) + + +class _UTExtra(TypedDict, total=False): + client: Optional[ls_client.Client] + id: Optional[uuid.UUID] + output_keys: Optional[Sequence[str]] + test_suite_name: Optional[str] + + +def _ensure_example( + func: Callable, *args: Any, langtest_extra: _UTExtra, **kwargs: Any +) -> Tuple[_LangSmithTestSuite, uuid.UUID]: + # 1. check if the id exists. + # TODOs: Local cache + prefer a peek operation + client = langtest_extra["client"] or ls_client.Client() + output_keys = langtest_extra["output_keys"] + signature = inspect.signature(func) + # 2. Create the example + inputs: dict = rh._get_inputs_safe(signature, *args, **kwargs) + example_id = langtest_extra["id"] or _get_id(func, inputs) + outputs = {} + if output_keys: + for k in output_keys: + outputs[k] = inputs.pop(k, None) + # TODO: Support multiple test suites + test_suite = _LangSmithTestSuite.get_singleton(client) + test_suite.sync_example(example_id, inputs, outputs) + return test_suite, example_id + + +def _run_test(func, *test_args, langtest_extra: _UTExtra, **test_kwargs): + test_suite, example_id = _ensure_example( + func, *test_args, **test_kwargs, langtest_extra=langtest_extra + ) + run_id = uuid.uuid4() + + try: + func_ = func if rh.is_traceable_function(func) else rh.traceable(func) + func_( + *test_args, + **test_kwargs, + langsmith_extra={ + "run_id": run_id, + "reference_example_id": example_id, + "project_name": test_suite.name, + }, + ) + except BaseException as e: + test_suite.submit_result(run_id, error=repr(e)) + raise e + try: + test_suite.submit_result(run_id, error=None) + except BaseException as e: + logger.warning(f"Failed to create feedback for run_id {run_id}: {e}") diff --git a/python/langsmith/client.py b/python/langsmith/client.py index 4835234d0..2908f1ecc 100644 --- a/python/langsmith/client.py +++ b/python/langsmith/client.py @@ -3079,7 +3079,7 @@ def list_examples( def update_example( self, - example_id: str, + example_id: ID_TYPE, *, inputs: Optional[Dict[str, Any]] = None, outputs: Optional[Mapping[str, Any]] = None, diff --git a/python/langsmith/run_helpers.py b/python/langsmith/run_helpers.py index 17194ca7e..f231f3525 100644 --- a/python/langsmith/run_helpers.py +++ b/python/langsmith/run_helpers.py @@ -128,6 +128,16 @@ def _get_inputs( return arguments +def _get_inputs_safe( + signature: inspect.Signature, *args: Any, **kwargs: Any +) -> Dict[str, Any]: + try: + return _get_inputs(signature, *args, **kwargs) + except BaseException as e: + logger.debug(f"Failed to get inputs for {signature}: {e}") + return {"args": args, "kwargs": kwargs} + + class LangSmithExtra(TypedDict, total=False): """Any additional info to be injected into the run dynamically.""" @@ -259,11 +269,7 @@ def _setup_run( metadata_.update(metadata or {}) metadata_["ls_method"] = "traceable" extra_inner["metadata"] = metadata_ - try: - inputs = _get_inputs(signature, *args, **kwargs) - except TypeError as e: - logger.debug(f"Failed to infer inputs for {name_}: {e}") - inputs = {"args": args, "kwargs": kwargs} + inputs = _get_inputs_safe(signature, *args, **kwargs) process_inputs = container_input.get("process_inputs") if process_inputs: try: diff --git a/python/langsmith/run_trees.py b/python/langsmith/run_trees.py index a2d49dadb..15dc57c93 100644 --- a/python/langsmith/run_trees.py +++ b/python/langsmith/run_trees.py @@ -106,6 +106,19 @@ def add_metadata(self, metadata: Dict[str, Any]) -> None: metadata_: dict = self.extra.setdefault("metadata", {}) metadata_.update(metadata) + def add_outputs(self, outputs: Dict[str, Any]) -> None: + """Upsert the given outputs into the run. + + Args: + outputs (Dict[str, Any]): A dictionary containing the outputs to be added. + + Returns: + None + """ + if self.outputs is None: + self.outputs = {} + self.outputs.update(outputs) + def add_event( self, events: Union[ @@ -154,7 +167,10 @@ def end( """Set the end time of the run and all child runs.""" self.end_time = end_time or datetime.now(timezone.utc) if outputs is not None: - self.outputs = outputs + if not self.outputs: + self.outputs = outputs + else: + self.outputs.update(outputs) if error is not None: self.error = error if events is not None: