From 70c6dc9c08f558b39b6afaf692a5eee68a6e43d0 Mon Sep 17 00:00:00 2001 From: Bagatur Date: Mon, 11 Nov 2024 15:04:01 -0800 Subject: [PATCH] python[patch]: pass Runnable to evaluate --- python/langsmith/evaluation/_arunner.py | 8 +++- python/langsmith/evaluation/_runner.py | 39 ++++++++++++++----- python/langsmith/run_helpers.py | 6 +-- .../unit_tests/evaluation/test_runner.py | 39 +++++++++++++------ 4 files changed, 66 insertions(+), 26 deletions(-) diff --git a/python/langsmith/evaluation/_arunner.py b/python/langsmith/evaluation/_arunner.py index 7c791ac09..ceb1e4443 100644 --- a/python/langsmith/evaluation/_arunner.py +++ b/python/langsmith/evaluation/_arunner.py @@ -40,6 +40,7 @@ _ExperimentManagerMixin, _extract_feedback_keys, _ForwardResults, + _is_langchain_runnable, _load_examples_map, _load_experiment, _load_tqdm, @@ -940,7 +941,7 @@ def _get_run(r: run_trees.RunTree) -> None: def _ensure_async_traceable( target: ATARGET_T, ) -> rh.SupportsLangsmithExtra[[dict], Awaitable]: - if not asyncio.iscoroutinefunction(target): + if not asyncio.iscoroutinefunction(target) and not _is_langchain_runnable(target): if callable(target): raise ValueError( "Target must be an async function. For sync functions, use evaluate." @@ -961,7 +962,10 @@ def _ensure_async_traceable( ) if rh.is_traceable_function(target): return target # type: ignore - return rh.traceable(name="AsyncTarget")(target) + else: + if _is_langchain_runnable(target): + target = target.ainvoke # type: ignore[attr-defined] + return rh.traceable(name="AsyncTarget")(target) def _aresolve_data( diff --git a/python/langsmith/evaluation/_runner.py b/python/langsmith/evaluation/_runner.py index 111986b76..d0d823bbc 100644 --- a/python/langsmith/evaluation/_runner.py +++ b/python/langsmith/evaluation/_runner.py @@ -58,6 +58,7 @@ if TYPE_CHECKING: import pandas as pd + from langchain_core.runnables import Runnable DataFrame = pd.DataFrame else: @@ -96,7 +97,7 @@ def evaluate( - target: TARGET_T, + target: Union[TARGET_T, Runnable], /, data: DATA_T, evaluators: Optional[Sequence[EVALUATOR_T]] = None, @@ -878,12 +879,12 @@ def _print_comparative_experiment_start( ) -def _is_callable(target: Union[TARGET_T, Iterable[schemas.Run]]) -> bool: - return callable(target) or (hasattr(target, "invoke") and callable(target.invoke)) +def _is_callable(target: Union[TARGET_T, Iterable[schemas.Run], Runnable]) -> bool: + return callable(target) or _is_langchain_runnable(target) def _evaluate( - target: Union[TARGET_T, Iterable[schemas.Run]], + target: Union[TARGET_T, Iterable[schemas.Run], Runnable], /, data: DATA_T, evaluators: Optional[Sequence[EVALUATOR_T]] = None, @@ -1664,12 +1665,13 @@ def _resolve_data( def _ensure_traceable( - target: TARGET_T | rh.SupportsLangsmithExtra[[dict], dict], + target: TARGET_T | rh.SupportsLangsmithExtra[[dict], dict] | Runnable, ) -> rh.SupportsLangsmithExtra[[dict], dict]: """Ensure the target function is traceable.""" - if not callable(target): + if not _is_callable(target): raise ValueError( - "Target must be a callable function. For example:\n\n" + "Target must be a callable function or a langchain/langgraph object. For " + "example:\n\n" "def predict(inputs: dict) -> dict:\n" " # do work, like chain.invoke(inputs)\n" " return {...}\n\n" @@ -1679,9 +1681,11 @@ def _ensure_traceable( ")" ) if rh.is_traceable_function(target): - fn = target + fn: rh.SupportsLangsmithExtra[[dict], dict] = target else: - fn = rh.traceable(name="Target")(target) + if _is_langchain_runnable(target): + target = target.invoke # type: ignore[union-attr] + fn = rh.traceable(name="Target")(cast(Callable, target)) return fn @@ -1923,3 +1927,20 @@ def _flatten_experiment_results( } for x in results[start:end] ] + + +@functools.lru_cache(maxsize=1) +def _import_langchain_runnable() -> Optional[type]: + try: + from langchain_core.runnables import Runnable + + return Runnable + except ImportError: + return None + + +def _is_langchain_runnable(o: Any) -> bool: + if (Runnable := _import_langchain_runnable()) and isinstance(o, Runnable): + return True + else: + return False diff --git a/python/langsmith/run_helpers.py b/python/langsmith/run_helpers.py index eaa838192..7510b75ee 100644 --- a/python/langsmith/run_helpers.py +++ b/python/langsmith/run_helpers.py @@ -151,9 +151,7 @@ def tracing_context( get_run_tree_context = get_current_run_tree -def is_traceable_function( - func: Callable[P, R], -) -> TypeGuard[SupportsLangsmithExtra[P, R]]: +def is_traceable_function(func: Any) -> TypeGuard[SupportsLangsmithExtra[P, R]]: """Check if a function is @traceable decorated.""" return ( _is_traceable_function(func) @@ -1445,7 +1443,7 @@ def _handle_container_end( LOGGER.warning(f"Unable to process trace outputs: {repr(e)}") -def _is_traceable_function(func: Callable) -> bool: +def _is_traceable_function(func: Any) -> bool: return getattr(func, "__langsmith_traceable__", False) diff --git a/python/tests/unit_tests/evaluation/test_runner.py b/python/tests/unit_tests/evaluation/test_runner.py index d20960d3e..2c787f437 100644 --- a/python/tests/unit_tests/evaluation/test_runner.py +++ b/python/tests/unit_tests/evaluation/test_runner.py @@ -120,7 +120,8 @@ def _wait_until(condition: Callable, timeout: int = 8): @pytest.mark.skipif(sys.version_info < (3, 9), reason="requires python3.9 or higher") @pytest.mark.parametrize("blocking", [False, True]) -def test_evaluate_results(blocking: bool) -> None: +@pytest.mark.parametrize("as_runnable", [False, True]) +def test_evaluate_results(blocking: bool, as_runnable: bool) -> None: session = mock.Mock() ds_name = "my-dataset" ds_id = "00886375-eb2a-4038-9032-efff60309896" @@ -180,6 +181,15 @@ def predict(inputs: dict) -> dict: ordering_of_stuff.append("predict") return {"output": inputs["in"] + 1} + if as_runnable: + try: + from langchain_core.runnables import RunnableLambda + except ImportError: + pytest.skip("langchain-core not installed.") + return + else: + predict = RunnableLambda(predict) + def score_value_first(run, example): ordering_of_stuff.append("evaluate") return {"score": 0.3} @@ -263,26 +273,24 @@ async def my_other_func(inputs: dict, other_val: int): with pytest.raises(ValueError, match=match): evaluate(functools.partial(my_other_func, other_val=3), data="foo") + if sys.version_info < (3, 10): + return try: from langchain_core.runnables import RunnableLambda except ImportError: pytest.skip("langchain-core not installed.") - - @RunnableLambda - def foo(inputs: dict): - return "bar" - - with pytest.raises(ValueError, match=match): - evaluate(foo.ainvoke, data="foo") - if sys.version_info < (3, 10): return with pytest.raises(ValueError, match=match): - evaluate(functools.partial(foo.ainvoke, inputs={"foo": "bar"}), data="foo") + evaluate( + functools.partial(RunnableLambda(my_func).ainvoke, inputs={"foo": "bar"}), + data="foo", + ) @pytest.mark.skipif(sys.version_info < (3, 9), reason="requires python3.9 or higher") @pytest.mark.parametrize("blocking", [False, True]) -async def test_aevaluate_results(blocking: bool) -> None: +@pytest.mark.parametrize("as_runnable", [False, True]) +async def test_aevaluate_results(blocking: bool, as_runnable: bool) -> None: session = mock.Mock() ds_name = "my-dataset" ds_id = "00886375-eb2a-4038-9032-efff60309896" @@ -343,6 +351,15 @@ async def predict(inputs: dict) -> dict: ordering_of_stuff.append("predict") return {"output": inputs["in"] + 1} + if as_runnable: + try: + from langchain_core.runnables import RunnableLambda + except ImportError: + pytest.skip("langchain-core not installed.") + return + else: + predict = RunnableLambda(predict) + async def score_value_first(run, example): ordering_of_stuff.append("evaluate") return {"score": 0.3}