Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add test caching #589

Merged
merged 30 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions .github/ISSUE_TEMPLATE/bug-report.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
name: "\U0001F41B Bug Report"
description: Submit a bug report to help us improve LangSmith. To report a _security_ issue, please instead use the security option below.
description: "Submit a bug report to help us improve LangSmith. To report a _security_ issue, please instead use the security option below."
labels: ["01 Bug Report"]
body:
- type: markdown
Expand All @@ -15,15 +15,15 @@ body:
label: Tracing Method
description: "Select whether you are tracing using LangChain or some other method:"
options:
- label: "With LangChain"
- label: "SDK/Client"
- label: "REST API"
- label: "With LangChain"
- label: "Other"

- type: checkboxes
id: runtime-language
attributes:
label: Runtime Language
label: Language
description: ""
options:
- label: "Python"
Expand All @@ -33,11 +33,10 @@ body:
- type: checkboxes
id: platform-environment
attributes:
label: LangSmith Platform Environment
label: Host
description: "Indicate whether you are connected to the hosted LangSmith platform or running locally."
options:
- label: "Hosted (https://api.smith.langchain.com)"
- label: "Local (http://localhost:1984)"
- label: "Self-hosted"
- label: "Other"

Expand All @@ -48,7 +47,7 @@ body:
description: Please share any other system info with us. You can view this by running `langsmith env` in your terminal.
placeholder: LangSmith SDK version, client runtime information,
validations:
required: false
required: true

- type: textarea
id: reproduction
Expand Down
3 changes: 2 additions & 1 deletion .github/actions/python-integration-tests/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ runs:
- name: Install dependencies
run: |
poetry install --with dev
poetry run pip install -U langchain langchain_anthropic tiktoken rapidfuzz
poetry run pip install -U langchain langchain_anthropic tiktoken rapidfuzz vcrpy
shell: bash
working-directory: python

Expand All @@ -52,6 +52,7 @@ runs:
LANGCHAIN_API_KEY: ${{ inputs.langchain-api-key }}
OPENAI_API_KEY: ${{ inputs.openai-api-key }}
ANTHROPIC_API_KEY: ${{ inputs.anthropic-api-key }}
LANGCHAIN_TEST_CACHE: "tests/cassettes"
run: make doctest
shell: bash
working-directory: python
114 changes: 88 additions & 26 deletions python/langsmith/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import threading
import uuid
import warnings
from pathlib import Path
from typing import Any, Callable, Optional, Sequence, Tuple, TypeVar, overload

from typing_extensions import TypedDict
Expand Down Expand Up @@ -69,6 +70,15 @@ def unit(*args: Any, **kwargs: Any) -> Callable:
Returns:
Callable: The decorated test function.

Environment:
- LANGCHAIN_TEST_CACHE: If set, API calls will be cached to disk to
save time and costs during testing. Recommended to commit the
cache files to your repository for faster CI/CD runs.
Requires the 'langsmith[vcr]' package to be installed.
- LANGCHAIN_TEST_TRACKING: Set this variable to the path of a directory
to enable caching of test results. This is useful for re-running tests
without re-executing the code. Requires the 'langsmith[vcr]' package.

Example:
For basic usage, simply decorate a test function with `@unit`:

Expand All @@ -81,6 +91,30 @@ def unit(*args: Any, **kwargs: Any) -> Callable:
or `wrap_*` functions) will be traced within the test case for
improved visibility and debugging.

>>> from langsmith import traceable
>>> @traceable
... def generate_numbers():
... return 3, 4

>>> @unit
... def test_nested():
... # Traced code will be included in the test case
... a, b = generate_numbers()
... assert a + b == 7

LLM calls are expensive! Cache requests by setting
`LANGCHAIN_TEST_CACHE=path/to/cache`. Check in these files to speed up
CI/CD pipelines, so your results only change when your prompt or requested
model changes.

Note that this will require that you install langsmith with the `vcr` extra:

`pip install -U "langsmith[vcr]"`

