Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Refactor StatisticalEvaluator #6999

Merged
merged 4 commits into from
Feb 15, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions haystack/components/eval/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .sas_evaluator import SASEvaluator
from .statistical_evaluator import StatisticalEvaluator
from .statistical_evaluator import StatisticalEvaluator, StatisticalMetric

__all__ = ["SASEvaluator", "StatisticalEvaluator"]
__all__ = ["SASEvaluator", "StatisticalEvaluator", "StatisticalMetric"]
47 changes: 28 additions & 19 deletions haystack/components/eval/statistical_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import collections
from enum import Enum
from typing import Any, Dict, List
from typing import Any, Dict, List, Union

from numpy import array as np_array
from numpy import mean as np_mean
Expand All @@ -9,6 +9,19 @@
from haystack.core.component import component


class StatisticalMetric(Enum):
"""
Metrics supported by the StatisticalEvaluator.
"""

F1 = "f1"
EM = "exact_match"

@classmethod
def from_string(cls, metric: str) -> "StatisticalMetric":
silvanocerza marked this conversation as resolved.
Show resolved Hide resolved
return {"f1": cls.F1, "exact_match": cls.EM}[metric]
silvanocerza marked this conversation as resolved.
Show resolved Hide resolved


@component
class StatisticalEvaluator:
"""
Expand All @@ -20,49 +33,44 @@ class StatisticalEvaluator:
- Exact Match: Measures the proportion of cases where prediction is identical to the expected label.
"""

class Metric(Enum):
"""
Supported metrics
"""

F1 = "F1"
EM = "Exact Match"

def __init__(self, metric: Metric):
def __init__(self, metric: Union[str, StatisticalMetric]):
"""
Creates a new instance of StatisticalEvaluator.

:param metric: Metric to use for evaluation in this component. Supported metrics are F1 and Exact Match.
:type metric: Metric
"""
if isinstance(metric, str):
metric = StatisticalMetric.from_string(metric)
silvanocerza marked this conversation as resolved.
Show resolved Hide resolved
self._metric = metric

self._metric_function = {
StatisticalEvaluator.Metric.F1: self._f1,
StatisticalEvaluator.Metric.EM: self._exact_match,
}[self._metric]
self._metric_function = {StatisticalMetric.F1: self._f1, StatisticalMetric.EM: self._exact_match}[self._metric]

def to_dict(self) -> Dict[str, Any]:
return default_to_dict(self, metric=self._metric.value)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "StatisticalEvaluator":
data["init_parameters"]["metric"] = StatisticalEvaluator.Metric(data["init_parameters"]["metric"])
data["init_parameters"]["metric"] = StatisticalMetric(data["init_parameters"]["metric"])
return default_from_dict(cls, data)

@component.output_types(result=float)
def run(self, labels: List[str], predictions: List[str]) -> Dict[str, Any]:
"""
Run the StatisticalEvaluator to compute the metric between a list of predictions and a list of labels.
Both must be list of strings of same length.
Returns a dictionary containing the result of the chosen metric.

:param predictions: List of predictions.
:param labels: List of labels against which the predictions are compared.
:returns: A dictionary with the following outputs:
* `result` - Calculated result of the chosen metric.
"""
if len(labels) != len(predictions):
raise ValueError("The number of predictions and labels must be the same.")

return {"result": self._metric_function(labels, predictions)}

def _f1(self, labels: List[str], predictions: List[str]):
@staticmethod
def _f1(labels: List[str], predictions: List[str]):
"""
Measure word overlap between predictions and labels.
"""
Expand All @@ -88,7 +96,8 @@ def _f1(self, labels: List[str], predictions: List[str]):

return np_mean(scores)

def _exact_match(self, labels: List[str], predictions: List[str]) -> float:
@staticmethod
def _exact_match(labels: List[str], predictions: List[str]) -> float:
"""
Measure the proportion of cases where predictiond is identical to the the expected label.
"""
Expand Down
38 changes: 21 additions & 17 deletions test/components/eval/test_statistical_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,46 @@
import pytest

from haystack.components.eval import StatisticalEvaluator
from haystack.components.eval import StatisticalEvaluator, StatisticalMetric


class TestStatisticalEvaluator:
def test_init_default(self):
evaluator = StatisticalEvaluator(metric=StatisticalEvaluator.Metric.F1)
assert evaluator._metric == StatisticalEvaluator.Metric.F1
evaluator = StatisticalEvaluator(metric=StatisticalMetric.F1)
assert evaluator._metric == StatisticalMetric.F1

def test_init_with_string(self):
evaluator = StatisticalEvaluator(metric="exact_match")
assert evaluator._metric == StatisticalMetric.EM

def test_to_dict(self):
evaluator = StatisticalEvaluator(metric=StatisticalEvaluator.Metric.F1)
evaluator = StatisticalEvaluator(metric=StatisticalMetric.F1)

