Skip to content

Commit

Permalink
Support Context Propagation (#599)
Browse files Browse the repository at this point in the history
Client Side:

```
async def the_parent_function():
    async with AsyncClient(app=fake_app, base_url="http://localhost:8000") as client:
        headers = {}
        if span := get_current_span():
            headers.update(span.to_headers())
        return await client.post("/fake-route", headers=headers)

```

Server Side:

```
@fake_app.post("/fake-route")
async def fake_route(request: Request):
    with tracing_context(headers=request.headers):
        fake_function()
    return {"message": "Fake route response"}

```


If people like, we could add some fun middleware, but probably not
necessary
  • Loading branch information
hinthornw authored Apr 13, 2024
1 parent b84766e commit 2066011
Show file tree
Hide file tree
Showing 9 changed files with 484 additions and 38 deletions.
86 changes: 66 additions & 20 deletions python/langsmith/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,31 +229,29 @@ def unit(*args: Any, **kwargs: Any) -> Callable:
" Skipping LangSmith test tracking."
)

if args and callable(args[0]):
func = args[0]
if disable_tracking:
return func

@functools.wraps(func)
def wrapper(*test_args, **test_kwargs):
_run_test(
func,
*test_args,
**test_kwargs,
langtest_extra=langtest_extra,
)
def decorator(func: Callable) -> Callable:
if inspect.iscoroutinefunction(func):

async def async_wrapper(*test_args: Any, **test_kwargs: Any):
if disable_tracking:
return await func(*test_args, **test_kwargs)
await _arun_test(
func, *test_args, **test_kwargs, langtest_extra=langtest_extra
)

return wrapper
return async_wrapper

def decorator(func):
@functools.wraps(func)
def wrapper(*test_args, **test_kwargs):
def wrapper(*test_args: Any, **test_kwargs: Any):
if disable_tracking:
return func(*test_args, **test_kwargs)
_run_test(func, *test_args, **test_kwargs, langtest_extra=langtest_extra)

return wrapper

if args and callable(args[0]):
return decorator(args[0])

return decorator


Expand Down Expand Up @@ -470,12 +468,9 @@ def _get_test_repr(func: Callable, sig: inspect.Signature) -> str:
def _ensure_example(
func: Callable, *args: Any, langtest_extra: _UTExtra, **kwargs: Any
) -> Tuple[_LangSmithTestSuite, uuid.UUID]:
# 1. check if the id exists.
# TODOs: Local cache + prefer a peek operation
client = langtest_extra["client"] or ls_client.Client()
output_keys = langtest_extra["output_keys"]
signature = inspect.signature(func)
# 2. Create the example
inputs: dict = rh._get_inputs_safe(signature, *args, **kwargs)
outputs = {}
if output_keys:
Expand All @@ -492,7 +487,9 @@ def _ensure_example(
return test_suite, example_id


def _run_test(func, *test_args, langtest_extra: _UTExtra, **test_kwargs):
def _run_test(
func: Callable, *test_args: Any, langtest_extra: _UTExtra, **test_kwargs: Any
) -> None:
test_suite, example_id = _ensure_example(
func, *test_args, **test_kwargs, langtest_extra=langtest_extra
)
Expand Down Expand Up @@ -537,3 +534,52 @@ def _test():
cache_path, ignore_hosts=[test_suite.client.api_url]
):
_test()


async def _arun_test(
func: Callable, *test_args: Any, langtest_extra: _UTExtra, **test_kwargs: Any
) -> None:
test_suite, example_id = _ensure_example(
func, *test_args, **test_kwargs, langtest_extra=langtest_extra
)
run_id = uuid.uuid4()

async def _test():
try:
func_ = func if rh.is_traceable_function(func) else rh.traceable(func)
await func_(
*test_args,
**test_kwargs,
langsmith_extra={
"run_id": run_id,
"reference_example_id": example_id,
"project_name": test_suite.name,
},
)
except BaseException as e:
test_suite.submit_result(run_id, error=repr(e))
raise e
try:
test_suite.submit_result(run_id, error=None)
except BaseException as e:
logger.warning(f"Failed to create feedback for run_id {run_id}: {e}")

