Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add StatisticalEvaluator component #6982

Merged
merged 4 commits into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion haystack/components/eval/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .sas_evaluator import SASEvaluator
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should go into haystack.components.evaluators

from .statistical_evaluator import StatisticalEvaluator

__all__ = ["SASEvaluator"]
__all__ = ["SASEvaluator", "StatisticalEvaluator"]
131 changes: 131 additions & 0 deletions haystack/components/eval/statistical_evaluator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import collections
from enum import Enum
from typing import Any, Dict, List, Optional

from numpy import array as np_array
from numpy import mean as np_mean

from haystack import default_from_dict, default_to_dict
from haystack.core.component import component

from .preprocess import _preprocess_text


@component
class StatisticalEvaluator:
"""
StatisticalEvaluator is a component that evaluates the performance of a model based on statistical metrics.
It's usually used in QA and Retrieval Augmented Generation (RAG) pipelines to evaluate the quality of the generated answers.

The supported metrics are:
- F1: Measures word overlap between predictions and labels.
- Exact Match: Measures the proportion of cases where prediction is identical to the expected label.
"""

class Metric(Enum):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's extract this and move it out the top-level namespace (components.evaluators). We should call it StatisticalMetrics (to disambiguate it from the others).

"""
Supported metrics
"""

F1 = "F1"
EM = "Exact Match"
Comment on lines +30 to +31
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably stick to snake_case like we do elsewhere.


def __init__(
self,
labels: List[str],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

metric: Metric,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make this a Union[str, StatisticalMetric] and add a from_str function to the latter.

regexes_to_ignore: Optional[List[str]] = None,
ignore_case: bool = False,
ignore_punctuation: bool = False,
ignore_numbers: bool = False,
Comment on lines +37 to +40
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

):
"""
Creates a new instance of StatisticalEvaluator.

:param labels: The list of expected answers.
:param metric: Metric to use for evaluation in this component. Supported metrics are F1 and Exact Match.
:param regexes_to_ignore: A list of regular expressions. If provided, it removes substrings
matching these regular expressions from both predictions and labels before comparison. Defaults to None.
:param ignore_case: If True, performs case-insensitive comparison. Defaults to False.
:param ignore_punctuation: If True, removes punctuation from both predictions and labels before
comparison. Defaults to False.
:param ignore_numbers: If True, removes numerical digits from both predictions and labels
before comparison. Defaults to False.
"""
self._labels = labels
self._metric = metric
self._regexes_to_ignore = regexes_to_ignore
self._ignore_case = ignore_case
self._ignore_punctuation = ignore_punctuation
self._ignore_numbers = ignore_numbers

self._metric_function = {
StatisticalEvaluator.Metric.F1: self._f1,
StatisticalEvaluator.Metric.EM: self._exact_match,
}[self._metric]

