From 2066011023fab751f38c5452e059ea415a87630c Mon Sep 17 00:00:00 2001 From: William FH <13333726+hinthornw@users.noreply.github.com> Date: Sat, 13 Apr 2024 00:17:33 -0700 Subject: [PATCH] Support Context Propagation (#599) Client Side: ``` async def the_parent_function(): async with AsyncClient(app=fake_app, base_url="http://localhost:8000") as client: headers = {} if span := get_current_span(): headers.update(span.to_headers()) return await client.post("/fake-route", headers=headers) ``` Server Side: ``` @fake_app.post("/fake-route") async def fake_route(request: Request): with tracing_context(headers=request.headers): fake_function() return {"message": "Fake route response"} ``` If people like, we could add some fun middleware, but probably not necessary --- python/langsmith/_testing.py | 86 +++++++--- python/langsmith/run_helpers.py | 53 ++++-- python/langsmith/run_trees.py | 155 +++++++++++++++++- python/poetry.lock | 58 ++++++- python/pyproject.toml | 4 +- python/tests/evaluation/test_evaluation.py | 7 + python/tests/integration_tests/fake_server.py | 54 ++++++ .../test_context_propagation.py | 59 +++++++ python/tests/unit_tests/test_run_trees.py | 46 ++++++ 9 files changed, 484 insertions(+), 38 deletions(-) create mode 100644 python/tests/integration_tests/fake_server.py create mode 100644 python/tests/integration_tests/test_context_propagation.py diff --git a/python/langsmith/_testing.py b/python/langsmith/_testing.py index d3d44a937..71de426d8 100644 --- a/python/langsmith/_testing.py +++ b/python/langsmith/_testing.py @@ -229,31 +229,29 @@ def unit(*args: Any, **kwargs: Any) -> Callable: " Skipping LangSmith test tracking." ) - if args and callable(args[0]): - func = args[0] - if disable_tracking: - return func - - @functools.wraps(func) - def wrapper(*test_args, **test_kwargs): - _run_test( - func, - *test_args, - **test_kwargs, - langtest_extra=langtest_extra, - ) + def decorator(func: Callable) -> Callable: + if inspect.iscoroutinefunction(func): + + async def async_wrapper(*test_args: Any, **test_kwargs: Any): + if disable_tracking: + return await func(*test_args, **test_kwargs) + await _arun_test( + func, *test_args, **test_kwargs, langtest_extra=langtest_extra + ) - return wrapper + return async_wrapper - def decorator(func): @functools.wraps(func) - def wrapper(*test_args, **test_kwargs): + def wrapper(*test_args: Any, **test_kwargs: Any): if disable_tracking: return func(*test_args, **test_kwargs) _run_test(func, *test_args, **test_kwargs, langtest_extra=langtest_extra) return wrapper + if args and callable(args[0]): + return decorator(args[0]) + return decorator @@ -470,12 +468,9 @@ def _get_test_repr(func: Callable, sig: inspect.Signature) -> str: def _ensure_example( func: Callable, *args: Any, langtest_extra: _UTExtra, **kwargs: Any ) -> Tuple[_LangSmithTestSuite, uuid.UUID]: - # 1. check if the id exists. - # TODOs: Local cache + prefer a peek operation client = langtest_extra["client"] or ls_client.Client() output_keys = langtest_extra["output_keys"] signature = inspect.signature(func) - # 2. Create the example inputs: dict = rh._get_inputs_safe(signature, *args, **kwargs) outputs = {} if output_keys: @@ -492,7 +487,9 @@ def _ensure_example( return test_suite, example_id -def _run_test(func, *test_args, langtest_extra: _UTExtra, **test_kwargs): +def _run_test( + func: Callable, *test_args: Any, langtest_extra: _UTExtra, **test_kwargs: Any +) -> None: test_suite, example_id = _ensure_example( func, *test_args, **test_kwargs, langtest_extra=langtest_extra ) @@ -537,3 +534,52 @@ def _test(): cache_path, ignore_hosts=[test_suite.client.api_url] ): _test() + + +async def _arun_test( + func: Callable, *test_args: Any, langtest_extra: _UTExtra, **test_kwargs: Any +) -> None: + test_suite, example_id = _ensure_example( + func, *test_args, **test_kwargs, langtest_extra=langtest_extra + ) + run_id = uuid.uuid4() + + async def _test(): + try: + func_ = func if rh.is_traceable_function(func) else rh.traceable(func) + await func_( + *test_args, + **test_kwargs, + langsmith_extra={ + "run_id": run_id, + "reference_example_id": example_id, + "project_name": test_suite.name, + }, + ) + except BaseException as e: + test_suite.submit_result(run_id, error=repr(e)) + raise e + try: + test_suite.submit_result(run_id, error=None) + except BaseException as e: + logger.warning(f"Failed to create feedback for run_id {run_id}: {e}") + + cache_path = ( + Path(langtest_extra["cache"]) / f"{test_suite.id}.yaml" + if langtest_extra["cache"] + else None + ) + current_context = rh.get_tracing_context() + metadata = { + **(current_context["metadata"] or {}), + **{ + "experiment": test_suite.experiment.name, + "reference_example_id": str(example_id), + }, + } + with rh.tracing_context( + **{**current_context, "metadata": metadata} + ), ls_utils.with_optional_cache( + cache_path, ignore_hosts=[test_suite.client.api_url] + ): + await _test() diff --git a/python/langsmith/run_helpers.py b/python/langsmith/run_helpers.py index f231f3525..c2d65bb5f 100644 --- a/python/langsmith/run_helpers.py +++ b/python/langsmith/run_helpers.py @@ -55,7 +55,7 @@ def get_current_run_tree() -> Optional[run_trees.RunTree]: def get_tracing_context() -> dict: """Get the current tracing context.""" return { - "parent_run": _PARENT_RUN_TREE.get(), + "parent": _PARENT_RUN_TREE.get(), "project_name": _PROJECT_NAME.get(), "tags": _TAGS.get(), "metadata": _METADATA.get(), @@ -68,14 +68,25 @@ def tracing_context( project_name: Optional[str] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, - parent_run: Optional[run_trees.RunTree] = None, + parent: Optional[Union[run_trees.RunTree, Mapping, str]] = None, + **kwargs: Any, ) -> Generator[None, None, None]: """Set the tracing context for a block of code.""" - parent_run_ = get_run_tree_context() + if kwargs: + # warn + warnings.warn( + f"Unrecognized keyword arguments: {kwargs}.", + DeprecationWarning, + ) + parent_run_ = get_current_run_tree() _PROJECT_NAME.set(project_name) + parent_run = _get_parent_run({"parent": parent or kwargs.get("parent_run")}) + if parent_run is not None: + _PARENT_RUN_TREE.set(parent_run) + tags = sorted(set(tags or []) | set(parent_run.tags or [])) + metadata = {**parent_run.metadata, **(metadata or {})} _TAGS.set(tags) _METADATA.set(metadata) - _PARENT_RUN_TREE.set(parent_run) try: yield finally: @@ -85,6 +96,7 @@ def tracing_context( _PARENT_RUN_TREE.set(parent_run_) +# Alias for backwards compatibility get_run_tree_context = get_current_run_tree @@ -143,7 +155,8 @@ class LangSmithExtra(TypedDict, total=False): reference_example_id: Optional[ls_client.ID_TYPE] run_extra: Optional[Dict] - run_tree: Optional[run_trees.RunTree] + parent: Optional[Union[run_trees.RunTree, str, Mapping]] + run_tree: Optional[run_trees.RunTree] # TODO: Deprecate project_name: Optional[str] metadata: Optional[Dict[str, Any]] tags: Optional[List[str]] @@ -212,6 +225,20 @@ def _collect_extra(extra_outer: dict, langsmith_extra: LangSmithExtra) -> dict: return extra_inner +def _get_parent_run(langsmith_extra: LangSmithExtra) -> Optional[run_trees.RunTree]: + parent = langsmith_extra.get("parent") + if isinstance(parent, run_trees.RunTree): + return parent + if isinstance(parent, dict): + return run_trees.RunTree.from_headers(parent) + if isinstance(parent, str): + return run_trees.RunTree.from_dotted_order(parent) + run_tree = langsmith_extra.get("run_tree") + if run_tree: + return run_tree + return get_current_run_tree() + + def _setup_run( func: Callable, container_input: _ContainerInput, @@ -228,7 +255,7 @@ def _setup_run( run_type = container_input.get("run_type") or "chain" outer_project = _PROJECT_NAME.get() langsmith_extra = langsmith_extra or LangSmithExtra() - parent_run_ = langsmith_extra.get("run_tree") or get_run_tree_context() + parent_run_ = _get_parent_run(langsmith_extra) project_cv = _PROJECT_NAME.get() selected_project = ( project_cv # From parent trace @@ -577,7 +604,7 @@ async def async_wrapper( **kwargs: Any, ) -> Any: """Async version of wrapper function.""" - context_run = get_run_tree_context() + context_run = get_current_run_tree() run_container = _setup_run( func, container_input=container_input, @@ -611,7 +638,7 @@ async def async_wrapper( async def async_generator_wrapper( *args: Any, langsmith_extra: Optional[LangSmithExtra] = None, **kwargs: Any ) -> AsyncGenerator: - context_run = get_run_tree_context() + context_run = get_current_run_tree() run_container = _setup_run( func, container_input=container_input, @@ -683,7 +710,7 @@ def wrapper( **kwargs: Any, ) -> Any: """Create a new run or create_child() if run is passed in kwargs.""" - context_run = get_run_tree_context() + context_run = get_current_run_tree() run_container = _setup_run( func, container_input=container_input, @@ -717,7 +744,7 @@ def wrapper( def generator_wrapper( *args: Any, langsmith_extra: Optional[LangSmithExtra] = None, **kwargs: Any ) -> Any: - context_run = get_run_tree_context() + context_run = get_current_run_tree() run_container = _setup_run( func, container_input=container_input, @@ -808,7 +835,7 @@ def trace( inputs: Optional[Dict] = None, extra: Optional[Dict] = None, project_name: Optional[str] = None, - run_tree: Optional[run_trees.RunTree] = None, + parent: Optional[Union[run_trees.RunTree, str, Mapping]] = None, tags: Optional[List[str]] = None, metadata: Optional[Mapping[str, Any]] = None, client: Optional[ls_client.Client] = None, @@ -825,7 +852,9 @@ def trace( outer_tags = _TAGS.get() outer_metadata = _METADATA.get() outer_project = _PROJECT_NAME.get() or utils.get_tracer_project() - parent_run_ = get_run_tree_context() if run_tree is None else run_tree + parent_run_ = _get_parent_run( + {"parent": parent, "run_tree": kwargs.get("run_tree")} + ) # Merge and set context variables tags_ = sorted(set((tags or []) + (outer_tags or []))) diff --git a/python/langsmith/run_trees.py b/python/langsmith/run_trees.py index 40ce3af1b..03b59ba7c 100644 --- a/python/langsmith/run_trees.py +++ b/python/langsmith/run_trees.py @@ -2,9 +2,10 @@ from __future__ import annotations +import json import logging from datetime import datetime, timezone -from typing import Any, Dict, List, Optional, Sequence, Union, cast +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast from uuid import UUID, uuid4 try: @@ -12,12 +13,17 @@ except ImportError: from pydantic import Field, root_validator, validator +import urllib.parse + from langsmith import schemas as ls_schemas from langsmith import utils -from langsmith.client import ID_TYPE, RUN_TYPE_T, Client +from langsmith.client import ID_TYPE, RUN_TYPE_T, Client, _dumps_json logger = logging.getLogger(__name__) +LANGSMITH_PREFIX = "langsmith-" +LANGSMITH_DOTTED_ORDER = f"{LANGSMITH_PREFIX}trace" + class RunTree(ls_schemas.RunBase): """Run Schema with back-references for posting runs.""" @@ -80,8 +86,8 @@ def ensure_dotted_order(cls, values: dict) -> dict: current_dotted_order = values.get("dotted_order") if current_dotted_order and current_dotted_order.strip(): return values - current_dotted_order = values["start_time"].strftime("%Y%m%dT%H%M%S%fZ") + str( - values["id"] + current_dotted_order = _create_current_dotted_order( + values["start_time"], values["id"] ) if values["parent_run"]: values["dotted_order"] = ( @@ -263,3 +269,144 @@ def wait(self) -> None: def get_url(self) -> str: """Return the URL of the run.""" return self.client.get_run_url(run=self) + + @classmethod + def from_dotted_order( + cls, + dotted_order: str, + **kwargs: Any, + ) -> RunTree: + """Create a new 'child' span from the provided dotted order. + + Returns: + RunTree: The new span. + """ + headers = { + f"{LANGSMITH_DOTTED_ORDER}": dotted_order, + } + return cast(RunTree, cls.from_headers(headers, **kwargs)) + + @classmethod + def from_headers(cls, headers: Dict[str, str], **kwargs: Any) -> Optional[RunTree]: + """Create a new 'parent' span from the provided headers. + + Extracts parent span information from the headers and creates a new span. + Metadata and tags are extracted from the baggage header. + The dotted order and trace id are extracted from the trace header. + + Returns: + Optional[RunTree]: The new span or None if + no parent span information is found. + """ + init_args = kwargs.copy() + + langsmith_trace = headers.get(f"{LANGSMITH_DOTTED_ORDER}") + if not langsmith_trace: + return # type: ignore[return-value] + + parent_dotted_order = langsmith_trace.strip() + parsed_dotted_order = _parse_dotted_order(parent_dotted_order) + trace_id = parsed_dotted_order[0][1] + init_args["trace_id"] = trace_id + init_args["id"] = parsed_dotted_order[-1][1] + init_args["dotted_order"] = parent_dotted_order + # All placeholders. We assume the source process + # handles the life-cycle of the run. + init_args["start_time"] = init_args.get("start_time") or datetime.now( + timezone.utc + ) + init_args["run_type"] = init_args.get("run_type") or "chain" + init_args["name"] = init_args.get("name") or "parent" + + baggage = _Baggage.from_header(headers.get("baggage")) + if baggage.metadata or baggage.tags: + init_args["extra"] = init_args.setdefault("extra", {}) + init_args["extra"]["metadata"] = init_args["extra"].setdefault( + "metadata", {} + ) + metadata = {**baggage.metadata, **init_args["extra"]["metadata"]} + init_args["extra"]["metadata"] = metadata + tags = sorted(set(baggage.tags + init_args.get("tags", []))) + init_args["tags"] = tags + + return RunTree(**init_args) + + def to_headers(self) -> Dict[str, str]: + """Return the RunTree as a dictionary of headers.""" + headers = {} + if self.trace_id: + headers[f"{LANGSMITH_DOTTED_ORDER}"] = self.dotted_order + baggage = _Baggage( + metadata=self.extra.get("metadata", {}), + tags=self.tags, + ) + headers["baggage"] = baggage.to_header() + return headers + + +class _Baggage: + """Baggage header information.""" + + def __init__( + self, + metadata: Optional[Dict[str, str]] = None, + tags: Optional[List[str]] = None, + ): + """Initialize the Baggage object.""" + self.metadata = metadata or {} + self.tags = tags or [] + + @classmethod + def from_header(cls, header_value: Optional[str]) -> _Baggage: + """Create a Baggage object from the given header value.""" + if not header_value: + return cls() + metadata = {} + tags = [] + try: + for item in header_value.split(","): + key, value = item.split("=", 1) + if key == f"{LANGSMITH_PREFIX}metadata": + metadata = json.loads(urllib.parse.unquote(value)) + elif key == f"{LANGSMITH_PREFIX}tags": + tags = urllib.parse.unquote(value).split(",") + except Exception as e: + logger.warning(f"Error parsing baggage header: {e}") + + return cls(metadata=metadata, tags=tags) + + def to_header(self) -> str: + """Return the Baggage object as a header value.""" + items = [] + if self.metadata: + serialized_metadata = _dumps_json(self.metadata) + items.append( + f"{LANGSMITH_PREFIX}metadata={urllib.parse.quote(serialized_metadata)}" + ) + if self.tags: + serialized_tags = ",".join(self.tags) + items.append( + f"{LANGSMITH_PREFIX}tags={urllib.parse.quote(serialized_tags)}" + ) + return ",".join(items) + + +def _parse_dotted_order(dotted_order: str) -> List[Tuple[datetime, UUID]]: + """Parse the dotted order string.""" + parts = dotted_order.split(".") + return [ + (datetime.strptime(part[:-36], "%Y%m%dT%H%M%S%fZ"), UUID(part[-36:])) + for part in parts + ] + + +def _create_current_dotted_order( + start_time: Optional[datetime], run_id: Optional[UUID] +) -> str: + """Create the current dotted order.""" + st = start_time or datetime.now(timezone.utc) + id_ = run_id or uuid4() + return st.strftime("%Y%m%dT%H%M%S%fZ") + str(id_) + + +__all__ = ["RunTree", "RunTree"] diff --git a/python/poetry.lock b/python/poetry.lock index bc6dae9e4..eb25cad70 100644 --- a/python/poetry.lock +++ b/python/poetry.lock @@ -343,6 +343,25 @@ files = [ [package.extras] testing = ["hatch", "pre-commit", "pytest", "tox"] +[[package]] +name = "fastapi" +version = "0.110.1" +description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production" +optional = false +python-versions = ">=3.8" +files = [ + {file = "fastapi-0.110.1-py3-none-any.whl", hash = "sha256:5df913203c482f820d31f48e635e022f8cbfe7350e4830ef05a3163925b1addc"}, + {file = "fastapi-0.110.1.tar.gz", hash = "sha256:6feac43ec359dfe4f45b2c18ec8c94edb8dc2dfc461d417d9e626590c071baad"}, +] + +[package.dependencies] +pydantic = ">=1.7.4,<1.8 || >1.8,<1.8.1 || >1.8.1,<2.0.0 || >2.0.0,<2.0.1 || >2.0.1,<2.1.0 || >2.1.0,<3.0.0" +starlette = ">=0.37.2,<0.38.0" +typing-extensions = ">=4.8.0" + +[package.extras] +all = ["email-validator (>=2.0.0)", "httpx (>=0.23.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=2.11.2)", "orjson (>=3.2.1)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.7)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"] + [[package]] name = "freezegun" version = "1.4.0" @@ -1137,6 +1156,24 @@ files = [ {file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"}, ] +[[package]] +name = "starlette" +version = "0.37.2" +description = "The little ASGI library that shines." +optional = false +python-versions = ">=3.8" +files = [ + {file = "starlette-0.37.2-py3-none-any.whl", hash = "sha256:6fe59f29268538e5d0d182f2791a479a0c64638e6935d1c6989e63fb2699c6ee"}, + {file = "starlette-0.37.2.tar.gz", hash = "sha256:9af890290133b79fc3db55474ade20f6220a364a0402e0b556e7cd5e1e093823"}, +] + +[package.dependencies] +anyio = ">=3.4.0,<5" +typing-extensions = {version = ">=3.10.0", markers = "python_version < \"3.10\""} + +[package.extras] +full = ["httpx (>=0.22.0)", "itsdangerous", "jinja2", "python-multipart (>=0.0.7)", "pyyaml"] + [[package]] name = "tomli" version = "2.0.1" @@ -1310,6 +1347,25 @@ h2 = ["h2 (>=4,<5)"] socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] zstd = ["zstandard (>=0.18.0)"] +[[package]] +name = "uvicorn" +version = "0.29.0" +description = "The lightning-fast ASGI server." +optional = false +python-versions = ">=3.8" +files = [ + {file = "uvicorn-0.29.0-py3-none-any.whl", hash = "sha256:2c2aac7ff4f4365c206fd773a39bf4ebd1047c238f8b8268ad996829323473de"}, + {file = "uvicorn-0.29.0.tar.gz", hash = "sha256:6a69214c0b6a087462412670b3ef21224fa48cae0e452b5883e8e8bdfdd11dd0"}, +] + +[package.dependencies] +click = ">=7.0" +h11 = ">=0.8" +typing-extensions = {version = ">=4.0", markers = "python_version < \"3.11\""} + +[package.extras] +standard = ["colorama (>=0.4)", "httptools (>=0.5.0)", "python-dotenv (>=0.13)", "pyyaml (>=5.1)", "uvloop (>=0.14.0,!=0.15.0,!=0.15.1)", "watchfiles (>=0.13)", "websockets (>=10.4)"] + [[package]] name = "vcrpy" version = "6.0.1" @@ -1558,4 +1614,4 @@ vcr = [] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "3724956de2f14ef761403e17016a212a7de110b726184cf2e685425cb60261d1" +content-hash = "a347c2aca058e8c1b9fbcdc7043a5ed53ba03485759adce207d249e4495e05c6" diff --git a/python/pyproject.toml b/python/pyproject.toml index 8c9fe1693..6253b0428 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langsmith" -version = "0.1.46" +version = "0.1.47" description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." authors = ["LangChain "] license = "MIT" @@ -50,6 +50,8 @@ pytest-cov = "^4.1.0" dataclasses-json = "^0.6.4" types-tqdm = "^4.66.0.20240106" vcrpy = "^6.0.1" +fastapi = "^0.110.1" +uvicorn = "^0.29.0" [tool.poetry.group.lint.dependencies] openai = "^1.10" diff --git a/python/tests/evaluation/test_evaluation.py b/python/tests/evaluation/test_evaluation.py index dff223ac7..80d1d2a6c 100644 --- a/python/tests/evaluation/test_evaluation.py +++ b/python/tests/evaluation/test_evaluation.py @@ -95,3 +95,10 @@ def expected_output(): @unit(output_keys=["expected_output"]) def test_bar(some_input: str, expected_output: str): expect(some_input).to_contain(expected_output) + + +@unit +async def test_baz(): + await asyncio.sleep(0.1) + expect(3 + 4).to_equal(7) + return 7 diff --git a/python/tests/integration_tests/fake_server.py b/python/tests/integration_tests/fake_server.py new file mode 100644 index 000000000..93850d9da --- /dev/null +++ b/python/tests/integration_tests/fake_server.py @@ -0,0 +1,54 @@ +from fastapi import FastAPI, Request + +from langsmith import traceable +from langsmith.run_helpers import get_current_run_tree, trace, tracing_context + +fake_app = FastAPI() + + +@traceable +def fake_function(): + span = get_current_run_tree() + assert span is not None + parent_run = span.parent_run + assert parent_run is not None + assert "did-propagate" in span.tags or [] + assert span.metadata["some-cool-value"] == 42 + return "Fake function response" + + +@traceable +def fake_function_two(foo: str): + span = get_current_run_tree() + assert span is not None + parent_run = span.parent_run + assert parent_run is not None + assert "did-propagate" in (span.tags or []) + assert span.metadata["some-cool-value"] == 42 + return "Fake function response" + + +@traceable +def fake_function_three(foo: str): + span = get_current_run_tree() + assert span is not None + parent_run = span.parent_run + assert parent_run is not None + assert "did-propagate" in (span.tags or []) + assert span.metadata["some-cool-value"] == 42 + return "Fake function response" + + +@fake_app.post("/fake-route") +async def fake_route(request: Request): + with trace( + "Trace", + project_name="Definitely-not-your-grandpas-project", + parent=request.headers, + ): + fake_function() + fake_function_two("foo", langsmith_extra={"parent": request.headers}) + + with tracing_context(parent=request.headers): + fake_function_three("foo") + return {"message": "Fake route response"} diff --git a/python/tests/integration_tests/test_context_propagation.py b/python/tests/integration_tests/test_context_propagation.py new file mode 100644 index 000000000..32cd1f74d --- /dev/null +++ b/python/tests/integration_tests/test_context_propagation.py @@ -0,0 +1,59 @@ +import asyncio + +import pytest +from httpx import AsyncClient +from uvicorn import Config, Server + +from langsmith import traceable +from langsmith.run_helpers import get_current_run_tree +from tests.integration_tests.fake_server import fake_app + + +@pytest.fixture(scope="module") +def event_loop(): + loop = asyncio.get_event_loop() + yield loop + loop.close() + + +@pytest.fixture(scope="module") +async def fake_server(): + config = Config(app=fake_app, loop="asyncio", port=8000, log_level="info") + server = Server(config=config) + + asyncio.create_task(server.serve()) + await asyncio.sleep(0.1) + + yield + try: + await server.shutdown() + except RuntimeError: + pass + + +@traceable +async def the_parent_function(): + async with AsyncClient(app=fake_app, base_url="http://localhost:8000") as client: + headers = {} + if span := get_current_run_tree(): + headers.update(span.to_headers()) + response = await client.post("/fake-route", headers=headers) + assert response.status_code == 200 + return response.json() + + +@traceable +async def the_root_function(foo: str): + return await the_parent_function() + + +@pytest.mark.asyncio +async def test_tracing_fake_server(fake_server): + result = await the_root_function( + "test input", + langsmith_extra={ + "metadata": {"some-cool-value": 42}, + "tags": ["did-propagate"], + }, + ) + assert result["message"] == "Fake route response" diff --git a/python/tests/unit_tests/test_run_trees.py b/python/tests/unit_tests/test_run_trees.py index fa88e623c..92c398b2d 100644 --- a/python/tests/unit_tests/test_run_trees.py +++ b/python/tests/unit_tests/test_run_trees.py @@ -1,5 +1,9 @@ from concurrent.futures import ThreadPoolExecutor +from datetime import datetime from unittest.mock import MagicMock +from uuid import UUID + +import pytest from langsmith import run_trees from langsmith.client import Client @@ -13,3 +17,45 @@ def test_run_tree_accepts_tpe() -> None: client=mock_client, executor=ThreadPoolExecutor(), ) + + +@pytest.mark.parametrize( + "inputs, expected", + [ + ( + "20240412T202937370454Z152ce25c-064e-4742-bf36-8bb0389f8805.20240412T202937627763Zfe8b541f-e75a-4ee6-b92d-732710897194.20240412T202937708023Z625b30ed-2fbb-4387-81b1-cb5d6221e5b4.20240412T202937775748Z448dc09f-ad54-4475-b3a4-fa43018ca621.20240412T202937981350Z4cd59ea4-491e-4ed9-923f-48cd93e03755.20240412T202938078862Zcd168cf7-ee72-48c2-8ec0-50ab09821973.20240412T202938152278Z32481c1a-b83c-4b53-a52e-1ea893ffba51", + [ + ( + datetime(2024, 4, 12, 20, 29, 37, 370454), + UUID("152ce25c-064e-4742-bf36-8bb0389f8805"), + ), + ( + datetime(2024, 4, 12, 20, 29, 37, 627763), + UUID("fe8b541f-e75a-4ee6-b92d-732710897194"), + ), + ( + datetime(2024, 4, 12, 20, 29, 37, 708023), + UUID("625b30ed-2fbb-4387-81b1-cb5d6221e5b4"), + ), + ( + datetime(2024, 4, 12, 20, 29, 37, 775748), + UUID("448dc09f-ad54-4475-b3a4-fa43018ca621"), + ), + ( + datetime(2024, 4, 12, 20, 29, 37, 981350), + UUID("4cd59ea4-491e-4ed9-923f-48cd93e03755"), + ), + ( + datetime(2024, 4, 12, 20, 29, 38, 78862), + UUID("cd168cf7-ee72-48c2-8ec0-50ab09821973"), + ), + ( + datetime(2024, 4, 12, 20, 29, 38, 152278), + UUID("32481c1a-b83c-4b53-a52e-1ea893ffba51"), + ), + ], + ), + ], +) +def test_parse_dotted_order(inputs, expected): + assert run_trees._parse_dotted_order(inputs) == expected