cache_path = (
Path(langtest_extra["cache"]) / f"{test_suite.id}.yaml"
if langtest_extra["cache"]
else None
)
current_context = rh.get_tracing_context()
metadata = {
**(current_context["metadata"] or {}),
**{
"experiment": test_suite.experiment.name,
"reference_example_id": str(example_id),
},
}
with rh.tracing_context(
**{**current_context, "metadata": metadata}
), ls_utils.with_optional_cache(
cache_path, ignore_hosts=[test_suite.client.api_url]
):
await _test()
53 changes: 41 additions & 12 deletions python/langsmith/run_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def get_current_run_tree() -> Optional[run_trees.RunTree]:
def get_tracing_context() -> dict:
"""Get the current tracing context."""
return {
"parent_run": _PARENT_RUN_TREE.get(),
"parent": _PARENT_RUN_TREE.get(),
"project_name": _PROJECT_NAME.get(),
"tags": _TAGS.get(),
"metadata": _METADATA.get(),
Expand All @@ -68,14 +68,25 @@ def tracing_context(
project_name: Optional[str] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
parent_run: Optional[run_trees.RunTree] = None,
parent: Optional[Union[run_trees.RunTree, Mapping, str]] = None,
**kwargs: Any,
) -> Generator[None, None, None]:
"""Set the tracing context for a block of code."""
parent_run_ = get_run_tree_context()
if kwargs:
# warn
warnings.warn(
f"Unrecognized keyword arguments: {kwargs}.",
DeprecationWarning,
)
parent_run_ = get_current_run_tree()
_PROJECT_NAME.set(project_name)
parent_run = _get_parent_run({"parent": parent or kwargs.get("parent_run")})
if parent_run is not None:
_PARENT_RUN_TREE.set(parent_run)
tags = sorted(set(tags or []) | set(parent_run.tags or []))
metadata = {**parent_run.metadata, **(metadata or {})}
_TAGS.set(tags)
_METADATA.set(metadata)
_PARENT_RUN_TREE.set(parent_run)
try:
yield
finally:
Expand All @@ -85,6 +96,7 @@ def tracing_context(
_PARENT_RUN_TREE.set(parent_run_)


# Alias for backwards compatibility
get_run_tree_context = get_current_run_tree


Expand Down Expand Up @@ -143,7 +155,8 @@ class LangSmithExtra(TypedDict, total=False):

reference_example_id: Optional[ls_client.ID_TYPE]
run_extra: Optional[Dict]
run_tree: Optional[run_trees.RunTree]
parent: Optional[Union[run_trees.RunTree, str, Mapping]]
run_tree: Optional[run_trees.RunTree] # TODO: Deprecate
project_name: Optional[str]
metadata: Optional[Dict[str, Any]]
tags: Optional[List[str]]
Expand Down Expand Up @@ -212,6 +225,20 @@ def _collect_extra(extra_outer: dict, langsmith_extra: LangSmithExtra) -> dict:
return extra_inner


def _get_parent_run(langsmith_extra: LangSmithExtra) -> Optional[run_trees.RunTree]:
parent = langsmith_extra.get("parent")
if isinstance(parent, run_trees.RunTree):
return parent
if isinstance(parent, dict):
return run_trees.RunTree.from_headers(parent)
if isinstance(parent, str):
return run_trees.RunTree.from_dotted_order(parent)
run_tree = langsmith_extra.get("run_tree")
if run_tree:
return run_tree
return get_current_run_tree()


def _setup_run(
func: Callable,
container_input: _ContainerInput,
Expand All @@ -228,7 +255,7 @@ def _setup_run(
run_type = container_input.get("run_type") or "chain"
outer_project = _PROJECT_NAME.get()
langsmith_extra = langsmith_extra or LangSmithExtra()
parent_run_ = langsmith_extra.get("run_tree") or get_run_tree_context()
parent_run_ = _get_parent_run(langsmith_extra)
project_cv = _PROJECT_NAME.get()
selected_project = (
project_cv # From parent trace
Expand Down Expand Up @@ -577,7 +604,7 @@ async def async_wrapper(
**kwargs: Any,
) -> Any:
"""Async version of wrapper function."""
context_run = get_run_tree_context()
context_run = get_current_run_tree()
run_container = _setup_run(
func,
container_input=container_input,
Expand Down Expand Up @@ -611,7 +638,7 @@ async def async_wrapper(
async def async_generator_wrapper(
*args: Any, langsmith_extra: Optional[LangSmithExtra] = None, **kwargs: Any
) -> AsyncGenerator:
context_run = get_run_tree_context()
context_run = get_current_run_tree()
run_container = _setup_run(
func,
container_input=container_input,
Expand Down Expand Up @@ -683,7 +710,7 @@ def wrapper(
**kwargs: Any,
) -> Any:
"""Create a new run or create_child() if run is passed in kwargs."""
context_run = get_run_tree_context()
context_run = get_current_run_tree()
run_container = _setup_run(
func,
container_input=container_input,
Expand Down Expand Up @@ -717,7 +744,7 @@ def wrapper(
def generator_wrapper(
*args: Any, langsmith_extra: Optional[LangSmithExtra] = None, **kwargs: Any
) -> Any:
context_run = get_run_tree_context()
context_run = get_current_run_tree()
run_container = _setup_run(
func,
container_input=container_input,
Expand Down Expand Up @@ -808,7 +835,7 @@ def trace(
inputs: Optional[Dict] = None,
extra: Optional[Dict] = None,
project_name: Optional[str] = None,
run_tree: Optional[run_trees.RunTree] = None,
parent: Optional[Union[run_trees.RunTree, str, Mapping]] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Mapping[str, Any]] = None,
client: Optional[ls_client.Client] = None,
Expand All @@ -825,7 +852,9 @@ def trace(
outer_tags = _TAGS.get()
outer_metadata = _METADATA.get()
outer_project = _PROJECT_NAME.get() or utils.get_tracer_project()
parent_run_ = get_run_tree_context() if run_tree is None else run_tree
parent_run_ = _get_parent_run(
{"parent": parent, "run_tree": kwargs.get("run_tree")}
)

# Merge and set context variables
tags_ = sorted(set((tags or []) + (outer_tags or [])))
Expand Down
Loading

0 comments on commit 2066011

Please sign in to comment.