Skip to content

Commit

Permalink
Still support 3.9
Browse files Browse the repository at this point in the history
  • Loading branch information
hinthornw committed Apr 18, 2024
1 parent 9db45db commit b61307c
Showing 1 changed file with 39 additions and 26 deletions.
65 changes: 39 additions & 26 deletions python/langsmith/run_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,21 +48,38 @@
_PROJECT_NAME = contextvars.ContextVar[Optional[str]]("_PROJECT_NAME", default=None)
_TAGS = contextvars.ContextVar[Optional[List[str]]]("_TAGS", default=None)
_METADATA = contextvars.ContextVar[Optional[Dict[str, Any]]]("_METADATA", default=None)
_CONTEXT_KEYS: Dict[str, contextvars.ContextVar] = {
"parent": _PARENT_RUN_TREE,
"project_name": _PROJECT_NAME,
"tags": _TAGS,
"metadata": _METADATA,
}


def get_current_run_tree() -> Optional[run_trees.RunTree]:
"""Get the current run tree."""
return _PARENT_RUN_TREE.get()


def get_tracing_context() -> dict:
def get_tracing_context(
context: Optional[contextvars.Context] = None,
) -> Dict[str, Any]:
"""Get the current tracing context."""
return {
"parent": _PARENT_RUN_TREE.get(),
"project_name": _PROJECT_NAME.get(),
"tags": _TAGS.get(),
"metadata": _METADATA.get(),
}
if context is None:
return {
"parent": _PARENT_RUN_TREE.get(),
"project_name": _PROJECT_NAME.get(),
"tags": _TAGS.get(),
"metadata": _METADATA.get(),
}
return {k: context.get(v) for k, v in _CONTEXT_KEYS.items()}


def _set_tracing_context(context: Dict[str, Any]):
"""Set the tracing context."""
for k, v in context.items():
var = _CONTEXT_KEYS[k]
var.set(v)


@contextlib.contextmanager
Expand All @@ -81,22 +98,24 @@ def tracing_context(
f"Unrecognized keyword arguments: {kwargs}.",
DeprecationWarning,
)
parent_run_ = get_current_run_tree()
_PROJECT_NAME.set(project_name)
current_context = get_tracing_context()
parent_run = _get_parent_run({"parent": parent or kwargs.get("parent_run")})
if parent_run is not None:
_PARENT_RUN_TREE.set(parent_run)
tags = sorted(set(tags or []) | set(parent_run.tags or []))
metadata = {**parent_run.metadata, **(metadata or {})}
_TAGS.set(tags)
_METADATA.set(metadata)

_set_tracing_context(
{
"parent": parent_run,
"project_name": project_name,
"tags": tags,
"metadata": metadata,
}
)
try:
yield
finally:
_PROJECT_NAME.set(None)
_TAGS.set(None)
_METADATA.set(None)
_PARENT_RUN_TREE.set(parent_run_)
_set_tracing_context(current_context)


# Alias for backwards compatibility
Expand Down Expand Up @@ -402,7 +421,10 @@ async def async_wrapper(
fr_coro, context=run_container["context"]
)
else:
function_result = await fr_coro
# Python < 3.11
copied_context = get_tracing_context(run_container["context"])
with tracing_context(**copied_context):
function_result = await fr_coro
except Exception as e:
stacktrace = traceback.format_exc()
_container_end(run_container, error=stacktrace)
Expand All @@ -416,7 +438,6 @@ async def async_wrapper(
async def async_generator_wrapper(
*args: Any, langsmith_extra: Optional[LangSmithExtra] = None, **kwargs: Any
) -> AsyncGenerator:
parent_run_tree = get_current_run_tree()
run_container = _setup_run(
func,
container_input=container_input,
Expand Down Expand Up @@ -475,8 +496,6 @@ async def async_generator_wrapper(
stacktrace = traceback.format_exc()
_container_end(run_container, error=stacktrace)
raise e
finally:
_unset_context_vars(run_container, parent_run_tree)
if results:
if reduce_fn:
try:
Expand All @@ -497,7 +516,6 @@ def wrapper(
**kwargs: Any,
) -> Any:
"""Create a new run or create_child() if run is passed in kwargs."""
parent_run_tree = get_current_run_tree()
run_container = _setup_run(
func,
container_input=container_input,
Expand All @@ -521,16 +539,13 @@ def wrapper(
stacktrace = traceback.format_exc()
_container_end(run_container, error=stacktrace)
raise e
finally:
_unset_context_vars(run_container, parent_run_tree)
_container_end(run_container, outputs=function_result)
return function_result

@functools.wraps(func)
def generator_wrapper(
*args: Any, langsmith_extra: Optional[LangSmithExtra] = None, **kwargs: Any
) -> Any:
parent_run_tree = get_current_run_tree()
run_container = _setup_run(
func,
container_input=container_input,
Expand Down Expand Up @@ -575,8 +590,6 @@ def generator_wrapper(
stacktrace = traceback.format_exc()
_container_end(run_container, error=stacktrace)
raise e
finally:
_unset_context_vars(run_container, parent_run_tree)
if results:
if reduce_fn:
try:
Expand Down

0 comments on commit b61307c

Please sign in to comment.