expected_dict = {
"type": "haystack.components.eval.statistical_evaluator.StatisticalEvaluator",
"init_parameters": {"metric": "F1"},
"init_parameters": {"metric": "f1"},
}
assert evaluator.to_dict() == expected_dict

def test_from_dict(self):
evaluator = StatisticalEvaluator.from_dict(
{
"type": "haystack.components.eval.statistical_evaluator.StatisticalEvaluator",
"init_parameters": {"metric": "F1"},
"init_parameters": {"metric": "f1"},
}
)

assert evaluator._metric == StatisticalEvaluator.Metric.F1
assert evaluator._metric == StatisticalMetric.F1


class TestStatisticalEvaluatorF1:
def test_run_with_empty_inputs(self):
evaluator = StatisticalEvaluator(metric=StatisticalEvaluator.Metric.F1)
evaluator = StatisticalEvaluator(metric=StatisticalMetric.F1)
result = evaluator.run(labels=[], predictions=[])
assert len(result) == 1
assert result["result"] == 0.0

def test_run_with_different_lengths(self):
evaluator = StatisticalEvaluator(metric=StatisticalEvaluator.Metric.F1)
evaluator = StatisticalEvaluator(metric=StatisticalMetric.F1)
labels = [
"A construction budget of US $2.3 billion",
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
Expand All @@ -50,7 +54,7 @@ def test_run_with_different_lengths(self):
evaluator.run(labels=labels, predictions=predictions)

def test_run_with_matching_predictions(self):
evaluator = StatisticalEvaluator(metric=StatisticalEvaluator.Metric.F1)
evaluator = StatisticalEvaluator(metric=StatisticalMetric.F1)
labels = ["OpenSource", "HaystackAI", "LLMs"]
predictions = ["OpenSource", "HaystackAI", "LLMs"]
result = evaluator.run(labels=labels, predictions=predictions)
Expand All @@ -59,15 +63,15 @@ def test_run_with_matching_predictions(self):
assert result["result"] == 1.0

def test_run_with_single_prediction(self):
evaluator = StatisticalEvaluator(metric=StatisticalEvaluator.Metric.F1)
evaluator = StatisticalEvaluator(metric=StatisticalMetric.F1)

result = evaluator.run(labels=["Source"], predictions=["Open Source"])
assert len(result) == 1
assert result["result"] == pytest.approx(2 / 3)

def test_run_with_mismatched_predictions(self):
labels = ["Source", "HaystackAI"]
evaluator = StatisticalEvaluator(metric=StatisticalEvaluator.Metric.F1)
evaluator = StatisticalEvaluator(metric=StatisticalMetric.F1)
predictions = ["Open Source", "HaystackAI"]
result = evaluator.run(labels=labels, predictions=predictions)
assert len(result) == 1
Expand All @@ -76,13 +80,13 @@ def test_run_with_mismatched_predictions(self):

class TestStatisticalEvaluatorExactMatch:
def test_run_with_empty_inputs(self):
evaluator = StatisticalEvaluator(metric=StatisticalEvaluator.Metric.EM)
evaluator = StatisticalEvaluator(metric=StatisticalMetric.EM)
result = evaluator.run(predictions=[], labels=[])
assert len(result) == 1
assert result["result"] == 0.0

def test_run_with_different_lengths(self):
evaluator = StatisticalEvaluator(metric=StatisticalEvaluator.Metric.EM)
evaluator = StatisticalEvaluator(metric=StatisticalMetric.EM)
labels = [
"A construction budget of US $2.3 billion",
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
Expand All @@ -97,21 +101,21 @@ def test_run_with_different_lengths(self):

def test_run_with_matching_predictions(self):
labels = ["OpenSource", "HaystackAI", "LLMs"]
evaluator = StatisticalEvaluator(metric=StatisticalEvaluator.Metric.EM)
evaluator = StatisticalEvaluator(metric=StatisticalMetric.EM)
predictions = ["OpenSource", "HaystackAI", "LLMs"]
result = evaluator.run(labels=labels, predictions=predictions)

assert len(result) == 1
assert result["result"] == 1.0

def test_run_with_single_prediction(self):
evaluator = StatisticalEvaluator(metric=StatisticalEvaluator.Metric.EM)
evaluator = StatisticalEvaluator(metric=StatisticalMetric.EM)
result = evaluator.run(labels=["OpenSource"], predictions=["OpenSource"])
assert len(result) == 1
assert result["result"] == 1.0

def test_run_with_mismatched_predictions(self):
evaluator = StatisticalEvaluator(metric=StatisticalEvaluator.Metric.EM)
evaluator = StatisticalEvaluator(metric=StatisticalMetric.EM)
labels = ["Source", "HaystackAI", "LLMs"]
predictions = ["OpenSource", "HaystackAI", "LLMs"]
result = evaluator.run(labels=labels, predictions=predictions)
Expand Down