def to_dict(self) -> Dict[str, Any]:
return default_to_dict(
self,
labels=self._labels,
metric=self._metric.value,
regexes_to_ignore=self._regexes_to_ignore,
ignore_case=self._ignore_case,
ignore_punctuation=self._ignore_punctuation,
ignore_numbers=self._ignore_numbers,
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "StatisticalEvaluator":
data["init_parameters"]["metric"] = StatisticalEvaluator.Metric(data["init_parameters"]["metric"])
return default_from_dict(cls, data)

@component.output_types(result=float)
def run(self, predictions: List[str]) -> Dict[str, Any]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I also noticed the comment about length of these predictions, why not add a check here for zero length and then we can omit checks in all metrics?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As of now I moved only the metrics that we already had from the old API. In future PRs other will be added, I'm unsure whether those will return the same values as F1 and Exact Match if length is zero. That's the only reason I've done it like this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if len(predictions) != len(self._labels):
raise ValueError("The number of predictions and labels must be the same.")

predictions = _preprocess_text(
predictions, self._regexes_to_ignore, self._ignore_case, self._ignore_punctuation, self._ignore_numbers
)
labels = _preprocess_text(
self._labels, self._regexes_to_ignore, self._ignore_case, self._ignore_punctuation, self._ignore_numbers
)

return {"result": self._metric_function(labels, predictions)}

def _f1(self, labels: List[str], predictions: List[str]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be a @staticmethod.

"""
Measure word overlap between predictions and labels.
"""
if len(predictions) == 0:
# We expect callers of this function already checked if predictions and labels are equal length
return 0.0

scores: List[float] = []
tokenized_predictions = [pred.split() for pred in predictions]
tokenized_labels = [label.split() for label in labels]
for label_tokens, prediction_tokens in zip(tokenized_labels, tokenized_predictions):
common = collections.Counter(label_tokens) & collections.Counter(prediction_tokens)
num_same = sum(common.values())
if len(label_tokens) == 0 or len(prediction_tokens) == 0:
# If either is no-answer, then F1 is 1 if they agree, 0 otherwise
return int(label_tokens == prediction_tokens)
if num_same == 0:
return 0
precision = 1.0 * num_same / len(prediction_tokens)
recall = 1.0 * num_same / len(label_tokens)
f1 = (2 * precision * recall) / (precision + recall)
scores.append(f1)

return np_mean(scores)

def _exact_match(self, labels: List[str], predictions: List[str]) -> float:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be a @staticmethod.

"""
Measure the proportion of cases where predictiond is identical to the the expected label.
"""
if len(predictions) == 0:
# We expect callers of this function already checked if predictions and labels are equal length
return 0.0
score_list = np_array(predictions) == np_array(labels)
return np_mean(score_list)
121 changes: 1 addition & 120 deletions haystack/evaluation/eval.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
import collections
from typing import Any, Callable, Dict, List, Optional, Union

import numpy as np
from typing import Any, Callable, Dict, List, Union

from haystack import Pipeline
from haystack.core.component import Component
from haystack.evaluation.eval_utils import get_answers_from_output, preprocess_text
from haystack.evaluation.metrics import Metric, MetricsResult


Expand Down Expand Up @@ -45,8 +41,6 @@ def __init__(
Metric.RECALL: self._calculate_recall,
Metric.MRR: self._calculate_mrr,
Metric.MAP: self._calculate_map,
Metric.F1: self._calculate_f1,
Metric.EM: self._calculate_em,
}

def calculate_metrics(self, metric: Union[Metric, Callable[..., MetricsResult]], **kwargs) -> MetricsResult:
Expand All @@ -71,119 +65,6 @@ def _calculate_map(self):
def _calculate_mrr(self):
return MetricsResult({"mean_reciprocal_rank": None})

def _compute_f1_single(self, label_toks: List[str], pred_toks: List[str]) -> float:
"""
Compute F1 score for a single sample.
"""
common: collections.Counter = collections.Counter(label_toks) & collections.Counter(pred_toks)
num_same = sum(common.values())
if len(label_toks) == 0 or len(pred_toks) == 0:
# If either is no-answer, then F1 is 1 if they agree, 0 otherwise
return int(label_toks == pred_toks)
if num_same == 0:
return 0
precision = 1.0 * num_same / len(pred_toks)
recall = 1.0 * num_same / len(label_toks)
f1 = (2 * precision * recall) / (precision + recall)
return f1

def _calculate_f1(
self,
output_key: str,
regexes_to_ignore: Optional[List[str]] = None,
ignore_case: bool = False,
ignore_punctuation: bool = False,
ignore_numbers: bool = False,
) -> MetricsResult:
"""
Calculates the F1 score between two lists of predictions and labels.
F1 score measures the word overlap between the predicted text and the corresponding ground truth label.

:param output_key: The key of the output to use for comparison.
:param regexes_to_ignore (list, optional): A list of regular expressions. If provided, it removes substrings
matching these regular expressions from both predictions and labels before comparison. Defaults to None.
:param ignore_case (bool, optional): If True, performs case-insensitive comparison. Defaults to False.
:param ignore_punctuation (bool, optional): If True, removes punctuation from both predictions and labels before
comparison. Defaults to False.
:param ignore_numbers (bool, optional): If True, removes numerical digits from both predictions and labels
before comparison. Defaults to False.
:return: A MetricsResult object containing the calculated F1 score.
"""

predictions = get_answers_from_output(
outputs=self.outputs, output_key=output_key, runnable_type=self.runnable_type
)
labels = get_answers_from_output(
outputs=self.expected_outputs, output_key=output_key, runnable_type=self.runnable_type
)

if len(predictions) != len(labels):
raise ValueError("The number of predictions and labels must be the same.")
if len(predictions) == len(labels) == 0:
# Return F1 as 0 for no inputs
return MetricsResult({"f1": 0.0})

predictions = preprocess_text(predictions, regexes_to_ignore, ignore_case, ignore_punctuation, ignore_numbers)
labels = preprocess_text(labels, regexes_to_ignore, ignore_case, ignore_punctuation, ignore_numbers)

# Tokenize by splitting on spaces
tokenized_predictions = [pred.split() for pred in predictions]
tokenized_labels = [label.split() for label in labels]

f1_scores = [
self._compute_f1_single(label_toks, pred_toks)
for label_toks, pred_toks in zip(tokenized_labels, tokenized_predictions)
]

f1 = np.mean(f1_scores)

return MetricsResult({"f1": f1})

def _calculate_em(
self,
output_key: str,
regexes_to_ignore: Optional[List[str]] = None,
ignore_case: bool = False,
ignore_punctuation: bool = False,
ignore_numbers: bool = False,
) -> MetricsResult:
"""
Calculates the Exact Match (EM) score between two lists of predictions and labels.
Exact Match (EM) score measures the percentage of samples where the predicted text exactly matches the
corresponding ground truth label.

:param output_key: The key of the output to use for comparison.
:param regexes_to_ignore (list, optional): A list of regular expressions. If provided, it removes substrings
matching these regular expressions from both predictions and labels before comparison. Defaults to None.
:param ignore_case (bool, optional): If True, performs case-insensitive comparison. Defaults to False.
:param ignore_punctuation (bool, optional): If True, removes punctuation from both predictions and labels before
comparison. Defaults to False.
:param ignore_numbers (bool, optional): If True, removes numerical digits from both predictions and labels
before comparison. Defaults to False.
:return: A MetricsResult object containing the calculated Exact Match (EM) score.
"""

predictions = get_answers_from_output(
outputs=self.outputs, output_key=output_key, runnable_type=self.runnable_type
)
labels = get_answers_from_output(
outputs=self.expected_outputs, output_key=output_key, runnable_type=self.runnable_type
)

if len(predictions) != len(labels):
raise ValueError("The number of predictions and labels must be the same.")
if len(predictions) == len(labels) == 0:
# Return Exact Match as 0 for no inputs
return MetricsResult({"exact_match": 0.0})

predictions = preprocess_text(predictions, regexes_to_ignore, ignore_case, ignore_punctuation, ignore_numbers)
labels = preprocess_text(labels, regexes_to_ignore, ignore_case, ignore_punctuation, ignore_numbers)

score_list = np.array(predictions) == np.array(labels)
exact_match_score = np.mean(score_list)

return MetricsResult({"exact_match": exact_match_score})


def eval(
runnable: Union[Pipeline, Component], inputs: List[Dict[str, Any]], expected_outputs: List[Dict[str, Any]]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
features:
- |
Add `StatisticalEvaluator`, this Component can be used to calculate the different statistic metrics from answers returned by LLMs.
Loading