diff --git a/python/langsmith/run_helpers.py b/python/langsmith/run_helpers.py index 9dee2191d..f9e3f9bb5 100644 --- a/python/langsmith/run_helpers.py +++ b/python/langsmith/run_helpers.py @@ -698,6 +698,7 @@ def trace( f"{sorted(kwargs.keys())}.", DeprecationWarning, ) + old_ctx = get_tracing_context() outer_tags = _TAGS.get() outer_metadata = _METADATA.get() outer_project = _PROJECT_NAME.get() or utils.get_tracer_project() @@ -739,6 +740,7 @@ def trace( new_run.post() _PARENT_RUN_TREE.set(new_run) _PROJECT_NAME.set(project_name_) + try: yield new_run except (Exception, KeyboardInterrupt, BaseException) as e: @@ -751,10 +753,8 @@ def trace( new_run.patch() raise e finally: - _PARENT_RUN_TREE.set(parent_run_) - _PROJECT_NAME.set(outer_project) - _TAGS.set(outer_tags) - _METADATA.set(outer_metadata) + # Reset the old context + _set_tracing_context(old_ctx) new_run.patch() diff --git a/python/langsmith/utils.py b/python/langsmith/utils.py index 72607b8b1..2c0152e0f 100644 --- a/python/langsmith/utils.py +++ b/python/langsmith/utils.py @@ -67,11 +67,19 @@ class LangSmithConnectionError(LangSmithError): def tracing_is_enabled() -> bool: """Return True if tracing is enabled.""" - from langsmith.run_helpers import get_tracing_context + from langsmith.run_helpers import get_current_run_tree, get_tracing_context tc = get_tracing_context() + # You can manually override the environment using context vars. + # Check that first. + # Doing this before checking the run tree lets us + # disable a branch within a trace. if tc["enabled"] is not None: return tc["enabled"] + # Next check if we're mid-trace + if get_current_run_tree(): + return True + # Finally, check the global environment var_result = get_env_var("TRACING_V2", default=get_env_var("TRACING", default="")) return var_result == "true" diff --git a/python/pyproject.toml b/python/pyproject.toml index 9e7061a3a..9b8436ecd 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langsmith" -version = "0.1.74" +version = "0.1.75" description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." authors = ["LangChain "] license = "MIT" diff --git a/python/tests/unit_tests/test_client.py b/python/tests/unit_tests/test_client.py index 5e07ac147..f08e17864 100644 --- a/python/tests/unit_tests/test_client.py +++ b/python/tests/unit_tests/test_client.py @@ -343,7 +343,7 @@ def test_create_run_mutate() -> None: trace_id=id_, dotted_order=run_dict["dotted_order"], ) - for _ in range(7): + for _ in range(10): time.sleep(0.1) # Give the background thread time to stop payloads = [ json.loads(call[2]["data"]) diff --git a/python/tests/unit_tests/test_utils.py b/python/tests/unit_tests/test_utils.py index 29e4b4720..8fa74992a 100644 --- a/python/tests/unit_tests/test_utils.py +++ b/python/tests/unit_tests/test_utils.py @@ -7,7 +7,7 @@ from datetime import datetime from enum import Enum from typing import Any, NamedTuple, Optional -from unittest.mock import patch +from unittest.mock import MagicMock, patch import attr import dataclasses_json @@ -15,6 +15,7 @@ from pydantic import BaseModel import langsmith.utils as ls_utils +from langsmith import Client, traceable from langsmith.run_helpers import tracing_context @@ -87,7 +88,9 @@ def test_correct_get_tracer_project(self): def test_tracing_enabled(): - with patch.dict("os.environ", {"LANGCHAIN_TRACING_V2": "false"}): + with patch.dict( + "os.environ", {"LANGCHAIN_TRACING_V2": "false", "LANGSMITH_TRACING": "false"} + ): assert not ls_utils.tracing_is_enabled() with tracing_context(enabled=True): assert ls_utils.tracing_is_enabled() @@ -97,9 +100,39 @@ def test_tracing_enabled(): assert not ls_utils.tracing_is_enabled() assert not ls_utils.tracing_is_enabled() + @traceable + def child_function(): + assert ls_utils.tracing_is_enabled() + return 1 + + @traceable + def untraced_child_function(): + assert not ls_utils.tracing_is_enabled() + return 1 + + @traceable + def parent_function(): + with patch.dict( + "os.environ", + {"LANGCHAIN_TRACING_V2": "false", "LANGSMITH_TRACING": "false"}, + ): + assert ls_utils.tracing_is_enabled() + child_function() + with tracing_context(enabled=False): + assert not ls_utils.tracing_is_enabled() + return untraced_child_function() + + with patch.dict( + "os.environ", {"LANGCHAIN_TRACING_V2": "true", "LANGSMITH_TRACING": "true"} + ): + mock_client = MagicMock(spec=Client) + parent_function(langsmith_extra={"client": mock_client}) + def test_tracing_disabled(): - with patch.dict("os.environ", {"LANGCHAIN_TRACING_V2": "true"}): + with patch.dict( + "os.environ", {"LANGCHAIN_TRACING_V2": "true", "LANGSMITH_TRACING": "true"} + ): assert ls_utils.tracing_is_enabled() with tracing_context(enabled=False): assert not ls_utils.tracing_is_enabled()