From b61307cb7db37c698a49d1d262d8fd16f2de8474 Mon Sep 17 00:00:00 2001 From: William Fu-Hinthorn <13333726+hinthornw@users.noreply.github.com> Date: Thu, 18 Apr 2024 15:20:27 -0700 Subject: [PATCH] Still support 3.9 --- python/langsmith/run_helpers.py | 65 ++++++++++++++++++++------------- 1 file changed, 39 insertions(+), 26 deletions(-) diff --git a/python/langsmith/run_helpers.py b/python/langsmith/run_helpers.py index 28b354d20..6cd5e4c41 100644 --- a/python/langsmith/run_helpers.py +++ b/python/langsmith/run_helpers.py @@ -48,6 +48,12 @@ _PROJECT_NAME = contextvars.ContextVar[Optional[str]]("_PROJECT_NAME", default=None) _TAGS = contextvars.ContextVar[Optional[List[str]]]("_TAGS", default=None) _METADATA = contextvars.ContextVar[Optional[Dict[str, Any]]]("_METADATA", default=None) +_CONTEXT_KEYS: Dict[str, contextvars.ContextVar] = { + "parent": _PARENT_RUN_TREE, + "project_name": _PROJECT_NAME, + "tags": _TAGS, + "metadata": _METADATA, +} def get_current_run_tree() -> Optional[run_trees.RunTree]: @@ -55,14 +61,25 @@ def get_current_run_tree() -> Optional[run_trees.RunTree]: return _PARENT_RUN_TREE.get() -def get_tracing_context() -> dict: +def get_tracing_context( + context: Optional[contextvars.Context] = None, +) -> Dict[str, Any]: """Get the current tracing context.""" - return { - "parent": _PARENT_RUN_TREE.get(), - "project_name": _PROJECT_NAME.get(), - "tags": _TAGS.get(), - "metadata": _METADATA.get(), - } + if context is None: + return { + "parent": _PARENT_RUN_TREE.get(), + "project_name": _PROJECT_NAME.get(), + "tags": _TAGS.get(), + "metadata": _METADATA.get(), + } + return {k: context.get(v) for k, v in _CONTEXT_KEYS.items()} + + +def _set_tracing_context(context: Dict[str, Any]): + """Set the tracing context.""" + for k, v in context.items(): + var = _CONTEXT_KEYS[k] + var.set(v) @contextlib.contextmanager @@ -81,22 +98,24 @@ def tracing_context( f"Unrecognized keyword arguments: {kwargs}.", DeprecationWarning, ) - parent_run_ = get_current_run_tree() - _PROJECT_NAME.set(project_name) + current_context = get_tracing_context() parent_run = _get_parent_run({"parent": parent or kwargs.get("parent_run")}) if parent_run is not None: - _PARENT_RUN_TREE.set(parent_run) tags = sorted(set(tags or []) | set(parent_run.tags or [])) metadata = {**parent_run.metadata, **(metadata or {})} - _TAGS.set(tags) - _METADATA.set(metadata) + + _set_tracing_context( + { + "parent": parent_run, + "project_name": project_name, + "tags": tags, + "metadata": metadata, + } + ) try: yield finally: - _PROJECT_NAME.set(None) - _TAGS.set(None) - _METADATA.set(None) - _PARENT_RUN_TREE.set(parent_run_) + _set_tracing_context(current_context) # Alias for backwards compatibility @@ -402,7 +421,10 @@ async def async_wrapper( fr_coro, context=run_container["context"] ) else: - function_result = await fr_coro + # Python < 3.11 + copied_context = get_tracing_context(run_container["context"]) + with tracing_context(**copied_context): + function_result = await fr_coro except Exception as e: stacktrace = traceback.format_exc() _container_end(run_container, error=stacktrace) @@ -416,7 +438,6 @@ async def async_wrapper( async def async_generator_wrapper( *args: Any, langsmith_extra: Optional[LangSmithExtra] = None, **kwargs: Any ) -> AsyncGenerator: - parent_run_tree = get_current_run_tree() run_container = _setup_run( func, container_input=container_input, @@ -475,8 +496,6 @@ async def async_generator_wrapper( stacktrace = traceback.format_exc() _container_end(run_container, error=stacktrace) raise e - finally: - _unset_context_vars(run_container, parent_run_tree) if results: if reduce_fn: try: @@ -497,7 +516,6 @@ def wrapper( **kwargs: Any, ) -> Any: """Create a new run or create_child() if run is passed in kwargs.""" - parent_run_tree = get_current_run_tree() run_container = _setup_run( func, container_input=container_input, @@ -521,8 +539,6 @@ def wrapper( stacktrace = traceback.format_exc() _container_end(run_container, error=stacktrace) raise e - finally: - _unset_context_vars(run_container, parent_run_tree) _container_end(run_container, outputs=function_result) return function_result @@ -530,7 +546,6 @@ def wrapper( def generator_wrapper( *args: Any, langsmith_extra: Optional[LangSmithExtra] = None, **kwargs: Any ) -> Any: - parent_run_tree = get_current_run_tree() run_container = _setup_run( func, container_input=container_input, @@ -575,8 +590,6 @@ def generator_wrapper( stacktrace = traceback.format_exc() _container_end(run_container, error=stacktrace) raise e - finally: - _unset_context_vars(run_container, parent_run_tree) if results: if reduce_fn: try: