-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Implement abstract
EvaluationHarness
class (#5)
* feat: Implement abstract `EvaluationHarness` class * Add license header * docstring lint * Sort imports * Lint
- Loading branch information
Showing
2 changed files
with
90 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,7 @@ | ||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from .eval_harness import EvalRunOverrides, EvaluationHarness | ||
|
||
_all_ = ["EvaluationHarness", "EvalRunOverrides"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from abc import ABC, abstractmethod | ||
from dataclasses import dataclass | ||
from typing import Any, Dict, Generic, Optional, Type, TypeVar | ||
|
||
from haystack import Pipeline | ||
from haystack.core.serialization import DeserializationCallbacks | ||
from haystack.evaluation.eval_run_result import BaseEvaluationRunResult | ||
|
||
|
||
@dataclass | ||
class EvalRunOverrides: | ||
""" | ||
Overrides for an evaluation run. | ||
Used to override the init parameters of components in either | ||
(or both) the evaluated and evaluation pipelines. Each key is | ||
a component name and its value a dictionary with init parameters | ||
to override. | ||
:param evaluated_pipeline_overrides: | ||
Overrides for the evaluated pipeline. | ||
:param evaluation_pipeline_overrides: | ||
Overrides for the evaluation pipeline. | ||
""" | ||
|
||
evaluated_pipeline_overrides: Optional[Dict[str, Dict[str, Any]]] = None | ||
evaluation_pipeline_overrides: Optional[Dict[str, Dict[str, Any]]] = None | ||
|
||
|
||
EvalRunInputT = TypeVar("EvalRunInputT") | ||
EvalRunOutputT = TypeVar("EvalRunOutputT", bound=BaseEvaluationRunResult) | ||
EvalRunOverridesT = TypeVar("EvalRunOverridesT") | ||
|
||
|
||
class EvaluationHarness(ABC, Generic[EvalRunInputT, EvalRunOverridesT, EvalRunOutputT]): | ||
""" | ||
Executes a pipeline with a given set of parameters, inputs and evaluates its outputs with an evaluation pipeline. | ||
""" | ||
|
||
@staticmethod | ||
def _override_pipeline(pipeline: Pipeline, parameter_overrides: Optional[Dict[str, Any]]) -> Pipeline: | ||
def component_pre_init_callback( | ||
name: str, cls: Type, init_params: Dict[str, Any] | ||
): # pylint: disable=unused-argument | ||
assert parameter_overrides is not None | ||
overrides = parameter_overrides.get(name) | ||
if overrides: | ||
init_params.update(overrides) | ||
|
||
def validate_overrides(): | ||
if parameter_overrides is None: | ||
return | ||
|
||
pipeline_components = pipeline.inputs(include_components_with_connected_inputs=True).keys() | ||
for component_name in parameter_overrides.keys(): | ||
if component_name not in pipeline_components: | ||
raise ValueError(f"Cannot override non-existent component '{component_name}'") | ||
|
||
callbacks = DeserializationCallbacks(component_pre_init_callback) | ||
if parameter_overrides: | ||
validate_overrides() | ||
serialized_pipeline = pipeline.dumps() | ||
pipeline = Pipeline.loads(serialized_pipeline, callbacks=callbacks) | ||
|
||
return pipeline | ||
|
||
@abstractmethod | ||
def run( | ||
self, inputs: EvalRunInputT, *, overrides: Optional[EvalRunOverridesT] = None, run_name: Optional[str] = None | ||
) -> EvalRunOutputT: | ||
""" | ||
Launch a evaluation run. | ||
:param inputs: | ||
Inputs to the evaluated and evaluation pipelines. | ||
:param overrides: | ||
Overrides for the harness. | ||
:param run_name: | ||
A name for the evaluation run. | ||
:returns: | ||
The output of the evaluation pipeline. | ||
""" |