Skip to content

Commit

Permalink
python[patch]: pass Runnable to evaluate
Browse files Browse the repository at this point in the history
  • Loading branch information
baskaryan committed Nov 11, 2024
1 parent 5d5cace commit 70c6dc9
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 26 deletions.
8 changes: 6 additions & 2 deletions python/langsmith/evaluation/_arunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
_ExperimentManagerMixin,
_extract_feedback_keys,
_ForwardResults,
_is_langchain_runnable,
_load_examples_map,
_load_experiment,
_load_tqdm,
Expand Down Expand Up @@ -940,7 +941,7 @@ def _get_run(r: run_trees.RunTree) -> None:
def _ensure_async_traceable(
target: ATARGET_T,
) -> rh.SupportsLangsmithExtra[[dict], Awaitable]:
if not asyncio.iscoroutinefunction(target):
if not asyncio.iscoroutinefunction(target) and not _is_langchain_runnable(target):
if callable(target):
raise ValueError(
"Target must be an async function. For sync functions, use evaluate."
Expand All @@ -961,7 +962,10 @@ def _ensure_async_traceable(
)
if rh.is_traceable_function(target):
return target # type: ignore
return rh.traceable(name="AsyncTarget")(target)
else:
if _is_langchain_runnable(target):
target = target.ainvoke # type: ignore[attr-defined]
return rh.traceable(name="AsyncTarget")(target)


def _aresolve_data(
Expand Down
39 changes: 30 additions & 9 deletions python/langsmith/evaluation/_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@

if TYPE_CHECKING:
import pandas as pd
from langchain_core.runnables import Runnable

DataFrame = pd.DataFrame
else:
Expand Down Expand Up @@ -96,7 +97,7 @@


def evaluate(
target: TARGET_T,
target: Union[TARGET_T, Runnable],
/,
data: DATA_T,
evaluators: Optional[Sequence[EVALUATOR_T]] = None,
Expand Down Expand Up @@ -878,12 +879,12 @@ def _print_comparative_experiment_start(
)


def _is_callable(target: Union[TARGET_T, Iterable[schemas.Run]]) -> bool:
return callable(target) or (hasattr(target, "invoke") and callable(target.invoke))
def _is_callable(target: Union[TARGET_T, Iterable[schemas.Run], Runnable]) -> bool:
return callable(target) or _is_langchain_runnable(target)


def _evaluate(
target: Union[TARGET_T, Iterable[schemas.Run]],
target: Union[TARGET_T, Iterable[schemas.Run], Runnable],
/,
data: DATA_T,
evaluators: Optional[Sequence[EVALUATOR_T]] = None,
Expand Down Expand Up @@ -1664,12 +1665,13 @@ def _resolve_data(


def _ensure_traceable(
target: TARGET_T | rh.SupportsLangsmithExtra[[dict], dict],
target: TARGET_T | rh.SupportsLangsmithExtra[[dict], dict] | Runnable,
) -> rh.SupportsLangsmithExtra[[dict], dict]:
"""Ensure the target function is traceable."""
if not callable(target):
if not _is_callable(target):
raise ValueError(
"Target must be a callable function. For example:\n\n"
"Target must be a callable function or a langchain/langgraph object. For "
"example:\n\n"
"def predict(inputs: dict) -> dict:\n"
" # do work, like chain.invoke(inputs)\n"
" return {...}\n\n"
Expand All @@ -1679,9 +1681,11 @@ def _ensure_traceable(
")"
)
if rh.is_traceable_function(target):
fn = target
fn: rh.SupportsLangsmithExtra[[dict], dict] = target
else:
fn = rh.traceable(name="Target")(target)
if _is_langchain_runnable(target):
target = target.invoke # type: ignore[union-attr]
fn = rh.traceable(name="Target")(cast(Callable, target))
return fn


Expand Down Expand Up @@ -1923,3 +1927,20 @@ def _flatten_experiment_results(
}
for x in results[start:end]
]


@functools.lru_cache(maxsize=1)
def _import_langchain_runnable() -> Optional[type]:
try:
from langchain_core.runnables import Runnable

return Runnable
except ImportError:
return None


def _is_langchain_runnable(o: Any) -> bool:
if (Runnable := _import_langchain_runnable()) and isinstance(o, Runnable):
return True
else:
return False
6 changes: 2 additions & 4 deletions python/langsmith/run_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,7 @@ def tracing_context(
get_run_tree_context = get_current_run_tree


def is_traceable_function(
func: Callable[P, R],
) -> TypeGuard[SupportsLangsmithExtra[P, R]]:
def is_traceable_function(func: Any) -> TypeGuard[SupportsLangsmithExtra[P, R]]:
"""Check if a function is @traceable decorated."""
return (
_is_traceable_function(func)
Expand Down Expand Up @@ -1445,7 +1443,7 @@ def _handle_container_end(
LOGGER.warning(f"Unable to process trace outputs: {repr(e)}")


def _is_traceable_function(func: Callable) -> bool:
def _is_traceable_function(func: Any) -> bool:
return getattr(func, "__langsmith_traceable__", False)


Expand Down
39 changes: 28 additions & 11 deletions python/tests/unit_tests/evaluation/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ def _wait_until(condition: Callable, timeout: int = 8):

@pytest.mark.skipif(sys.version_info < (3, 9), reason="requires python3.9 or higher")
@pytest.mark.parametrize("blocking", [False, True])
def test_evaluate_results(blocking: bool) -> None:
@pytest.mark.parametrize("as_runnable", [False, True])
def test_evaluate_results(blocking: bool, as_runnable: bool) -> None:
session = mock.Mock()
ds_name = "my-dataset"
ds_id = "00886375-eb2a-4038-9032-efff60309896"
Expand Down Expand Up @@ -180,6 +181,15 @@ def predict(inputs: dict) -> dict:
ordering_of_stuff.append("predict")
return {"output": inputs["in"] + 1}

if as_runnable:
try:
from langchain_core.runnables import RunnableLambda
except ImportError:
pytest.skip("langchain-core not installed.")
return
else:
predict = RunnableLambda(predict)

def score_value_first(run, example):
ordering_of_stuff.append("evaluate")
return {"score": 0.3}
Expand Down Expand Up @@ -263,26 +273,24 @@ async def my_other_func(inputs: dict, other_val: int):
with pytest.raises(ValueError, match=match):
evaluate(functools.partial(my_other_func, other_val=3), data="foo")

if sys.version_info < (3, 10):
return
try:
from langchain_core.runnables import RunnableLambda
except ImportError:
pytest.skip("langchain-core not installed.")

@RunnableLambda
def foo(inputs: dict):
return "bar"

with pytest.raises(ValueError, match=match):
evaluate(foo.ainvoke, data="foo")
if sys.version_info < (3, 10):
return
with pytest.raises(ValueError, match=match):
evaluate(functools.partial(foo.ainvoke, inputs={"foo": "bar"}), data="foo")
evaluate(
functools.partial(RunnableLambda(my_func).ainvoke, inputs={"foo": "bar"}),
data="foo",
)


@pytest.mark.skipif(sys.version_info < (3, 9), reason="requires python3.9 or higher")
@pytest.mark.parametrize("blocking", [False, True])
async def test_aevaluate_results(blocking: bool) -> None:
@pytest.mark.parametrize("as_runnable", [False, True])
async def test_aevaluate_results(blocking: bool, as_runnable: bool) -> None:
session = mock.Mock()
ds_name = "my-dataset"
ds_id = "00886375-eb2a-4038-9032-efff60309896"
Expand Down Expand Up @@ -343,6 +351,15 @@ async def predict(inputs: dict) -> dict:
ordering_of_stuff.append("predict")
return {"output": inputs["in"] + 1}

if as_runnable:
try:
from langchain_core.runnables import RunnableLambda
except ImportError:
pytest.skip("langchain-core not installed.")
return
else:
predict = RunnableLambda(predict)

async def score_value_first(run, example):
ordering_of_stuff.append("evaluate")
return {"score": 0.3}
Expand Down

0 comments on commit 70c6dc9

Please sign in to comment.