Caching is faster if you install libyaml. See
https://vcrpy.readthedocs.io/en/latest/installation.html#speed for more details.

>>> os.environ["LANGCHAIN_TEST_CACHE"] = "tests/cassettes"
>>> import openai
>>> from langsmith.wrappers import wrap_openai
>>> @unit
Expand Down Expand Up @@ -145,6 +179,7 @@ def unit(*args: Any, **kwargs: Any) -> Callable:

To run these tests, use the pytest CLI. Or directly run the test functions.
>>> test_addition()
>>> test_nested()
>>> test_with_fixture("Some input")
>>> test_with_expected_output("Some input", "Some")
>>> test_multiplication()
Expand All @@ -156,11 +191,21 @@ def unit(*args: Any, **kwargs: Any) -> Callable:
output_keys=kwargs.pop("output_keys", None),
client=kwargs.pop("client", None),
test_suite_name=kwargs.pop("test_suite_name", None),
cache=ls_utils.get_cache_dir(kwargs.pop("cache", None)),
)
if kwargs:
warnings.warn(f"Unexpected keyword arguments: {kwargs.keys()}")
disable_tracking = os.environ.get("LANGCHAIN_TEST_TRACKING") == "false"
if disable_tracking:
warnings.warn(
"LANGCHAIN_TEST_TRACKING is set to 'false'."
" Skipping LangSmith test tracking."
)

if args and callable(args[0]):
func = args[0]
if disable_tracking:
return func

@functools.wraps(func)
def wrapper(*test_args, **test_kwargs):
Expand All @@ -176,6 +221,8 @@ def wrapper(*test_args, **test_kwargs):
def decorator(func):
@functools.wraps(func)
def wrapper(*test_args, **test_kwargs):
if disable_tracking:
return func(*test_args, **test_kwargs)
_run_test(func, *test_args, **test_kwargs, langtest_extra=langtest_extra)

return wrapper
Expand All @@ -188,7 +235,7 @@ def wrapper(*test_args, **test_kwargs):

def _get_experiment_name() -> str:
# TODO Make more easily configurable
prefix = ls_utils.get_tracer_project(False) or "TestSuite"
prefix = ls_utils.get_tracer_project(False) or "TestSuiteResult"
hinthornw marked this conversation as resolved.
Show resolved Hide resolved
name = f"{prefix}:{uuid.uuid4().hex[:8]}"
return name

Expand All @@ -199,13 +246,13 @@ def _get_test_suite_name() -> str:
if test_suite_name:
return test_suite_name
if __package__:
return __package__
return __package__ + " Test Suite"
git_info = ls_env.get_git_info()
if git_info:
if git_info["remote_url"]:
repo_name = git_info["remote_url"].split("/")[-1].split(".")[0]
if repo_name:
return repo_name
return repo_name + " Test Suite"
raise ValueError("Please set the LANGCHAIN_TEST_SUITE environment variable.")


Expand All @@ -221,16 +268,19 @@ def _get_test_suite(client: ls_client.Client) -> ls_schemas.Dataset:
def _start_experiment(
client: ls_client.Client,
test_suite: ls_schemas.Dataset,
) -> ls_schemas.TracerSessionResult:
) -> ls_schemas.TracerSession:
experiment_name = _get_experiment_name()
return client.create_project(experiment_name, reference_dataset_id=test_suite.id)


def _get_id(func: Callable, inputs: dict) -> uuid.UUID:
try:
file_path = str(Path(inspect.getfile(func)).relative_to(Path.cwd()))
except ValueError:
# Fall back to module name if file path is not available
file_path = func.__module__
input_json = json.dumps(inputs, sort_keys=True)
identifier = f"{func.__module__}.{func.__name__}_{input_json}"

# Generate a UUID based on the identifier
identifier = f"{file_path}::{func.__name__}{input_json}"
return uuid.uuid5(uuid.NAMESPACE_DNS, identifier)


