Skip to content

Commit

Permalink
Add tracing propagation test
Browse files Browse the repository at this point in the history
  • Loading branch information
hinthornw committed Sep 11, 2024
1 parent 620a82b commit 709431b
Showing 1 changed file with 110 additions and 0 deletions.
110 changes: 110 additions & 0 deletions libs/langgraph/tests/test_tracing_interops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import json
import sys
import time
from typing import Any, Callable, Tuple, TypedDict, TypeVar
from unittest.mock import MagicMock

import langsmith as ls
import pytest
from langchain_core.runnables import RunnableConfig
from langchain_core.tracers import LangChainTracer

from langgraph.graph import StateGraph


def _get_mock_client(**kwargs: Any) -> ls.Client:
mock_session = MagicMock()
return ls.Client(session=mock_session, api_key="test", **kwargs)


def _get_calls(
mock_client: Any,
verbs: set[str] = {"POST"},
) -> list:
return [
c
for c in mock_client.session.request.mock_calls
if c.args and c.args[0] in verbs
]


T = TypeVar("T")


def wait_for(
condition: Callable[[], Tuple[T, bool]],
max_sleep_time: int = 10,
sleep_time: int = 3,
) -> T:
"""Wait for a condition to be true."""
start_time = time.time()
last_e = None
while time.time() - start_time < max_sleep_time:
try:
res, cond = condition()
if cond:
return res
except Exception as e:
last_e = e
time.sleep(sleep_time)
total_time = time.time() - start_time
if last_e is not None:
raise last_e
raise ValueError(f"Callable did not return within {total_time}")


@pytest.mark.asyncio
async def test_nested_tracing():
lt_py_311 = sys.version_info < (3, 11)
mock_client = _get_mock_client()

class State(TypedDict):
value: str

@ls.traceable
async def some_traceable(content: State):
return await child_graph.ainvoke(content)

async def parent_node(state: State, config: RunnableConfig) -> State:
if lt_py_311:
result = await some_traceable(state, langsmith_extra={"config": config})
else:
result = await some_traceable(state)
return {"value": f"parent_{result['value']}"}

async def child_node(state: State) -> State:
return {"value": f"child_{state['value']}"}

child_builder = StateGraph(State)
child_builder.add_node(child_node)
child_builder.add_edge("__start__", "child_node")
child_graph = child_builder.compile()

parent_builder = StateGraph(State)
parent_builder.add_node(parent_node)
parent_builder.add_edge("__start__", "parent_node")
parent_graph = parent_builder.compile()

tracer = LangChainTracer(client=mock_client)
result = await parent_graph.ainvoke({"value": "input"}, {"callbacks": [tracer]})

assert result == {"value": "parent_child_input"}

# Check POST calls
# Wait until we have a call with the name of "child_node"
def get_posts():
post_calls = _get_calls(mock_client, verbs={"POST"})

posts = [p for c in post_calls for p in json.loads(c.kwargs["data"])["post"]]
names = [p.get("name") for p in posts]
if "child_node" in names:
return posts, True
return None, False

posts = wait_for(get_posts)
parent_run = next(data for data in posts if data["name"] == "parent_node")
child_run = next(data for data in posts if data["name"] == "child_node")
traceable_run = next(data for data in posts if data["name"] == "some_traceable")

assert child_run["dotted_order"].startswith(traceable_run["dotted_order"])
assert traceable_run["dotted_order"].startswith(parent_run["dotted_order"])

0 comments on commit 709431b

Please sign in to comment.