Skip to content

Commit

Permalink
Implement register hooks
Browse files Browse the repository at this point in the history
  • Loading branch information
nfcampos committed Nov 15, 2023
1 parent b05991a commit 22f6fbe
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 12 deletions.
9 changes: 4 additions & 5 deletions libs/langchain/langchain/callbacks/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
)

from langchain.callbacks.openai_info import OpenAICallbackHandler
from langchain.callbacks.tracers import run_collector
from langchain.callbacks.tracers.wandb import WandbTracer
from langchain.schema.callbacks.manager import (
AsyncCallbackManager,
Expand All @@ -34,6 +33,7 @@
collect_runs,
env_var_is_set,
handle_event,
register_configure_hook,
trace_as_chain_group,
tracing_enabled,
tracing_v2_enabled,
Expand All @@ -48,10 +48,9 @@
"tracing_wandb_callback", default=None
)

run_collector_var: ContextVar[
Optional[run_collector.RunCollectorCallbackHandler]
] = ContextVar( # noqa: E501
"run_collector", default=None
register_configure_hook(openai_callback_var, True)
register_configure_hook(
wandb_tracing_callback_var, True, WandbTracer, "LANGCHAIN_WANDB_TRACING"
)


Expand Down
66 changes: 59 additions & 7 deletions libs/langchain/langchain/schema/callbacks/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
List,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
Expand Down Expand Up @@ -1897,6 +1898,43 @@ def _get_tracer_project() -> str:
)


_configure_hooks: List[
Tuple[
ContextVar[Optional[BaseCallbackHandler]],
bool,
Optional[Type[BaseCallbackHandler]],
Optional[str],
]
] = []

H = TypeVar("H", bound=BaseCallbackHandler, covariant=True)


def register_configure_hook(
context_var: ContextVar[Optional[Any]],
inheritable: bool,
handle_class: Optional[Type[BaseCallbackHandler]] = None,
env_var: Optional[str] = None,
) -> None:
if env_var is not None and handle_class is None:
raise ValueError(
"If env_var is set, handle_class must also be set to a non-None value."
)
_configure_hooks.append(
(
# the typings of ContextVar do not have the generic arg set as covariant
# so we have to cast it
cast(ContextVar[Optional[BaseCallbackHandler]], context_var),
inheritable,
handle_class,
env_var,
)
)


register_configure_hook(run_collector_var, False)


def _configure(
callback_manager_cls: Type[T],
inheritable_callbacks: Callbacks = None,
Expand Down Expand Up @@ -1972,7 +2010,6 @@ def _configure(
tracer_v2 = tracing_v2_callback_var.get()
tracing_v2_enabled_ = _tracing_v2_is_enabled()
tracer_project = _get_tracer_project()
run_collector_ = run_collector_var.get()
debug = _get_debug()
if verbose or debug or tracing_enabled_ or tracing_v2_enabled_:
if verbose and not any(
Expand Down Expand Up @@ -2012,12 +2049,27 @@ def _configure(
logger.warning(
"Unable to load requested LangChainTracer."
" To disable this warning,"
" unset the LANGCHAIN_TRACING_V2 environment variables.",
" unset the LANGCHAIN_TRACING_V2 environment variables.",
e,
)
if run_collector_ is not None and not any(
handler is run_collector_ # direct pointer comparison
for handler in callback_manager.handlers
):
callback_manager.add_handler(run_collector_, False)
for var, inheritable, handler_class, env_var in _configure_hooks:
create_one = (
env_var is not None
and env_var_is_set(env_var)
and handler_class is not None
)
if var.get() is not None or create_one:
var_handler = var.get() or cast(Type[BaseCallbackHandler], handler_class)()
if handler_class is None:
if not any(
handler is var_handler # direct pointer comparison
for handler in callback_manager.handlers
):
callback_manager.add_handler(var_handler, inheritable)
else:
if not any(
isinstance(handler, handler_class)
for handler in callback_manager.handlers
):
callback_manager.add_handler(var_handler, inheritable)
return callback_manager

0 comments on commit 22f6fbe

Please sign in to comment.