Expand All @@ -253,7 +303,7 @@ class _LangSmithTestSuite:
def __init__(
self,
client: Optional[ls_client.Client],
experiment: ls_schemas.TracerSessionResult,
experiment: ls_schemas.TracerSession,
dataset: ls_schemas.Dataset,
):
self.client = client or ls_client.Client()
Expand Down Expand Up @@ -338,6 +388,7 @@ class _UTExtra(TypedDict, total=False):
id: Optional[uuid.UUID]
output_keys: Optional[Sequence[str]]
test_suite_name: Optional[str]
cache: Optional[str]


def _ensure_example(
Expand Down Expand Up @@ -367,21 +418,32 @@ def _run_test(func, *test_args, langtest_extra: _UTExtra, **test_kwargs):
)
run_id = uuid.uuid4()

try:
func_ = func if rh.is_traceable_function(func) else rh.traceable(func)
func_(
*test_args,
**test_kwargs,
langsmith_extra={
"run_id": run_id,
"reference_example_id": example_id,
"project_name": test_suite.name,
},
)
except BaseException as e:
test_suite.submit_result(run_id, error=repr(e))
raise e
try:
test_suite.submit_result(run_id, error=None)
except BaseException as e:
logger.warning(f"Failed to create feedback for run_id {run_id}: {e}")
def _test():
try:
func_ = func if rh.is_traceable_function(func) else rh.traceable(func)
func_(
*test_args,
**test_kwargs,
langsmith_extra={
"run_id": run_id,
"reference_example_id": example_id,
"project_name": test_suite.name,
},
)
except BaseException as e:
test_suite.submit_result(run_id, error=repr(e))
raise e
try:
test_suite.submit_result(run_id, error=None)
except BaseException as e:
logger.warning(f"Failed to create feedback for run_id {run_id}: {e}")

cache_path = (
Path(langtest_extra["cache"]) / f"{test_suite.id}.yaml"
if langtest_extra["cache"]
else None
)
with ls_utils.with_optional_cache(
cache_path, ignore_hosts=[test_suite.client.api_url]
):
_test()
18 changes: 12 additions & 6 deletions python/langsmith/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,9 @@ def info(self) -> ls_schemas.LangSmithInfo:
ls_utils.raise_for_status_with_text(response)
self._info = ls_schemas.LangSmithInfo(**response.json())
except BaseException as e:
logger.warning(f"Failed to get info from {self.api_url}: {repr(e)}")
logger.warning(
f"Failed to get info from {self.api_url}: {repr(e)}",
)
self._info = ls_schemas.LangSmithInfo()
return self._info

Expand Down Expand Up @@ -810,7 +812,11 @@ def _get_paginated_list(
params_["limit"] = params_.get("limit", 100)
while True:
params_["offset"] = offset
response = self.request_with_retries("GET", path, params=params_)
response = self.request_with_retries(
"GET",
path,
params=params_,
)
items = response.json()

if not items:
Expand Down Expand Up @@ -1012,13 +1018,13 @@ def _run_transform(
dict: The transformed run object as a dictionary.
"""
if hasattr(run, "dict") and callable(getattr(run, "dict")):
run_create = run.dict() # type: ignore
run_create: dict = run.dict() # type: ignore
else:
run_create = cast(dict, run)
if "id" not in run_create:
run_create["id"] = uuid.uuid4()
elif isinstance(run["id"], str):
run["id"] = uuid.UUID(run["id"])
elif isinstance(run_create["id"], str):
run_create["id"] = uuid.UUID(run_create["id"])
if "inputs" in run_create and run_create["inputs"] is not None:
run_create["inputs"] = self._hide_run_inputs(run_create["inputs"])
if "outputs" in run_create and run_create["outputs"] is not None:
Expand Down Expand Up @@ -3161,7 +3167,7 @@ def _resolve_run_id(
if isinstance(run, (str, uuid.UUID)):
run_ = self.read_run(run, load_child_runs=load_child_runs)
else:
run_ = run
run_ = cast(ls_schemas.Run, run)
return run_

def _resolve_example_id(
Expand Down
Loading
Loading