diff --git a/python/langsmith/run_helpers.py b/python/langsmith/run_helpers.py index 473082b58..c19d30491 100644 --- a/python/langsmith/run_helpers.py +++ b/python/langsmith/run_helpers.py @@ -74,20 +74,23 @@ def tracing_context( tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, parent_run: Optional[run_trees.Span] = None, - headers: Optional[Dict[str, str]] = None, + headers: Optional[Union[Dict[str, str], Any]] = None, ) -> Generator[None, None, None]: """Set the tracing context for a block of code.""" parent_run_ = get_current_span() _PROJECT_NAME.set(project_name) - _TAGS.set(tags) - _METADATA.set(metadata) if parent_run is not None: _PARENT_RUN_TREE.set(parent_run) elif headers is not None: parent_run = run_trees.Span.from_headers(headers) + if parent_run: + tags = sorted(set(tags or []) | set(parent_run.tags or [])) + metadata = {**parent_run.metadata, **(metadata or {})} _PARENT_RUN_TREE.set(parent_run) else: _PARENT_RUN_TREE.set(None) + _TAGS.set(tags) + _METADATA.set(metadata) try: yield finally: @@ -96,6 +99,7 @@ def tracing_context( _METADATA.set(None) _PARENT_RUN_TREE.set(parent_run_) + # Alias for backwards compatibility get_current_run_tree = get_current_span get_run_tree_context = get_current_run_tree diff --git a/python/langsmith/run_trees.py b/python/langsmith/run_trees.py index 56c2a75fa..957bbe5a5 100644 --- a/python/langsmith/run_trees.py +++ b/python/langsmith/run_trees.py @@ -284,7 +284,7 @@ def from_headers(cls, headers: Dict[str, str], **kwargs: Any) -> Optional[Span]: langsmith_trace = headers.get(f"{LANGSMITH_PREFIX}trace") if not langsmith_trace: - return None + return # type: ignore[return-value] parent_dotted_order = langsmith_trace.strip() parsed_dotted_order = _parse_dotted_order(parent_dotted_order) @@ -350,11 +350,11 @@ def from_header(cls, header_value: Optional[str]) -> _Baggage: try: for item in header_value.split(","): key, value = item.split("=", 1) - if key == f"{LANGSMITH_PREFIX}-metadata": + if key == f"{LANGSMITH_PREFIX}metadata": metadata = json.loads(urllib.parse.unquote(value)) - elif key == f"{LANGSMITH_PREFIX}-tags": + elif key == f"{LANGSMITH_PREFIX}tags": tags = urllib.parse.unquote(value).split(",") - elif key == f"{LANGSMITH_PREFIX}-id": + elif key == f"{LANGSMITH_PREFIX}id": id_ = UUID(value) except Exception as e: logger.warning(f"Error parsing baggage header: {e}") @@ -383,7 +383,7 @@ def _parse_dotted_order(dotted_order: str) -> List[Tuple[datetime, UUID]]: """Parse the dotted order string.""" parts = dotted_order.split(".") return [ - (datetime.strptime(part[:23], "%Y%m%dT%H%M%S%fZ"), UUID(part[23:])) + (datetime.strptime(part[:-36], "%Y%m%dT%H%M%S%fZ"), UUID(part[-36:])) for part in parts ] diff --git a/python/poetry.lock b/python/poetry.lock index c9d9f3cbd..75a67b5ba 100644 --- a/python/poetry.lock +++ b/python/poetry.lock @@ -343,6 +343,25 @@ files = [ [package.extras] testing = ["hatch", "pre-commit", "pytest", "tox"] +[[package]] +name = "fastapi" +version = "0.110.1" +description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production" +optional = false +python-versions = ">=3.8" +files = [ + {file = "fastapi-0.110.1-py3-none-any.whl", hash = "sha256:5df913203c482f820d31f48e635e022f8cbfe7350e4830ef05a3163925b1addc"}, + {file = "fastapi-0.110.1.tar.gz", hash = "sha256:6feac43ec359dfe4f45b2c18ec8c94edb8dc2dfc461d417d9e626590c071baad"}, +] + +[package.dependencies] +pydantic = ">=1.7.4,<1.8 || >1.8,<1.8.1 || >1.8.1,<2.0.0 || >2.0.0,<2.0.1 || >2.0.1,<2.1.0 || >2.1.0,<3.0.0" +starlette = ">=0.37.2,<0.38.0" +typing-extensions = ">=4.8.0" + +[package.extras] +all = ["email-validator (>=2.0.0)", "httpx (>=0.23.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=2.11.2)", "orjson (>=3.2.1)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.7)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"] + [[package]] name = "freezegun" version = "1.4.0" @@ -1126,6 +1145,24 @@ files = [ {file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"}, ] +[[package]] +name = "starlette" +version = "0.37.2" +description = "The little ASGI library that shines." +optional = false +python-versions = ">=3.8" +files = [ + {file = "starlette-0.37.2-py3-none-any.whl", hash = "sha256:6fe59f29268538e5d0d182f2791a479a0c64638e6935d1c6989e63fb2699c6ee"}, + {file = "starlette-0.37.2.tar.gz", hash = "sha256:9af890290133b79fc3db55474ade20f6220a364a0402e0b556e7cd5e1e093823"}, +] + +[package.dependencies] +anyio = ">=3.4.0,<5" +typing-extensions = {version = ">=3.10.0", markers = "python_version < \"3.10\""} + +[package.extras] +full = ["httpx (>=0.22.0)", "itsdangerous", "jinja2", "python-multipart (>=0.0.7)", "pyyaml"] + [[package]] name = "tomli" version = "2.0.1" @@ -1299,6 +1336,25 @@ h2 = ["h2 (>=4,<5)"] socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] zstd = ["zstandard (>=0.18.0)"] +[[package]] +name = "uvicorn" +version = "0.29.0" +description = "The lightning-fast ASGI server." +optional = false +python-versions = ">=3.8" +files = [ + {file = "uvicorn-0.29.0-py3-none-any.whl", hash = "sha256:2c2aac7ff4f4365c206fd773a39bf4ebd1047c238f8b8268ad996829323473de"}, + {file = "uvicorn-0.29.0.tar.gz", hash = "sha256:6a69214c0b6a087462412670b3ef21224fa48cae0e452b5883e8e8bdfdd11dd0"}, +] + +[package.dependencies] +click = ">=7.0" +h11 = ">=0.8" +typing-extensions = {version = ">=4.0", markers = "python_version < \"3.11\""} + +[package.extras] +standard = ["colorama (>=0.4)", "httptools (>=0.5.0)", "python-dotenv (>=0.13)", "pyyaml (>=5.1)", "uvloop (>=0.14.0,!=0.15.0,!=0.15.1)", "watchfiles (>=0.13)", "websockets (>=10.4)"] + [[package]] name = "vcrpy" version = "6.0.1" @@ -1547,4 +1603,4 @@ vcr = [] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "3724956de2f14ef761403e17016a212a7de110b726184cf2e685425cb60261d1" +content-hash = "a347c2aca058e8c1b9fbcdc7043a5ed53ba03485759adce207d249e4495e05c6" diff --git a/python/pyproject.toml b/python/pyproject.toml index 8c9fe1693..23a64e575 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -50,6 +50,8 @@ pytest-cov = "^4.1.0" dataclasses-json = "^0.6.4" types-tqdm = "^4.66.0.20240106" vcrpy = "^6.0.1" +fastapi = "^0.110.1" +uvicorn = "^0.29.0" [tool.poetry.group.lint.dependencies] openai = "^1.10" diff --git a/python/tests/integration_tests/fake_server.py b/python/tests/integration_tests/fake_server.py new file mode 100644 index 000000000..a0a411e97 --- /dev/null +++ b/python/tests/integration_tests/fake_server.py @@ -0,0 +1,24 @@ +from fastapi import FastAPI, Request + +from langsmith import traceable +from langsmith.run_helpers import get_current_span, tracing_context + +fake_app = FastAPI() + + +@traceable +def fake_function(): + span = get_current_span() + assert span is not None + parent_run = span.parent_run + assert parent_run is not None + assert "did-propagate" in span.tags + assert span.metadata["some-cool-value"] == 42 + return "Fake function response" + + +@fake_app.post("/fake-route") +async def fake_route(request: Request): + with tracing_context(headers=request.headers): + fake_function() + return {"message": "Fake route response"} diff --git a/python/tests/integration_tests/test_context_propagation.py b/python/tests/integration_tests/test_context_propagation.py new file mode 100644 index 000000000..5166a92ee --- /dev/null +++ b/python/tests/integration_tests/test_context_propagation.py @@ -0,0 +1,47 @@ +import asyncio + +import pytest +from httpx import AsyncClient + +from langsmith import traceable +from langsmith.run_helpers import get_current_span +from tests.integration_tests.fake_server import fake_app + + +@pytest.fixture(scope="module") +def event_loop(): + loop = asyncio.get_event_loop() + yield loop + loop.close() + + +@pytest.fixture(scope="module") +async def fake_server(): + from uvicorn import Config, Server + + config = Config(app=fake_app, loop="asyncio", port=8000, log_level="info") + server = Server(config=config) + + asyncio.create_task(server.serve()) + await asyncio.sleep(0.1) + + yield + + await server.shutdown() + + +@traceable +async def the_parent_function(): + async with AsyncClient(app=fake_app, base_url="http://localhost:8000") as client: + headers = {} + if span := get_current_span(): + headers.update(span.to_headers()) + return await client.post("/fake-route", headers=headers) + + +@pytest.mark.asyncio +async def test_tracing_fake_server(fake_server): + response = await the_parent_function( + langsmith_extra={"metadata": {"some-cool-value": 42}, "tags": ["did-propagate"]} + ) + assert response.status_code == 200 diff --git a/python/tests/unit_tests/test_run_trees.py b/python/tests/unit_tests/test_run_trees.py index fa88e623c..92c398b2d 100644 --- a/python/tests/unit_tests/test_run_trees.py +++ b/python/tests/unit_tests/test_run_trees.py @@ -1,5 +1,9 @@ from concurrent.futures import ThreadPoolExecutor +from datetime import datetime from unittest.mock import MagicMock +from uuid import UUID + +import pytest from langsmith import run_trees from langsmith.client import Client @@ -13,3 +17,45 @@ def test_run_tree_accepts_tpe() -> None: client=mock_client, executor=ThreadPoolExecutor(), ) + + +@pytest.mark.parametrize( + "inputs, expected", + [ + ( + "20240412T202937370454Z152ce25c-064e-4742-bf36-8bb0389f8805.20240412T202937627763Zfe8b541f-e75a-4ee6-b92d-732710897194.20240412T202937708023Z625b30ed-2fbb-4387-81b1-cb5d6221e5b4.20240412T202937775748Z448dc09f-ad54-4475-b3a4-fa43018ca621.20240412T202937981350Z4cd59ea4-491e-4ed9-923f-48cd93e03755.20240412T202938078862Zcd168cf7-ee72-48c2-8ec0-50ab09821973.20240412T202938152278Z32481c1a-b83c-4b53-a52e-1ea893ffba51", + [ + ( + datetime(2024, 4, 12, 20, 29, 37, 370454), + UUID("152ce25c-064e-4742-bf36-8bb0389f8805"), + ), + ( + datetime(2024, 4, 12, 20, 29, 37, 627763), + UUID("fe8b541f-e75a-4ee6-b92d-732710897194"), + ), + ( + datetime(2024, 4, 12, 20, 29, 37, 708023), + UUID("625b30ed-2fbb-4387-81b1-cb5d6221e5b4"), + ), + ( + datetime(2024, 4, 12, 20, 29, 37, 775748), + UUID("448dc09f-ad54-4475-b3a4-fa43018ca621"), + ), + ( + datetime(2024, 4, 12, 20, 29, 37, 981350), + UUID("4cd59ea4-491e-4ed9-923f-48cd93e03755"), + ), + ( + datetime(2024, 4, 12, 20, 29, 38, 78862), + UUID("cd168cf7-ee72-48c2-8ec0-50ab09821973"), + ), + ( + datetime(2024, 4, 12, 20, 29, 38, 152278), + UUID("32481c1a-b83c-4b53-a52e-1ea893ffba51"), + ), + ], + ), + ], +) +def test_parse_dotted_order(inputs, expected): + assert run_trees._parse_dotted_order(inputs) == expected