From 621fae6acf507a83401596239401563d7a163ff4 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Thu, 30 May 2024 02:44:26 -0700 Subject: [PATCH] [Python] 0.1.64 |Accept RunnableConfig|Customize OAI wrapper name|@traceable typing|Cache default RunTree Client MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Read parent run info from RunnableConfig if passed to function decorated with traceable - pass config into wrapped func only if signature declares it - modify signature of wrapper func to declare config as a kw arg, otherwise runnables don't pass it in - extract client and project_name from tracer, in addition to parent run info - update as_runnable to delegate run tree creation to new method ## Custom Run Name support in OpenAI Client ## Improve @traceable typing # Improves for python >= 3.10. Should in theory not change anything for users of 3.8 and 3.9 (at least, things pass in our 3.8 linting here... Apologies in advance to any mypy acolytes on 3.{8,9} who gain undesired linting errors after this change) Uses: 1. ParamSpec (available in 3.10 and onwards) 2. Protocols (already used) Even though python doesn't naturally support keyword-only concatenation, we can work around this with protocols and duck typing to communicate the actual returned type of "the same function signature + a keyword only langsmith_extra arg" Python's typing situation makes it hard to make everyone happy, but hopefully this strikes a better compromise than before (typing kwargs as Any in the wrapped function; too lenient) ## Cache default run tree client --------- Co-authored-by: William Fu-Hinthorn <13333726+hinthornw@users.noreply.github.com> --- python/langsmith/env/_git.py | 2 +- python/langsmith/env/_runtime_env.py | 18 ++- python/langsmith/evaluation/_arunner.py | 6 +- python/langsmith/evaluation/_runner.py | 6 +- python/langsmith/evaluation/evaluator.py | 16 +- python/langsmith/run_helpers.py | 148 ++++++++---------- python/langsmith/run_trees.py | 69 +++++++- python/langsmith/wrappers/_openai.py | 16 +- python/pyproject.toml | 2 +- .../integration_tests/wrappers/test_openai.py | 24 ++- python/tests/unit_tests/test_run_helpers.py | 81 +++++++++- 11 files changed, 277 insertions(+), 111 deletions(-) diff --git a/python/langsmith/env/_git.py b/python/langsmith/env/_git.py index 705f53c1e..ce598285f 100644 --- a/python/langsmith/env/_git.py +++ b/python/langsmith/env/_git.py @@ -47,7 +47,7 @@ def get_git_info(remote: str = "origin") -> GitInfo: dirty=None, tags=None, repo_name=None, - ) + ) return { "remote_url": exec_git(["remote", "get-url", remote]), diff --git a/python/langsmith/env/_runtime_env.py b/python/langsmith/env/_runtime_env.py index 5646263c2..5fe46a9c7 100644 --- a/python/langsmith/env/_runtime_env.py +++ b/python/langsmith/env/_runtime_env.py @@ -1,4 +1,5 @@ """Environment information.""" + import functools import logging import os @@ -77,6 +78,7 @@ def get_runtime_environment() -> dict: "py_implementation": platform.python_implementation(), "runtime_version": platform.python_version(), "langchain_version": get_langchain_environment(), + "langchain_core_version": get_langchain_core_version(), **shas, } @@ -91,6 +93,16 @@ def get_langchain_environment() -> Optional[str]: return None +@functools.lru_cache(maxsize=1) +def get_langchain_core_version() -> Optional[str]: + try: + import langchain_core # type: ignore + + return langchain_core.__version__ + except ImportError: + return None + + @functools.lru_cache(maxsize=1) def get_docker_version() -> Optional[str]: import subprocess @@ -138,9 +150,9 @@ def get_docker_environment() -> dict: compose_command = _get_compose_command() return { "docker_version": get_docker_version(), - "docker_compose_command": " ".join(compose_command) - if compose_command is not None - else None, + "docker_compose_command": ( + " ".join(compose_command) if compose_command is not None else None + ), "docker_compose_version": get_docker_compose_version(), } diff --git a/python/langsmith/evaluation/_arunner.py b/python/langsmith/evaluation/_arunner.py index ab1eaa3be..b5e1cc2ed 100644 --- a/python/langsmith/evaluation/_arunner.py +++ b/python/langsmith/evaluation/_arunner.py @@ -801,7 +801,7 @@ async def wait(self) -> None: async def _aforward( - fn: rh.SupportsLangsmithExtra[Awaitable], + fn: rh.SupportsLangsmithExtra[[dict], Awaitable], example: schemas.Example, experiment_name: str, metadata: dict, @@ -839,7 +839,9 @@ def _get_run(r: run_trees.RunTree) -> None: ) -def _ensure_async_traceable(target: ATARGET_T) -> rh.SupportsLangsmithExtra[Awaitable]: +def _ensure_async_traceable( + target: ATARGET_T, +) -> rh.SupportsLangsmithExtra[[dict], Awaitable]: if not asyncio.iscoroutinefunction(target): raise ValueError( "Target must be an async function. For sync functions, use evaluate." diff --git a/python/langsmith/evaluation/_runner.py b/python/langsmith/evaluation/_runner.py index 7854372fd..ae9d983c8 100644 --- a/python/langsmith/evaluation/_runner.py +++ b/python/langsmith/evaluation/_runner.py @@ -1466,12 +1466,14 @@ def _resolve_data( return data -def _ensure_traceable(target: TARGET_T) -> rh.SupportsLangsmithExtra: +def _ensure_traceable( + target: TARGET_T | rh.SupportsLangsmithExtra[[dict], dict], +) -> rh.SupportsLangsmithExtra[[dict], dict]: """Ensure the target function is traceable.""" if not callable(target): raise ValueError("Target must be a callable function.") if rh.is_traceable_function(target): - fn = cast(rh.SupportsLangsmithExtra, target) + fn = target else: fn = rh.traceable(name="Target")(target) return fn diff --git a/python/langsmith/evaluation/evaluator.py b/python/langsmith/evaluation/evaluator.py index 79c700feb..ee732a351 100644 --- a/python/langsmith/evaluation/evaluator.py +++ b/python/langsmith/evaluation/evaluator.py @@ -181,9 +181,8 @@ def __init__( self.afunc = run_helpers.ensure_traceable(func) self._name = getattr(func, "__name__", "DynamicRunEvaluator") else: - self.func = cast( - run_helpers.SupportsLangsmithExtra[_RUNNABLE_OUTPUT], - run_helpers.ensure_traceable(func), + self.func = run_helpers.ensure_traceable( + cast(Callable[[Run, Optional[Example]], _RUNNABLE_OUTPUT], func) ) self._name = getattr(func, "__name__", "DynamicRunEvaluator") @@ -383,9 +382,14 @@ def __init__( self.afunc = run_helpers.ensure_traceable(func) self._name = getattr(func, "__name__", "DynamicRunEvaluator") else: - self.func = cast( - run_helpers.SupportsLangsmithExtra[_COMPARISON_OUTPUT], - run_helpers.ensure_traceable(func), + self.func = run_helpers.ensure_traceable( + cast( + Callable[ + [Sequence[Run], Optional[Example]], + _COMPARISON_OUTPUT, + ], + func, + ) ) self._name = getattr(func, "__name__", "DynamicRunEvaluator") diff --git a/python/langsmith/run_helpers.py b/python/langsmith/run_helpers.py index ac5c5604c..7a0cc278a 100644 --- a/python/langsmith/run_helpers.py +++ b/python/langsmith/run_helpers.py @@ -35,12 +35,15 @@ runtime_checkable, ) +from typing_extensions import ParamSpec, TypeGuard + from langsmith import client as ls_client from langsmith import run_trees, utils from langsmith._internal import _aiter as aitertools +from langsmith.env import _runtime_env if TYPE_CHECKING: - from langchain.schema.runnable import Runnable + from langchain_core.runnables import Runnable LOGGER = logging.getLogger(__name__) _PARENT_RUN_TREE = contextvars.ContextVar[Optional[run_trees.RunTree]]( @@ -143,7 +146,9 @@ def tracing_context( get_run_tree_context = get_current_run_tree -def is_traceable_function(func: Callable) -> bool: +def is_traceable_function( + func: Callable[P, R], +) -> TypeGuard[SupportsLangsmithExtra[P, R]]: """Check if a function is @traceable decorated.""" return ( _is_traceable_function(func) @@ -152,12 +157,11 @@ def is_traceable_function(func: Callable) -> bool: ) -def ensure_traceable(func: Callable[..., R]) -> Callable[..., R]: +def ensure_traceable(func: Callable[P, R]) -> SupportsLangsmithExtra[P, R]: """Ensure that a function is traceable.""" - return cast( - SupportsLangsmithExtra, - (func if is_traceable_function(func) else traceable()(func)), - ) + if is_traceable_function(func): + return func + return traceable()(func) def is_async(func: Callable) -> bool: @@ -183,10 +187,11 @@ class LangSmithExtra(TypedDict, total=False): R = TypeVar("R", covariant=True) +P = ParamSpec("P") @runtime_checkable -class SupportsLangsmithExtra(Protocol, Generic[R]): +class SupportsLangsmithExtra(Protocol, Generic[P, R]): """Implementations of this Protoc accept an optional langsmith_extra parameter. Args: @@ -201,9 +206,9 @@ class SupportsLangsmithExtra(Protocol, Generic[R]): def __call__( self, - *args: Any, + *args: P.args, langsmith_extra: Optional[LangSmithExtra] = None, - **kwargs: Any, + **kwargs: P.kwargs, ) -> R: """Call the instance when it is called as a function. @@ -222,8 +227,8 @@ def __call__( @overload def traceable( - func: Callable[..., R], -) -> Callable[..., R]: ... + func: Callable[P, R], +) -> SupportsLangsmithExtra[P, R]: ... @overload @@ -238,7 +243,7 @@ def traceable( project_name: Optional[str] = None, process_inputs: Optional[Callable[[dict], dict]] = None, _invocation_params_fn: Optional[Callable[[dict], dict]] = None, -) -> Callable[[Callable[..., R]], SupportsLangsmithExtra[R]]: ... +) -> Callable[[Callable[P, R]], SupportsLangsmithExtra[P, R]]: ... def traceable( @@ -415,6 +420,10 @@ def manual_extra_function(x): ) def decorator(func: Callable): + func_sig = inspect.signature(func) + func_accepts_parent_run = func_sig.parameters.get("run_tree", None) is not None + func_accepts_config = func_sig.parameters.get("config", None) is not None + @functools.wraps(func) async def async_wrapper( *args: Any, @@ -429,15 +438,14 @@ async def async_wrapper( args=args, kwargs=kwargs, ) - func_accepts_parent_run = ( - inspect.signature(func).parameters.get("run_tree", None) is not None - ) + try: accepts_context = aitertools.accepts_context(asyncio.create_task) if func_accepts_parent_run: - fr_coro = func(*args, run_tree=run_container["new_run"], **kwargs) - else: - fr_coro = func(*args, **kwargs) + kwargs["run_tree"] = run_container["new_run"] + if not func_accepts_config: + kwargs.pop("config", None) + fr_coro = func(*args, **kwargs) if accepts_context: function_result = await asyncio.create_task( # type: ignore[call-arg] fr_coro, context=run_container["context"] @@ -465,20 +473,16 @@ async def async_generator_wrapper( args=args, kwargs=kwargs, ) - func_accepts_parent_run = ( - inspect.signature(func).parameters.get("run_tree", None) is not None - ) results: List[Any] = [] try: if func_accepts_parent_run: - async_gen_result = func( - *args, run_tree=run_container["new_run"], **kwargs - ) - else: + kwargs["run_tree"] = run_container["new_run"] # TODO: Nesting is ambiguous if a nested traceable function is only # called mid-generation. Need to explicitly accept run_tree to get # around this. - async_gen_result = func(*args, **kwargs) + if not func_accepts_config: + kwargs.pop("config", None) + async_gen_result = func(*args, **kwargs) # Can't iterate through if it's a coroutine accepts_context = aitertools.accepts_context(asyncio.create_task) if inspect.iscoroutine(async_gen_result): @@ -555,13 +559,10 @@ def wrapper( ) try: if func_accepts_parent_run: - function_result = run_container["context"].run( - func, *args, run_tree=run_container["new_run"], **kwargs - ) - else: - function_result = run_container["context"].run( - func, *args, **kwargs - ) + kwargs["run_tree"] = run_container["new_run"] + if not func_accepts_config: + kwargs.pop("config", None) + function_result = run_container["context"].run(func, *args, **kwargs) except BaseException as e: _container_end(run_container, error=e) raise e @@ -585,16 +586,13 @@ def generator_wrapper( results: List[Any] = [] try: if func_accepts_parent_run: - generator_result = run_container["context"].run( - func, *args, run_tree=run_container["new_run"], **kwargs - ) - else: + kwargs["run_tree"] = run_container["new_run"] # TODO: Nesting is ambiguous if a nested traceable function is only # called mid-generation. Need to explicitly accept run_tree to get # around this. - generator_result = run_container["context"].run( - func, *args, **kwargs - ) + if not func_accepts_config: + kwargs.pop("config", None) + generator_result = run_container["context"].run(func, *args, **kwargs) try: while True: item = run_container["context"].run(next, generator_result) @@ -645,6 +643,17 @@ def generator_wrapper( else: selected_wrapper = wrapper setattr(selected_wrapper, "__langsmith_traceable__", True) + sig = inspect.signature(selected_wrapper) + if not sig.parameters.get("config"): + sig = sig.replace( + parameters=[ + *sig.parameters.values(), + inspect.Parameter( + "config", inspect.Parameter.KEYWORD_ONLY, default=None + ), + ] + ) + selected_wrapper.__signature__ = sig # type: ignore[attr-defined] return selected_wrapper # If the decorator is called with no arguments, then it's being used as a @@ -762,17 +771,12 @@ def as_runnable(traceable_fn: Callable) -> Runnable: >>> runnable = as_runnable(my_function) """ try: - from langchain.callbacks.manager import ( - AsyncCallbackManager, - CallbackManager, - ) - from langchain.callbacks.tracers.langchain import LangChainTracer - from langchain.schema.runnable import RunnableConfig, RunnableLambda - from langchain.schema.runnable.utils import Input, Output + from langchain_core.runnables import RunnableConfig, RunnableLambda + from langchain_core.runnables.utils import Input, Output except ImportError as e: raise ImportError( - "as_runnable requires langchain to be installed. " - "You can install it with `pip install langchain`." + "as_runnable requires langchain-core to be installed. " + "You can install it with `pip install langchain-core`." ) from e if not is_traceable_function(traceable_fn): try: @@ -822,33 +826,6 @@ def __init__( ), ) - @staticmethod - def _configure_run_tree(callback_manager: Any) -> Optional[run_trees.RunTree]: - run_tree: Optional[run_trees.RunTree] = None - if isinstance(callback_manager, (CallbackManager, AsyncCallbackManager)): - lc_tracers = [ - handler - for handler in callback_manager.handlers - if isinstance(handler, LangChainTracer) - ] - if lc_tracers and callback_manager.parent_run_id: - lc_tracer = lc_tracers[0] - trace_id, dotted_order = lc_tracer.order_map[ - callback_manager.parent_run_id - ] - run_tree = run_trees.RunTree( - id=callback_manager.parent_run_id, - dotted_order=dotted_order, - trace_id=trace_id, - session_name=lc_tracer.project_name, - name="Wrapping", - run_type="chain", - inputs={}, - tags=callback_manager.tags, - extra={"metadata": callback_manager.metadata}, - ) - return run_tree - @staticmethod def _wrap_sync( func: Callable[..., Output], @@ -856,9 +833,7 @@ def _wrap_sync( """Wrap a synchronous function to make it asynchronous.""" def wrap_traceable(inputs: dict, config: RunnableConfig) -> Any: - run_tree = RunnableTraceable._configure_run_tree( - config.get("callbacks") - ) + run_tree = run_trees.RunTree.from_runnable_config(cast(dict, config)) return func(**inputs, langsmith_extra={"run_tree": run_tree}) return cast(Callable[[Input, RunnableConfig], Output], wrap_traceable) @@ -879,9 +854,7 @@ def _wrap_async( afunc_ = cast(Callable[..., Awaitable[Output]], afunc) async def awrap_traceable(inputs: dict, config: RunnableConfig) -> Any: - run_tree = RunnableTraceable._configure_run_tree( - config.get("callbacks") - ) + run_tree = run_trees.RunTree.from_runnable_config(cast(dict, config)) return await afunc_(**inputs, langsmith_extra={"run_tree": run_tree}) return cast( @@ -970,7 +943,9 @@ def _collect_extra(extra_outer: dict, langsmith_extra: LangSmithExtra) -> dict: return extra_inner -def _get_parent_run(langsmith_extra: LangSmithExtra) -> Optional[run_trees.RunTree]: +def _get_parent_run( + langsmith_extra: LangSmithExtra, config: Optional[dict] = None +) -> Optional[run_trees.RunTree]: parent = langsmith_extra.get("parent") if isinstance(parent, run_trees.RunTree): return parent @@ -981,6 +956,9 @@ def _get_parent_run(langsmith_extra: LangSmithExtra) -> Optional[run_trees.RunTr run_tree = langsmith_extra.get("run_tree") if run_tree: return run_tree + if _runtime_env.get_langchain_core_version() is not None: + if rt := run_trees.RunTree.from_runnable_config(config): + return rt return get_current_run_tree() @@ -1000,7 +978,7 @@ def _setup_run( run_type = container_input.get("run_type") or "chain" outer_project = _PROJECT_NAME.get() langsmith_extra = langsmith_extra or LangSmithExtra() - parent_run_ = _get_parent_run(langsmith_extra) + parent_run_ = _get_parent_run(langsmith_extra, kwargs.get("config")) project_cv = _PROJECT_NAME.get() selected_project = ( project_cv # From parent trace diff --git a/python/langsmith/run_trees.py b/python/langsmith/run_trees.py index 03b59ba7c..675aee4ec 100644 --- a/python/langsmith/run_trees.py +++ b/python/langsmith/run_trees.py @@ -13,6 +13,7 @@ except ImportError: from pydantic import Field, root_validator, validator +import threading import urllib.parse from langsmith import schemas as ls_schemas @@ -23,6 +24,17 @@ LANGSMITH_PREFIX = "langsmith-" LANGSMITH_DOTTED_ORDER = f"{LANGSMITH_PREFIX}trace" +_CLIENT: Optional[Client] = None +_LOCK = threading.Lock() + + +def _get_client() -> Client: + global _CLIENT + if _CLIENT is None: + with _LOCK: + if _CLIENT is None: + _CLIENT = Client() + return _CLIENT class RunTree(ls_schemas.RunBase): @@ -43,7 +55,7 @@ class RunTree(ls_schemas.RunBase): ) session_id: Optional[UUID] = Field(default=None, alias="project_id") extra: Dict = Field(default_factory=dict) - client: Client = Field(default_factory=Client, exclude=True) + client: Client = Field(default_factory=_get_client, exclude=True) dotted_order: str = Field( default="", description="The order of the run in the tree." ) @@ -60,7 +72,7 @@ class Config: def validate_client(cls, v: Optional[Client]) -> Client: """Ensure the client is specified.""" if v is None: - return Client() + return _get_client() return v @root_validator(pre=True) @@ -286,6 +298,59 @@ def from_dotted_order( } return cast(RunTree, cls.from_headers(headers, **kwargs)) + @classmethod + def from_runnable_config( + cls, + config: Optional[dict], + **kwargs: Any, + ) -> Optional[RunTree]: + """Create a new 'child' span from the provided runnable config. + + Requires langchain to be installed. + + Returns: + Optional[RunTree]: The new span or None if + no parent span information is found. + """ + try: + from langchain_core.callbacks.manager import ( + AsyncCallbackManager, + CallbackManager, + ) + from langchain_core.runnables import RunnableConfig, ensure_config + from langchain_core.tracers.langchain import LangChainTracer + except ImportError as e: + raise ImportError( + "RunTree.from_runnable_config requires langchain-core to be installed. " + "You can install it with `pip install langchain-core`." + ) from e + config_ = ensure_config( + cast(RunnableConfig, config) if isinstance(config, dict) else None + ) + if ( + (cb := config_.get("callbacks")) + and isinstance(cb, (CallbackManager, AsyncCallbackManager)) + and cb.parent_run_id + and ( + tracer := next( + (t for t in cb.handlers if isinstance(t, LangChainTracer)), + None, + ) + ) + ): + if hasattr(tracer, "order_map"): + dotted_order = tracer.order_map[cb.parent_run_id][1] + elif ( + run := tracer.run_map.get(str(cb.parent_run_id)) + ) and run.dotted_order: + dotted_order = run.dotted_order + else: + return None + kwargs["client"] = tracer.client + kwargs["project_name"] = tracer.project_name + return RunTree.from_dotted_order(dotted_order, **kwargs) + return None + @classmethod def from_headers(cls, headers: Dict[str, str], **kwargs: Any) -> Optional[RunTree]: """Create a new 'parent' span from the provided headers. diff --git a/python/langsmith/wrappers/_openai.py b/python/langsmith/wrappers/_openai.py index de6c77ded..5b6798e8d 100644 --- a/python/langsmith/wrappers/_openai.py +++ b/python/langsmith/wrappers/_openai.py @@ -211,13 +211,23 @@ class TracingExtra(TypedDict, total=False): client: Optional[ls_client.Client] -def wrap_openai(client: C, *, tracing_extra: Optional[TracingExtra] = None) -> C: +def wrap_openai( + client: C, + *, + tracing_extra: Optional[TracingExtra] = None, + chat_name: str = "ChatOpenAI", + completions_name: str = "OpenAI", +) -> C: """Patch the OpenAI client to make it traceable. Args: client (Union[OpenAI, AsyncOpenAI]): The client to patch. tracing_extra (Optional[TracingExtra], optional): Extra tracing information. Defaults to None. + chat_name (str, optional): The run name for the chat completions endpoint. + Defaults to "ChatOpenAI". + completions_name (str, optional): The run name for the completions endpoint. + Defaults to "OpenAI". Returns: Union[OpenAI, AsyncOpenAI]: The patched client. @@ -225,14 +235,14 @@ def wrap_openai(client: C, *, tracing_extra: Optional[TracingExtra] = None) -> C """ client.chat.completions.create = _get_wrapper( # type: ignore[method-assign] client.chat.completions.create, - "ChatOpenAI", + chat_name, _reduce_chat, tracing_extra=tracing_extra, invocation_params_fn=functools.partial(_infer_invocation_params, "chat"), ) client.completions.create = _get_wrapper( # type: ignore[method-assign] client.completions.create, - "OpenAI", + completions_name, _reduce_completions, tracing_extra=tracing_extra, invocation_params_fn=functools.partial(_infer_invocation_params, "text"), diff --git a/python/pyproject.toml b/python/pyproject.toml index 7f87cd28a..5fab25448 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langsmith" -version = "0.1.63" +version = "0.1.64" description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." authors = ["LangChain "] license = "MIT" diff --git a/python/tests/integration_tests/wrappers/test_openai.py b/python/tests/integration_tests/wrappers/test_openai.py index 40804b96d..11bc7bf3f 100644 --- a/python/tests/integration_tests/wrappers/test_openai.py +++ b/python/tests/integration_tests/wrappers/test_openai.py @@ -4,6 +4,7 @@ import pytest +import langsmith from langsmith.wrappers import wrap_openai @@ -12,8 +13,9 @@ def test_chat_sync_api(mock_session: mock.MagicMock, stream: bool): import openai # noqa + client = langsmith.Client(session=mock_session()) original_client = openai.Client() - patched_client = wrap_openai(openai.Client()) + patched_client = wrap_openai(openai.Client(), tracing_extra={"client": client}) messages = [{"role": "user", "content": "Say 'foo'"}] original = original_client.chat.completions.create( messages=messages, # noqa: [arg-type] @@ -50,8 +52,9 @@ def test_chat_sync_api(mock_session: mock.MagicMock, stream: bool): async def test_chat_async_api(mock_session: mock.MagicMock, stream: bool): import openai # noqa + client = langsmith.Client(session=mock_session()) original_client = openai.AsyncClient() - patched_client = wrap_openai(openai.AsyncClient()) + patched_client = wrap_openai(openai.AsyncClient(), tracing_extra={"client": client}) messages = [{"role": "user", "content": "Say 'foo'"}] original = await original_client.chat.completions.create( messages=messages, stream=stream, temperature=0, seed=42, model="gpt-3.5-turbo" @@ -84,8 +87,9 @@ async def test_chat_async_api(mock_session: mock.MagicMock, stream: bool): def test_completions_sync_api(mock_session: mock.MagicMock, stream: bool): import openai + client = langsmith.Client(session=mock_session()) original_client = openai.Client() - patched_client = wrap_openai(openai.Client()) + patched_client = wrap_openai(openai.Client(), tracing_extra={"client": client}) prompt = ("Say 'Foo' then stop.",) original = original_client.completions.create( model="gpt-3.5-turbo-instruct", @@ -124,8 +128,15 @@ def test_completions_sync_api(mock_session: mock.MagicMock, stream: bool): async def test_completions_async_api(mock_session: mock.MagicMock, stream: bool): import openai + client = langsmith.Client(session=mock_session()) + original_client = openai.AsyncClient() - patched_client = wrap_openai(openai.AsyncClient()) + patched_client = wrap_openai( + openai.AsyncClient(), + tracing_extra={"client": client}, + chat_name="chattychat", + completions_name="incompletions", + ) prompt = ("Say 'Hi i'm ChatGPT' then stop.",) original = await original_client.completions.create( model="gpt-3.5-turbo-instruct", @@ -158,7 +169,10 @@ async def test_completions_async_api(mock_session: mock.MagicMock, stream: bool) assert type(original) == type(patched) assert original.choices == patched.choices # Give the thread a chance. - time.sleep(0.1) + for _ in range(10): + time.sleep(0.1) + if mock_session.return_value.request.call_count >= 1: + break assert mock_session.return_value.request.call_count >= 1 for call in mock_session.return_value.request.call_args_list: assert call[0][0].upper() == "POST" diff --git a/python/tests/unit_tests/test_run_helpers.py b/python/tests/unit_tests/test_run_helpers.py index 61a1d2004..4170b4f2e 100644 --- a/python/tests/unit_tests/test_run_helpers.py +++ b/python/tests/unit_tests/test_run_helpers.py @@ -322,6 +322,85 @@ async def my_function(a, b, d): assert result == [6, 7] +def test_traceable_parent_from_runnable_config() -> None: + try: + from langchain.callbacks.tracers import LangChainTracer + from langchain.schema.runnable import RunnableLambda + except ImportError: + pytest.skip("Skipping test that requires langchain") + with tracing_context(enabled=True): + mock_client_ = _get_mock_client() + + @traceable() + def my_function(a: int) -> int: + return a * 2 + + my_function_runnable = RunnableLambda(my_function) + + assert ( + my_function_runnable.invoke( + 1, {"callbacks": [LangChainTracer(client=mock_client_)]} + ) + == 2 + ) + time.sleep(1) + # Inspect the mock_calls and assert that 2 runs were created, + # one for the parent and one for the child + mock_calls = mock_client_.session.request.mock_calls # type: ignore + posts = [] + for call in mock_calls: + if call.args: + assert call.args[0] == "POST" + assert call.args[1].startswith("https://api.smith.langchain.com") + body = json.loads(call.kwargs["data"]) + assert body["post"] + posts.extend(body["post"]) + assert len(posts) == 2 + parent = next(p for p in posts if p["parent_run_id"] is None) + child = next(p for p in posts if p["parent_run_id"] is not None) + assert child["parent_run_id"] == parent["id"] + + +def test_traceable_parent_from_runnable_config_accepts_config() -> None: + try: + from langchain.callbacks.tracers import LangChainTracer + from langchain.schema.runnable import RunnableLambda + except ImportError: + pytest.skip("Skipping test that requires langchain") + with tracing_context(enabled=True): + mock_client_ = _get_mock_client() + + @traceable() + def my_function(a: int, config: dict) -> int: + assert isinstance(config, dict) + return a * 2 + + my_function_runnable = RunnableLambda(my_function) + + assert ( + my_function_runnable.invoke( + 1, {"callbacks": [LangChainTracer(client=mock_client_)]} + ) + == 2 + ) + time.sleep(1) + # Inspect the mock_calls and assert that 2 runs were created, + # one for the parent and one for the child + mock_calls = mock_client_.session.request.mock_calls # type: ignore + posts = [] + for call in mock_calls: + if call.args: + assert call.args[0] == "POST" + assert call.args[1].startswith("https://api.smith.langchain.com") + body = json.loads(call.kwargs["data"]) + assert body["post"] + posts.extend(body["post"]) + assert len(posts) == 2 + parent = next(p for p in posts if p["parent_run_id"] is None) + child = next(p for p in posts if p["parent_run_id"] is not None) + assert child["parent_run_id"] == parent["id"] + + def test_traceable_project_name() -> None: with tracing_context(enabled=True): mock_client_ = _get_mock_client() @@ -350,7 +429,7 @@ def my_function(a: int, b: int, d: int) -> int: def my_other_function(run_tree) -> int: return my_function(1, 2, 3) - my_other_function() + my_other_function() # type: ignore time.sleep(0.25) # Inspect the mock_calls and assert that "my bar project" is in # both all POST runs in the single request. We want to ensure