Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
hinthornw committed Dec 6, 2023
1 parent a74c03d commit 2e0828d
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 157 deletions.
280 changes: 154 additions & 126 deletions libs/langchain/langchain/smith/evaluation/runner_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

from __future__ import annotations

import dataclasses
import functools
import inspect
import logging
import uuid
from datetime import datetime
from enum import Enum
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -34,7 +36,7 @@
from langsmith.client import Client
from langsmith.evaluation import RunEvaluator
from langsmith.run_helpers import as_runnable, is_traceable_function
from langsmith.schemas import Dataset, DataType, Example
from langsmith.schemas import Dataset, DataType, Example, TracerSession
from langsmith.utils import LangSmithError
from requests import HTTPError

Expand Down Expand Up @@ -918,7 +920,7 @@ def _prepare_eval_run(
project_name: str,
project_metadata: Optional[Dict[str, Any]] = None,
tags: Optional[List[str]] = None,
) -> Tuple[MCF, str, Dataset, List[Example]]:
) -> Tuple[MCF, TracerSession, Dataset, List[Example]]:
wrapped_model = _wrap_in_chain_factory(llm_or_chain_factory, dataset_name)
dataset = client.read_dataset(dataset_name=dataset_name)

Expand Down Expand Up @@ -955,104 +957,145 @@ def _prepare_eval_run(
examples = list(client.list_examples(dataset_id=dataset.id))
if not examples:
raise ValueError(f"Dataset {dataset_name} has no example rows.")
return wrapped_model, project_name, dataset, examples

return wrapped_model, project, dataset, examples


@dataclasses.dataclass
class _DatasetRunContainer:
"""A container to help manage the state of a eval run."""

client: Client
project: TracerSession
wrapped_model: MCF
examples: List[Example]
configs: List[RunnableConfig]

def _merge_test_outputs(
self, batch_results: list, all_eval_results: dict, all_execution_time: dict
) -> dict:
results: dict = {}
for example, output in zip(self.examples, batch_results):
feedback = all_eval_results.get(str(example.id), [])
results[str(example.id)] = {
"input": example.inputs,
"feedback": feedback,
"execution_time": all_execution_time.get(str(example.id)),
}
if isinstance(output, EvalError):
results[str(example.id)]["Error"] = output.Error
else:
results[str(example.id)]["output"] = output
if example.outputs:
results[str(example.id)]["reference"] = example.outputs
return results

def _collect_metrics(self) -> Tuple[dict, dict]:
all_eval_results = {}
all_execution_time = {}
for c in self.configs:
for callback in cast(list, c["callbacks"]):
if isinstance(callback, EvaluatorCallbackHandler):
eval_results = callback.logged_eval_results
all_eval_results.update(
{example_id: v for (_, example_id), v in eval_results.items()}
)
elif isinstance(callback, LangChainTracer):
run = callback.latest_run
example_id = callback.example_id
execution_time = (
(run.end_time - run.start_time).total_seconds()
if run and run.end_time
else None
)
all_execution_time[str(example_id)] = execution_time
return all_eval_results, all_execution_time

def _prepare_run_on_dataset(
client: Client,
dataset_name: str,
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
project_name: Optional[str],
evaluation: Optional[smith_eval.RunEvalConfig] = None,
tags: Optional[List[str]] = None,
input_mapper: Optional[Callable[[Dict], Any]] = None,
concurrency_level: int = 5,
project_metadata: Optional[Dict[str, Any]] = None,
) -> Tuple[MCF, str, List[Example], List[RunnableConfig]]:
project_name = project_name or name_generation.random_name()
wrapped_model, project_name, dataset, examples = _prepare_eval_run(
client,
dataset_name,
llm_or_chain_factory,
project_name,
project_metadata=project_metadata,
tags=tags,
)
wrapped_model = _wrap_in_chain_factory(llm_or_chain_factory)
run_evaluators = _setup_evaluation(
wrapped_model, examples, evaluation, dataset.data_type or DataType.kv
)
_validate_example_inputs(examples[0], wrapped_model, input_mapper)
progress_bar = progress.ProgressBarCallback(len(examples))
configs = [
RunnableConfig(
callbacks=[
LangChainTracer(
project_name=project_name,
client=client,
use_threading=False,
example_id=example.id,
),
EvaluatorCallbackHandler(
evaluators=run_evaluators or [],
client=client,
example_id=example.id,
max_concurrency=0,
),
progress_bar,
],
tags=tags or [],
max_concurrency=concurrency_level,
def _collect_test_results(
self,
batch_results: List[Union[dict, str, LLMResult, ChatResult]],
) -> TestResult:
wait_for_all_evaluators()
all_eval_results, all_execution_time = self._collect_metrics()
results = self._merge_test_outputs(
batch_results, all_eval_results, all_execution_time
)
return TestResult(
project_name=self.project.name,
results=results,
)
for example in examples
]
return wrapped_model, project_name, examples, configs


def _collect_test_results(
examples: List[Example],
batch_results: List[Union[dict, str, LLMResult, ChatResult]],
configs: List[RunnableConfig],
project_name: str,
) -> TestResult:
wait_for_all_evaluators()
all_eval_results = {}
all_execution_time = {}
for c in configs:
for callback in cast(list, c["callbacks"]):
if isinstance(callback, EvaluatorCallbackHandler):
eval_results = callback.logged_eval_results
all_eval_results.update(
{example_id: v for (_, example_id), v in eval_results.items()}
)
elif isinstance(callback, LangChainTracer):
run = callback.latest_run
example_id = callback.example_id
execution_time = (
(run.end_time - run.start_time).total_seconds()
if run and run.end_time
else None
)
all_execution_time[str(example_id)] = execution_time

results: dict = {}
for example, output in zip(examples, batch_results):
feedback = all_eval_results.get(str(example.id), [])
results[str(example.id)] = {
"input": example.inputs,
"feedback": feedback,
"execution_time": all_execution_time.get(str(example.id)),
}
if isinstance(output, EvalError):
results[str(example.id)]["Error"] = output.Error
else:
results[str(example.id)]["output"] = output
if example.outputs:
results[str(example.id)]["reference"] = example.outputs
return TestResult(
project_name=project_name,
results=results,
)
def finish(self, batch_results: list, verbose: bool = False) -> TestResult:
results = self._collect_test_results(batch_results)
if verbose:
try:
agg_feedback = results.get_aggregate_feedback()
_display_aggregate_results(agg_feedback)
except Exception as e:
logger.debug(f"Failed to print aggregate feedback: {repr(e)}")
try:
# Closing the project permits name changing and metric optimizations
self.client.update_project(self.project.id, end_time=datetime.utcnow())
except Exception as e:
logger.debug(f"Failed to close project: {repr(e)}")
return results

@classmethod
def prepare(
cls,
client: Client,
dataset_name: str,
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
project_name: Optional[str],
evaluation: Optional[smith_eval.RunEvalConfig] = None,
tags: Optional[List[str]] = None,
input_mapper: Optional[Callable[[Dict], Any]] = None,
concurrency_level: int = 5,
project_metadata: Optional[Dict[str, Any]] = None,
) -> _DatasetRunContainer:
project_name = project_name or name_generation.random_name()
wrapped_model, project, dataset, examples = _prepare_eval_run(
client,
dataset_name,
llm_or_chain_factory,
project_name,
project_metadata=project_metadata,
tags=tags,
)
wrapped_model = _wrap_in_chain_factory(llm_or_chain_factory)
run_evaluators = _setup_evaluation(
wrapped_model, examples, evaluation, dataset.data_type or DataType.kv
)
_validate_example_inputs(examples[0], wrapped_model, input_mapper)
progress_bar = progress.ProgressBarCallback(len(examples))
configs = [
RunnableConfig(
callbacks=[
LangChainTracer(
project_name=project.name,
client=client,
use_threading=False,
example_id=example.id,
),
EvaluatorCallbackHandler(
evaluators=run_evaluators or [],
client=client,
example_id=example.id,
max_concurrency=0,
),
progress_bar,
],
tags=tags or [],
max_concurrency=concurrency_level,
)
for example in examples
]
return cls(
client=client,
project=project,
wrapped_model=wrapped_model,
examples=examples,
configs=configs,
)


def _is_jupyter_environment() -> bool:
Expand Down Expand Up @@ -1120,7 +1163,7 @@ async def arun_on_dataset(
removal="0.0.305",
)
client = client or Client()
wrapped_model, project_name, examples, configs = _prepare_run_on_dataset(
container = _DatasetRunContainer.prepare(
client,
dataset_name,
llm_or_chain_factory,
Expand All @@ -1132,26 +1175,18 @@ async def arun_on_dataset(
project_metadata=project_metadata,
)
batch_results = await runnable_utils.gather_with_concurrency(
configs[0].get("max_concurrency"),
container.configs[0].get("max_concurrency"),
*map(
functools.partial(
_arun_llm_or_chain,
llm_or_chain_factory=wrapped_model,
llm_or_chain_factory=container.wrapped_model,
input_mapper=input_mapper,
),
examples,
configs,
container.examples,
container.configs,
),
)
results = _collect_test_results(examples, batch_results, configs, project_name)
if verbose:
try:
agg_feedback = results.get_aggregate_feedback()
print("\n Eval quantiles:")
print(agg_feedback)
except Exception as e:
logger.debug(f"Failed to print aggregate feedback: {repr(e)}")
return results
return container.finish(batch_results, verbose=verbose)


def run_on_dataset(
Expand Down Expand Up @@ -1180,7 +1215,7 @@ def run_on_dataset(
removal="0.0.305",
)
client = client or Client()
wrapped_model, project_name, examples, configs = _prepare_run_on_dataset(
container = _DatasetRunContainer.prepare(
client,
dataset_name,
llm_or_chain_factory,
Expand All @@ -1196,33 +1231,26 @@ def run_on_dataset(
_run_llm_or_chain(
example,
config,
llm_or_chain_factory=wrapped_model,
llm_or_chain_factory=container.wrapped_model,
input_mapper=input_mapper,
)
for example, config in zip(examples, configs)
for example, config in zip(container.examples, container.configs)
]
else:
with runnable_config.get_executor_for_config(configs[0]) as executor:
with runnable_config.get_executor_for_config(container.configs[0]) as executor:
batch_results = list(
executor.map(
functools.partial(
_run_llm_or_chain,
llm_or_chain_factory=wrapped_model,
llm_or_chain_factory=container.wrapped_model,
input_mapper=input_mapper,
),
examples,
configs,
container.examples,
container.configs,
)
)

results = _collect_test_results(examples, batch_results, configs, project_name)
if verbose:
try:
agg_feedback = results.get_aggregate_feedback()
_display_aggregate_results(agg_feedback)
except Exception as e:
logger.debug(f"Failed to print aggregate feedback: {repr(e)}")
return results
return container.finish(batch_results, verbose=verbose)


_RUN_ON_DATASET_DOCSTRING = """
Expand Down
Loading

0 comments on commit 2e0828d

Please sign in to comment.