Skip to content

Commit

Permalink
Add trace_as_chain_group metadata (#17187)
Browse files Browse the repository at this point in the history
  • Loading branch information
hinthornw authored Feb 7, 2024
1 parent 5ceaf78 commit 9fa0707
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions libs/core/langchain_core/callbacks/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def trace_as_chain_group(
example_id: Optional[Union[str, UUID]] = None,
run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> Generator[CallbackManagerForChainGroup, None, None]:
"""Get a callback manager for a chain group in a context manager.
Useful for grouping different calls together as a single run even if
Expand All @@ -83,6 +84,8 @@ def trace_as_chain_group(
run_id (UUID, optional): The ID of the run.
tags (List[str], optional): The inheritable tags to apply to all runs.
Defaults to None.
metadata (Dict[str, Any], optional): The metadata to apply to all runs.
Defaults to None.
Note: must have LANGCHAIN_TRACING_V2 env var set to true to see the trace in LangSmith.
Expand All @@ -95,7 +98,7 @@ def trace_as_chain_group(
llm_input = "Foo"
with trace_as_chain_group("group_name", inputs={"input": llm_input}) as manager:
# Use the callback manager for the chain group
res = llm.predict(llm_input, callbacks=manager)
res = llm.invoke(llm_input, {"callbacks": manager})
manager.on_chain_end({"output": res})
""" # noqa: E501
from langchain_core.tracers.context import _get_trace_callbacks
Expand All @@ -106,6 +109,7 @@ def trace_as_chain_group(
cm = CallbackManager.configure(
inheritable_callbacks=cb,
inheritable_tags=tags,
inheritable_metadata=metadata,
)

run_manager = cm.on_chain_start({"name": group_name}, inputs or {}, run_id=run_id)
Expand Down Expand Up @@ -141,6 +145,7 @@ async def atrace_as_chain_group(
example_id: Optional[Union[str, UUID]] = None,
run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> AsyncGenerator[AsyncCallbackManagerForChainGroup, None]:
"""Get an async callback manager for a chain group in a context manager.
Useful for grouping different async calls together as a single run even if
Expand All @@ -157,6 +162,8 @@ async def atrace_as_chain_group(
run_id (UUID, optional): The ID of the run.
tags (List[str], optional): The inheritable tags to apply to all runs.
Defaults to None.
metadata (Dict[str, Any], optional): The metadata to apply to all runs.
Defaults to None.
Returns:
AsyncCallbackManager: The async callback manager for the chain group.
Expand All @@ -168,15 +175,17 @@ async def atrace_as_chain_group(
llm_input = "Foo"
async with atrace_as_chain_group("group_name", inputs={"input": llm_input}) as manager:
# Use the async callback manager for the chain group
res = await llm.apredict(llm_input, callbacks=manager)
res = await llm.ainvoke(llm_input, {"callbacks": manager})
await manager.on_chain_end({"output": res})
""" # noqa: E501
from langchain_core.tracers.context import _get_trace_callbacks

cb = _get_trace_callbacks(
project_name, example_id, callback_manager=callback_manager
)
cm = AsyncCallbackManager.configure(inheritable_callbacks=cb, inheritable_tags=tags)
cm = AsyncCallbackManager.configure(
inheritable_callbacks=cb, inheritable_tags=tags, inheritable_metadata=metadata
)

run_manager = await cm.on_chain_start(
{"name": group_name}, inputs or {}, run_id=run_id
Expand Down

0 comments on commit 9fa0707

Please sign in to comment.