From 9297fca520bbebac742f74db0c55242eb32632fe Mon Sep 17 00:00:00 2001 From: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> Date: Wed, 14 Feb 2024 16:16:22 +0100 Subject: [PATCH] feat: Add `SASEvaluator` component (#6980) * Add SASEvaluator component * Add release notes * Delete old tests * Remove SAS metric in old API * Avoid importing whole numpy package --- haystack/components/eval/__init__.py | 3 + haystack/components/eval/preprocess.py | 38 ++ haystack/components/eval/sas_evaluator.py | 156 ++++++++ haystack/evaluation/eval.py | 111 +----- haystack/evaluation/metrics.py | 1 - .../notes/sas-evaluator-7858ea6c38f80bc7.yaml | 4 + test/components/eval/test_preprocess.py | 125 +++++++ test/components/eval/test_sas_evaluator.py | 292 +++++++++++++++ test/evaluation/test_eval_sas.py | 347 ------------------ 9 files changed, 619 insertions(+), 458 deletions(-) create mode 100644 haystack/components/eval/__init__.py create mode 100644 haystack/components/eval/preprocess.py create mode 100644 haystack/components/eval/sas_evaluator.py create mode 100644 releasenotes/notes/sas-evaluator-7858ea6c38f80bc7.yaml create mode 100644 test/components/eval/test_preprocess.py create mode 100644 test/components/eval/test_sas_evaluator.py delete mode 100644 test/evaluation/test_eval_sas.py diff --git a/haystack/components/eval/__init__.py b/haystack/components/eval/__init__.py new file mode 100644 index 0000000000..9477cb1242 --- /dev/null +++ b/haystack/components/eval/__init__.py @@ -0,0 +1,3 @@ +from .sas_evaluator import SASEvaluator + +__all__ = ["SASEvaluator"] diff --git a/haystack/components/eval/preprocess.py b/haystack/components/eval/preprocess.py new file mode 100644 index 0000000000..9be1dd93ff --- /dev/null +++ b/haystack/components/eval/preprocess.py @@ -0,0 +1,38 @@ +import re +import string +from typing import List, Optional + + +def _preprocess_text( + texts: List[str], + regexes_to_ignore: Optional[List[str]] = None, + ignore_case: bool = False, + ignore_punctuation: bool = False, + ignore_numbers: bool = False, +) -> List[str]: + """ + Preprocess the outputs of the runnable to remove unwanted characters. + + :param regexes_to_ignore (list, optional): A list of regular expressions. If provided, it removes substrings + matching these regular expressions from the text. Defaults to None. + :param ignore_case (bool, optional): If True, converts all characters to lowercase. Defaults to False. + :param ignore_punctuation (bool, optional): If True, removes punctuation from the text. Defaults to False. + :param ignore_numbers (bool, optional): If True, removes numerical digits from the text. Defaults to False. + :return: A list of preprocessed strings. + """ + if regexes_to_ignore: + combined_regex = "|".join(regexes_to_ignore) + texts = [re.sub(combined_regex, "", text, flags=re.IGNORECASE) for text in texts] + + if ignore_case: + texts = [text.lower() for text in texts] + + if ignore_punctuation: + translator = str.maketrans("", "", string.punctuation) + texts = [text.translate(translator) for text in texts] + + if ignore_numbers: + translator = str.maketrans("", "", string.digits) + texts = [text.translate(translator) for text in texts] + + return texts diff --git a/haystack/components/eval/sas_evaluator.py b/haystack/components/eval/sas_evaluator.py new file mode 100644 index 0000000000..8b4c30352e --- /dev/null +++ b/haystack/components/eval/sas_evaluator.py @@ -0,0 +1,156 @@ +from typing import Any, Dict, List, Optional + +from numpy import mean as np_mean + +from haystack import component, default_from_dict, default_to_dict +from haystack.lazy_imports import LazyImport +from haystack.utils import ComponentDevice, expit +from haystack.utils.auth import Secret, deserialize_secrets_inplace + +from .preprocess import _preprocess_text + +with LazyImport(message="Run 'pip install scikit-learn \"sentence-transformers>=2.2.0\"'") as metrics_import: + from sentence_transformers import CrossEncoder, SentenceTransformer, util + from transformers import AutoConfig + + +@component +class SASEvaluator: + """ + SASEvaluator computes the Semantic Answer Similarity (SAS) between a list of predictions and a list of labels. + It's usually used in Retrieval Augmented Generation (RAG) pipelines to evaluate the quality of the generated answers. + + The SAS is computed using a pre-trained model from the Hugging Face model hub. The model can be either a + Bi-Encoder or a Cross-Encoder. The choice of the model is based on the `model` parameter. + The default model is `sentence-transformers/paraphrase-multilingual-mpnet-base-v2`. + """ + + def __init__( + self, + labels: List[str], + model: str = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2", + regexes_to_ignore: Optional[List[str]] = None, + ignore_case: bool = False, + ignore_punctuation: bool = False, + ignore_numbers: bool = False, + batch_size: int = 32, + device: Optional[ComponentDevice] = None, + token: Secret = Secret.from_env_var("HF_API_TOKEN", strict=False), + ): + """ + Creates a new instance of SASEvaluator. + + :param labels: The list of expected answers. + :param model: SentenceTransformers semantic textual similarity model, should be path or string pointing to + a downloadable model. + :param regexes_to_ignore: A list of regular expressions. If provided, it removes substrings + matching these regular expressions from both predictions and labels before comparison. Defaults to None. + :param ignore_case: If True, performs case-insensitive comparison. Defaults to False. + :param ignore_punctuation: If True, removes punctuation from both predictions and labels before + comparison. Defaults to False. + :param ignore_numbers: If True, removes numerical digits from both predictions and labels + before comparison. Defaults to False. + :param batch_size: Number of prediction-label pairs to encode at once. + :param device: The device on which the model is loaded. If `None`, the default device is automatically + selected. + :param token: The Hugging Face token for HTTP bearer authorization. + You can find your HF token at https://huggingface.co/settings/tokens. + """ + metrics_import.check() + + self._labels = labels + self._model = model + self._regexes_to_ignore = regexes_to_ignore + self._ignore_case = ignore_case + self._ignore_punctuation = ignore_punctuation + self._ignore_numbers = ignore_numbers + self._batch_size = batch_size + self._device = device + self._token = token + + def to_dict(self) -> Dict[str, Any]: + return default_to_dict( + self, + labels=self._labels, + regexes_to_ignore=self._regexes_to_ignore, + ignore_case=self._ignore_case, + ignore_punctuation=self._ignore_punctuation, + ignore_numbers=self._ignore_numbers, + model=self._model, + batch_size=self._batch_size, + device=self._device.to_dict() if self._device else None, + token=self._token.to_dict() if self._token else None, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "SASEvaluator": + deserialize_secrets_inplace(data["init_parameters"], keys=["token"]) + if device := data.get("init_parameters", {}).get("device"): + data["init_parameters"]["device"] = ComponentDevice.from_dict(device) + return default_from_dict(cls, data) + + @component.output_types(sas=float, scores=List[float]) + def run(self, predictions: List[str]) -> Dict[str, Any]: + if len(predictions) != len(self._labels): + raise ValueError("The number of predictions and labels must be the same.") + + if len(predictions) == 0: + return {"sas": 0.0, "scores": [0.0]} + + token = self._token.resolve_value() if self._token else None + + predictions = _preprocess_text( + predictions, self._regexes_to_ignore, self._ignore_case, self._ignore_punctuation, self._ignore_numbers + ) + labels = _preprocess_text( + self._labels, self._regexes_to_ignore, self._ignore_case, self._ignore_punctuation, self._ignore_numbers + ) + config = AutoConfig.from_pretrained(self._model, use_auth_token=token) + cross_encoder_used = False + if config.architectures: + cross_encoder_used = any(arch.endswith("ForSequenceClassification") for arch in config.architectures) + + device = ComponentDevice.resolve_device(self._device) + + # Based on the Model string we can load either Bi-Encoders or Cross Encoders. + # Similarity computation changes for both approaches + + if cross_encoder_used: + # For Cross Encoders we create a list of pairs of predictions and labels + similarity_model = CrossEncoder( + self._model, + device=device.to_torch_str(), + tokenizer_args={"use_auth_token": token}, + automodel_args={"use_auth_token": token}, + ) + sentence_pairs = [[pred, label] for pred, label in zip(predictions, labels)] + similarity_scores = similarity_model.predict( + sentence_pairs, batch_size=self._batch_size, convert_to_numpy=True + ) + + # All Cross Encoders do not return a set of logits scores that are normalized + # We normalize scores if they are larger than 1 + if (similarity_scores > 1).any(): + similarity_scores = expit(similarity_scores) + + # Convert scores to list of floats from numpy array + similarity_scores = similarity_scores.tolist() + + else: + # For Bi-encoders we create embeddings separately for predictions and labels + similarity_model = SentenceTransformer(self._model, device=device.to_torch_str(), use_auth_token=token) + predictions_embeddings = similarity_model.encode( + predictions, batch_size=self._batch_size, convert_to_tensor=True + ) + label_embeddings = similarity_model.encode(labels, batch_size=self._batch_size, convert_to_tensor=True) + + # Compute cosine-similarities + scores = util.cos_sim(predictions_embeddings, label_embeddings) + + # cos_sim computes cosine similarity between all pairs of vectors in pred_embeddings and label_embeddings + # It returns a matrix with shape (len(predictions), len(labels)) + similarity_scores = [scores[i][i].item() for i in range(len(predictions))] + + sas_score = np_mean(similarity_scores) + + return {"sas": sas_score, "scores": similarity_scores} diff --git a/haystack/evaluation/eval.py b/haystack/evaluation/eval.py index 3ce375ee24..500b423e86 100644 --- a/haystack/evaluation/eval.py +++ b/haystack/evaluation/eval.py @@ -1,5 +1,5 @@ import collections -from typing import Any, Callable, Dict, List, Union, Optional +from typing import Any, Callable, Dict, List, Optional, Union import numpy as np @@ -8,13 +8,6 @@ from haystack.evaluation.eval_utils import get_answers_from_output, preprocess_text from haystack.evaluation.metrics import Metric, MetricsResult -from haystack.lazy_imports import LazyImport -from haystack.utils import ComponentDevice, expit - -with LazyImport(message="Run 'pip install scikit-learn \"sentence-transformers>=2.2.0\"'") as metrics_import: - from sentence_transformers import SentenceTransformer, CrossEncoder, util - from transformers import AutoConfig - class EvaluationResult: """ @@ -54,7 +47,6 @@ def __init__( Metric.MAP: self._calculate_map, Metric.F1: self._calculate_f1, Metric.EM: self._calculate_em, - Metric.SAS: self._calculate_sas, } def calculate_metrics(self, metric: Union[Metric, Callable[..., MetricsResult]], **kwargs) -> MetricsResult: @@ -192,107 +184,6 @@ def _calculate_em( return MetricsResult({"exact_match": exact_match_score}) - def _calculate_sas( - self, - output_key: str, - regexes_to_ignore: Optional[List[str]] = None, - ignore_case: bool = False, - ignore_punctuation: bool = False, - ignore_numbers: bool = False, - model: str = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2", - batch_size: int = 32, - device: Optional[ComponentDevice] = None, - token: Optional[Union[str, bool]] = None, - ) -> MetricsResult: - """ - Calculates the Semantic Answer Similarity (SAS) score between two lists of predictions and labels. - Semantic Answer Similarity (SAS) score measures the Transformer-based similarity between the predicted text and - the corresponding ground truth label. - - :param output_key: The key of the output to use for comparison. - :param regexes_to_ignore (list, optional): A list of regular expressions. If provided, it removes substrings - matching these regular expressions from both predictions and labels before comparison. Defaults to None. - :param ignore_case (bool, optional): If True, performs case-insensitive comparison. Defaults to False. - :param ignore_punctuation (bool, optional): If True, removes punctuation from both predictions and labels before - comparison. Defaults to False. - :param ignore_numbers (bool, optional): If True, removes numerical digits from both predictions and labels - before comparison. Defaults to False. - :param model: SentenceTransformers semantic textual similarity model, should be path or string pointing to - a downloadable model. - :param batch_size: Number of prediction-label pairs to encode at once. - :param device: The device on which the model is loaded. If `None`, the default device is automatically - selected. - :param token: The token to use as HTTP bearer authorization for private models from Huggingface. - If True, will use the token generated when running huggingface-cli login (stored in ~/.huggingface). - Additional information can be found here: - https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained - :return: A MetricsResult object containing the calculated Semantic Answer Similarity (SAS) score and the - list of similarity scores obtained for each prediction-label pair. - """ - metrics_import.check() - - predictions = get_answers_from_output( - outputs=self.outputs, output_key=output_key, runnable_type=self.runnable_type - ) - labels = get_answers_from_output( - outputs=self.expected_outputs, output_key=output_key, runnable_type=self.runnable_type - ) - - if len(predictions) != len(labels): - raise ValueError("The number of predictions and labels must be the same.") - if len(predictions) == len(labels) == 0: - # Return SAS as 0 for no inputs - return MetricsResult({"sas": 0.0, "scores": [0.0]}) - - predictions = preprocess_text(predictions, regexes_to_ignore, ignore_case, ignore_punctuation, ignore_numbers) - labels = preprocess_text(labels, regexes_to_ignore, ignore_case, ignore_punctuation, ignore_numbers) - - config = AutoConfig.from_pretrained(model, use_auth_token=token) - cross_encoder_used = False - if config.architectures: - cross_encoder_used = any(arch.endswith("ForSequenceClassification") for arch in config.architectures) - - device = ComponentDevice.resolve_device(device) - - # Based on the Model string we can load either Bi-Encoders or Cross Encoders. - # Similarity computation changes for both approaches - - if cross_encoder_used: - # For Cross Encoders we create a list of pairs of predictions and labels - similarity_model = CrossEncoder( - model, - device=device.to_torch_str(), - tokenizer_args={"use_auth_token": token}, - automodel_args={"use_auth_token": token}, - ) - sentence_pairs = [[pred, label] for pred, label in zip(predictions, labels)] - similarity_scores = similarity_model.predict(sentence_pairs, batch_size=batch_size, convert_to_numpy=True) - - # All Cross Encoders do not return a set of logits scores that are normalized - # We normalize scores if they are larger than 1 - if (similarity_scores > 1).any(): - similarity_scores = expit(similarity_scores) - - # Convert scores to list of floats from numpy array - similarity_scores = similarity_scores.tolist() - - else: - # For Bi-encoders we create embeddings separately for predictions and labels - similarity_model = SentenceTransformer(model, device=device.to_torch_str(), use_auth_token=token) - pred_embeddings = similarity_model.encode(predictions, batch_size=batch_size, convert_to_tensor=True) - label_embeddings = similarity_model.encode(labels, batch_size=batch_size, convert_to_tensor=True) - - # Compute cosine-similarities - scores = util.cos_sim(pred_embeddings, label_embeddings) - - # cos_sim computes cosine similarity between all pairs of vectors in pred_embeddings and label_embeddings - # It returns a matrix with shape (len(predictions), len(labels)) - similarity_scores = [scores[i][i].item() for i in range(len(predictions))] - - sas_score = np.mean(similarity_scores) - - return MetricsResult({"sas": sas_score, "scores": similarity_scores}) - def eval( runnable: Union[Pipeline, Component], inputs: List[Dict[str, Any]], expected_outputs: List[Dict[str, Any]] diff --git a/haystack/evaluation/metrics.py b/haystack/evaluation/metrics.py index fbe2fec8af..ee94b42fd8 100644 --- a/haystack/evaluation/metrics.py +++ b/haystack/evaluation/metrics.py @@ -14,7 +14,6 @@ class Metric(Enum): MAP = "Mean Average Precision" F1 = "F1" EM = "Exact Match" - SAS = "Semantic Answer Similarity" class MetricsResult(dict): diff --git a/releasenotes/notes/sas-evaluator-7858ea6c38f80bc7.yaml b/releasenotes/notes/sas-evaluator-7858ea6c38f80bc7.yaml new file mode 100644 index 0000000000..7447437d89 --- /dev/null +++ b/releasenotes/notes/sas-evaluator-7858ea6c38f80bc7.yaml @@ -0,0 +1,4 @@ +--- +features: + - | + Add `SASEvaluator`, this Component can be used to calculate the Semantic Answer Similarity of answers returned by LLMs. diff --git a/test/components/eval/test_preprocess.py b/test/components/eval/test_preprocess.py new file mode 100644 index 0000000000..5e8971a775 --- /dev/null +++ b/test/components/eval/test_preprocess.py @@ -0,0 +1,125 @@ +from haystack.components.eval.preprocess import _preprocess_text + + +def test_preprocess_text_default_parameters(): + """ + Test preprocess_text with default parameters. + There should be no changes to the input text. + """ + texts = ["Test, Output-1!", "Test, Output-2!"] + expected_output = ["Test, Output-1!", "Test, Output-2!"] + actual_output = _preprocess_text(texts) + + assert actual_output == expected_output + + +def test_preprocess_text_ignore_case(): + """ + Test preprocess_text with ignore_case=True. + + """ + texts = ["Test, Output-1!"] + expected_output = ["test, output-1!"] + + actual_output = _preprocess_text(texts, ignore_case=True) + + assert actual_output == expected_output + + +def test_preprocess_text_ignore_punctuation(): + """ + Test preprocess_text with ignore_punctuation=True. + """ + texts = ["Test, Output-1!"] + expected_output = ["Test Output1"] + + actual_output = _preprocess_text(texts, ignore_punctuation=True) + + assert actual_output == expected_output + + +# Preprocess text with ignore_numbers=True. +def test_preprocess_text_ignore_numbers(): + """ + Test preprocess_text with ignore_numbers=True. It should be able to remove numbers from the input. + """ + texts = ["Test, Output-1!"] + expected_output = ["Test, Output-!"] + + actual_output = _preprocess_text(texts, ignore_numbers=True) + + assert actual_output == expected_output + + +def test_preprocess_text_regexes_to_ignore(): + """ + Test preprocess_text with a list of regex patterns to ignore. + """ + texts = ["Test, Output-1!"] + expected_output = ["Test Output"] + + # Use regex patterns to remove digits and non-alphanumeric characters + actual_output = _preprocess_text(texts, regexes_to_ignore=[r"\d", r"[^\w\s]"]) + + assert actual_output == expected_output + + +def test_preprocess_text_empty_list(): + """ + Test preprocess_text with empty list of texts. + """ + texts = [] + expected_output = [] + + actual_output = _preprocess_text(texts) + + assert actual_output == expected_output + + +def test_preprocess_text_all_ignore_parameters(): + """ + Test preprocess_text with all ignore parameters set to True. + """ + texts = ["Test, Output-1!"] + expected_output = ["test output"] + + actual_output = _preprocess_text(texts, ignore_case=True, ignore_punctuation=True, ignore_numbers=True) + + assert actual_output == expected_output + + +def test_preprocess_text_regexes_to_ignore_empty_string(): + """ + Test preprocess_text with regexes_to_ignore=[""]. + """ + texts = ["Test, Output-1!"] + expected_output = ["Test, Output-1!"] + + actual_output = _preprocess_text(texts, regexes_to_ignore=[""]) + + assert actual_output == expected_output + + +# Preprocess text with regexes_to_ignore=[".*"]. +def test_preprocess_text_regexes_to_ignore_dot_star(): + """ + Test preprocess_text with regexes_to_ignore=[".*"]. + """ + texts = ["Test, Output-1!"] + expected_output = [""] + + actual_output = _preprocess_text(texts, regexes_to_ignore=[".*"]) + + assert actual_output == expected_output + + +def test_preprocess_text_regexes_to_ignore_same_substring(): + """ + Test preprocess_text with regexes_to_ignore where all the regex patterns match the same substring. + """ + texts = ["Test, Output-1!"] + expected_output = ["Test, Output-!"] + + actual_output = _preprocess_text(texts, regexes_to_ignore=[r"\d", r"\d"]) + + assert actual_output == expected_output diff --git a/test/components/eval/test_sas_evaluator.py b/test/components/eval/test_sas_evaluator.py new file mode 100644 index 0000000000..0a7811bfe5 --- /dev/null +++ b/test/components/eval/test_sas_evaluator.py @@ -0,0 +1,292 @@ +import pytest + +from haystack.components.eval import SASEvaluator +from haystack.utils.device import ComponentDevice + + +class TestSASEvaluator: + def test_init_default(self, monkeypatch): + monkeypatch.setenv("HF_API_TOKEN", "fake-token") + labels = ["label1", "label2", "label3"] + evaluator = SASEvaluator(labels=labels) + + assert evaluator._labels == labels + assert evaluator._regexes_to_ignore is None + assert evaluator._ignore_case is False + assert evaluator._ignore_punctuation is False + assert evaluator._ignore_numbers is False + assert evaluator._model == "sentence-transformers/paraphrase-multilingual-mpnet-base-v2" + assert evaluator._batch_size == 32 + assert evaluator._device is None + assert evaluator._token.resolve_value() == "fake-token" + + def test_to_dict(self, monkeypatch): + monkeypatch.setenv("HF_API_TOKEN", "fake-token") + + labels = ["label1", "label2", "label3"] + evaluator = SASEvaluator(labels=labels, device=ComponentDevice.from_str("cuda:0")) + + expected_dict = { + "type": "haystack.components.eval.sas_evaluator.SASEvaluator", + "init_parameters": { + "labels": labels, + "regexes_to_ignore": None, + "ignore_case": False, + "ignore_punctuation": False, + "ignore_numbers": False, + "model": "sentence-transformers/paraphrase-multilingual-mpnet-base-v2", + "batch_size": 32, + "device": {"type": "single", "device": "cuda:0"}, + "token": {"type": "env_var", "env_vars": ["HF_API_TOKEN"], "strict": False}, + }, + } + assert evaluator.to_dict() == expected_dict + + def test_from_dict(self, monkeypatch): + monkeypatch.setenv("HF_API_TOKEN", "fake-token") + evaluator = SASEvaluator.from_dict( + { + "type": "haystack.components.eval.sas_evaluator.SASEvaluator", + "init_parameters": { + "labels": ["label1", "label2", "label3"], + "regexes_to_ignore": None, + "ignore_case": False, + "ignore_punctuation": False, + "ignore_numbers": False, + "model": "sentence-transformers/paraphrase-multilingual-mpnet-base-v2", + "batch_size": 32, + "device": {"type": "single", "device": "cuda:0"}, + "token": {"type": "env_var", "env_vars": ["HF_API_TOKEN"], "strict": False}, + }, + } + ) + + assert evaluator._labels == ["label1", "label2", "label3"] + assert evaluator._regexes_to_ignore is None + assert evaluator._ignore_case is False + assert evaluator._ignore_punctuation is False + assert evaluator._ignore_numbers is False + assert evaluator._model == "sentence-transformers/paraphrase-multilingual-mpnet-base-v2" + assert evaluator._batch_size == 32 + assert evaluator._device.to_torch_str() == "cuda:0" + assert evaluator._token.resolve_value() == "fake-token" + + @pytest.mark.integration + def test_run_with_empty_inputs(self): + evaluator = SASEvaluator(labels=[]) + result = evaluator.run(predictions=[]) + assert len(result) == 2 + assert result["sas"] == 0.0 + assert result["scores"] == [0.0] + + @pytest.mark.integration + def test_run_with_different_lengths(self): + labels = [ + "A construction budget of US $2.3 billion", + "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", + ] + evaluator = SASEvaluator(labels=labels) + + predictions = [ + "A construction budget of US $2.3 billion", + "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", + "The Meiji Restoration in 1868 transformed Japan into a modernized world power.", + ] + with pytest.raises(ValueError): + evaluator.run(predictions) + + @pytest.mark.integration + def test_run_with_matching_predictions(self): + labels = [ + "A construction budget of US $2.3 billion", + "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", + "The Meiji Restoration in 1868 transformed Japan into a modernized world power.", + ] + evaluator = SASEvaluator(labels=labels) + predictions = [ + "A construction budget of US $2.3 billion", + "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", + "The Meiji Restoration in 1868 transformed Japan into a modernized world power.", + ] + result = evaluator.run(predictions=predictions) + + assert len(result) == 2 + assert result["sas"] == pytest.approx(1.0) + assert result["scores"] == pytest.approx([1.0, 1.0, 1.0]) + + @pytest.mark.integration + def test_run_with_single_prediction(self): + labels = ["US $2.3 billion"] + evaluator = SASEvaluator(labels=labels) + + result = evaluator.run(predictions=["A construction budget of US $2.3 billion"]) + assert len(result) == 2 + assert result["sas"] == pytest.approx(0.689089, abs=1e-5) + assert result["scores"] == pytest.approx([0.689089], abs=1e-5) + + @pytest.mark.integration + def test_run_with_mismatched_predictions(self): + labels = [ + "US $2.3 billion", + "Paris's cultural magnificence is symbolized by the Eiffel Tower", + "Japan was transformed into a modernized world power after the Meiji Restoration.", + ] + evaluator = SASEvaluator(labels=labels) + predictions = [ + "A construction budget of US $2.3 billion", + "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", + "The Meiji Restoration in 1868 transformed Japan into a modernized world power.", + ] + result = evaluator.run(predictions=predictions) + assert len(result) == 2 + assert result["sas"] == pytest.approx(0.8227189) + assert result["scores"] == pytest.approx([0.689089, 0.870389, 0.908679], abs=1e-5) + + @pytest.mark.integration + def test_run_with_ignore_case(self): + labels = [ + "A construction budget of US $2.3 BILLION", + "The EIFFEL TOWER, completed in 1889, symbolizes Paris's cultural magnificence.", + "The MEIJI RESTORATION in 1868 transformed Japan into a modernized world power.", + ] + evaluator = SASEvaluator(labels=labels, ignore_case=True) + predictions = [ + "A construction budget of US $2.3 billion", + "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", + "The Meiji Restoration in 1868 transformed Japan into a modernized world power.", + ] + result = evaluator.run(predictions=predictions) + assert len(result) == 2 + assert result["sas"] == pytest.approx(1.0) + assert result["scores"] == pytest.approx([1.0, 1.0, 1.0]) + + @pytest.mark.integration + def test_run_with_ignore_punctuation(self): + labels = [ + "A construction budget of US $2.3 billion", + "The Eiffel Tower completed in 1889 symbolizes Paris's cultural magnificence", + "The Meiji Restoration in 1868 transformed Japan into a modernized world power", + ] + evaluator = SASEvaluator(labels=labels, ignore_punctuation=True) + predictions = [ + "A construction budget of US $2.3 billion", + "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", + "The Meiji Restoration, in 1868, transformed Japan into a modernized world power.", + ] + result = evaluator.run(predictions=predictions) + assert len(result) == 2 + assert result["sas"] == pytest.approx(1.0) + assert result["scores"] == pytest.approx([1.0, 1.0, 1.0]) + + @pytest.mark.integration + def test_run_with_ignore_numbers(self): + labels = [ + "A construction budget of US $10.3 billion", + "The Eiffel Tower, completed in 2005, symbolizes Paris's cultural magnificence.", + "The Meiji Restoration, in 1989, transformed Japan into a modernized world power.", + ] + evaluator = SASEvaluator(labels=labels, ignore_numbers=True) + predictions = [ + "A construction budget of US $2.3 billion", + "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", + "The Meiji Restoration, in 1868, transformed Japan into a modernized world power.", + ] + result = evaluator.run(predictions=predictions) + assert result["sas"] == pytest.approx(1.0) + assert result["scores"] == pytest.approx([1.0, 1.0, 1.0]) + + @pytest.mark.integration + def test_run_with_regex_to_ignore(self): + labels = [ + "A construction budget of US $10.3 billion", + "The Eiffel Tower, completed in 2005, symbolizes Paris's cultural magnificence.", + "The Meiji Restoration, in 1989, transformed Japan into a modernized world power.", + ] + evaluator = SASEvaluator(labels=labels, regexes_to_ignore=[r"\d+"]) + predictions = [ + "A construction budget of US $2.3 billion", + "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", + "The Meiji Restoration, in 1868, transformed Japan into a modernized world power.", + ] + result = evaluator.run(predictions=predictions) + assert len(result) == 2 + assert result["sas"] == pytest.approx(1.0) + assert result["scores"] == pytest.approx([1.0, 1.0, 1.0]) + + @pytest.mark.integration + def test_run_with_multiple_regex_to_ignore(self): + labels = [ + "A construction budget of US $10.3 billion", + "The Eiffel Tower, completed in 2005, symbolizes Paris's cultural magnificence.", + "The Meiji Restoration, in 1989, transformed Japan into a modernized world power.", + ] + evaluator = SASEvaluator(labels=labels, regexes_to_ignore=[r"\d+", r"[^\w\s]"]) + predictions = [ + "A construction budget of US $2.3 billion", + "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", + "The Meiji Restoration, in 1868, transformed Japan into a modernized world power.", + ] + result = evaluator.run(predictions=predictions) + assert len(result) == 2 + assert result["sas"] == pytest.approx(1.0) + assert result["scores"] == pytest.approx([1.0, 1.0, 1.0]) + + @pytest.mark.integration + def test_run_with_multiple_ignore_parameters(self): + labels = [ + "A construction budget of US $10.3 billion", + "The Eiffel Tower, completed in 2005, symbolizes Paris's cultural magnificence.", + "The Meiji Restoration, in 1989, transformed Japan into a modernized world power.", + ] + evaluator = SASEvaluator( + labels=labels, + ignore_numbers=True, + ignore_punctuation=True, + ignore_case=True, + regexes_to_ignore=[r"[^\w\s\d]+"], + ) + predictions = [ + "A construction budget of US $2.3 billion", + "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", + "The Meiji Restoration, in 1868, transformed Japan into a modernized world power.", + ] + result = evaluator.run(predictions=predictions) + assert len(result) == 2 + assert result["sas"] == pytest.approx(1.0) + assert result["scores"] == pytest.approx([1.0, 1.0, 1.0]) + + @pytest.mark.integration + def test_run_with_bi_encoder_model(self): + labels = [ + "A construction budget of US $2.3 billion", + "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", + "The Meiji Restoration in 1868 transformed Japan into a modernized world power.", + ] + evaluator = SASEvaluator(labels=labels, model="sentence-transformers/all-mpnet-base-v2") + predictions = [ + "A construction budget of US $2.3 billion", + "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", + "The Meiji Restoration in 1868 transformed Japan into a modernized world power.", + ] + result = evaluator.run(predictions=predictions) + assert len(result) == 2 + assert result["sas"] == pytest.approx(1.0) + assert result["scores"] == pytest.approx([1.0, 1.0, 1.0]) + + @pytest.mark.integration + def test_run_with_cross_encoder_model(self): + labels = [ + "A construction budget of US $2.3 billion", + "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", + "The Meiji Restoration in 1868 transformed Japan into a modernized world power.", + ] + evaluator = SASEvaluator(labels=labels, model="cross-encoder/ms-marco-MiniLM-L-6-v2") + predictions = [ + "A construction budget of US $2.3 billion", + "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", + "The Meiji Restoration in 1868 transformed Japan into a modernized world power.", + ] + result = evaluator.run(predictions=predictions) + assert len(result) == 2 + assert result["sas"] == pytest.approx(0.999967, abs=1e-5) + assert result["scores"] == pytest.approx([0.9999765157699585, 0.999968409538269, 0.9999572038650513], abs=1e-5) diff --git a/test/evaluation/test_eval_sas.py b/test/evaluation/test_eval_sas.py deleted file mode 100644 index 5b1fb70a8a..0000000000 --- a/test/evaluation/test_eval_sas.py +++ /dev/null @@ -1,347 +0,0 @@ -import pytest - -from haystack import Pipeline -from haystack.dataclasses import GeneratedAnswer -from haystack.evaluation.eval import EvaluationResult - - -class TestSAS: - def create_evaluation_result(self, predictions, labels): - """ - Creates an evaluation result of a RAG pipeline using the list of predictions and labels for testing the - Semantic Answer Similarity (SAS) Metric. - """ - runnable = Pipeline() - inputs = [] - outputs = [ - {"answer_builder": {"answers": [GeneratedAnswer(data=pred, query="", documents=[], meta={})]}} - for pred in predictions - ] - expected_outputs = [ - {"answer_builder": {"answers": [GeneratedAnswer(data=label, query="", documents=[], meta={})]}} - for label in labels - ] - evaluation_result = EvaluationResult(runnable, inputs, outputs, expected_outputs) - return evaluation_result - - def test_sas_empty_inputs(self): - """ - Test calculation of Semantic Answer Similarity (SAS) Score with empty inputs. - """ - runnable = Pipeline() - inputs = [] - outputs = [ - {"answer_builder": {"answers": []}}, - {"answer_builder": {"answers": []}}, - {"answer_builder": {"answers": []}}, - ] - expected_outputs = [ - {"answer_builder": {"answers": []}}, - {"answer_builder": {"answers": []}}, - {"answer_builder": {"answers": []}}, - ] - evaluation_result = EvaluationResult(runnable, inputs, outputs, expected_outputs) - # Expecting 0% SAS for empty inputs - sas_result = evaluation_result._calculate_sas( - output_key="answers", model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2" - ) - - assert sas_result["sas"] == 0.0 - assert sas_result["scores"] == [0.0] - - def test_calculate_sas_with_different_lengths(self): - """ - Test calculation of Semantic Answer Similarity (SAS) Score with default parameters. - """ - predictions = [ - "A construction budget of US $2.3 billion", - "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", - "The Meiji Restoration in 1868 transformed Japan into a modernized world power.", - ] - labels = [ - "A construction budget of US $2.3 billion", - "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", - ] - evaluation_result = self.create_evaluation_result(predictions, labels) - - with pytest.raises(ValueError, match="The number of predictions and labels must be the same."): - evaluation_result._calculate_sas( - output_key="answers", model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2" - ) - - @pytest.mark.integration - def test_sas_same_inputs(self): - """ - Test calculation of Semantic Answer Similarity (SAS) Score with default parameters. - """ - predictions = [ - "A construction budget of US $2.3 billion", - "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", - "The Meiji Restoration in 1868 transformed Japan into a modernized world power.", - ] - labels = [ - "A construction budget of US $2.3 billion", - "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", - "The Meiji Restoration in 1868 transformed Japan into a modernized world power.", - ] - evaluation_result = self.create_evaluation_result(predictions, labels) - sas_result = evaluation_result._calculate_sas( - output_key="answers", model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2" - ) - - assert sas_result["sas"] == pytest.approx(1.0) - assert sas_result["scores"] == pytest.approx([1.0, 1.0, 1.0]) - - @pytest.mark.integration - def test_sas_single_word(self): - """ - Test calculation of Semantic Answer Similarity (SAS) Score with single-word inputs. - """ - predictions = ["A construction budget of US $2.3 billion"] - labels = ["US $2.3 billion"] - - evaluation_result = self.create_evaluation_result(predictions, labels) - sas_result = evaluation_result._calculate_sas( - output_key="answers", model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2" - ) - - assert sas_result["sas"] == pytest.approx(0.689089, abs=1e-5) - assert sas_result["scores"] == pytest.approx([0.689089], abs=1e-5) - - @pytest.mark.integration - def test_sas_negative_case(self): - """ - Test calculation of Semantic Answer Similarity (SAS) Score with deliberately mismatched predictions and labels. - """ - predictions = [ - "A construction budget of US $2.3 billion", - "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", - "The Meiji Restoration in 1868 transformed Japan into a modernized world power.", - ] - labels = [ - "US $2.3 billion", - "Paris's cultural magnificence is symbolized by the Eiffel Tower", - "Japan was transformed into a modernized world power after the Meiji Restoration.", - ] - - evaluation_result = self.create_evaluation_result(predictions, labels) - sas_result = evaluation_result._calculate_sas( - output_key="answers", model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2" - ) - - assert sas_result["sas"] == pytest.approx(0.8227189) - assert sas_result["scores"] == pytest.approx([0.689089, 0.870389, 0.908679], abs=1e-5) - - @pytest.mark.integration - def test_sas_ignore_case(self): - """ - Test calculation of Semantic Answer Similarity (SAS) Score with ignoring case sensitivity. - """ - predictions = [ - "A construction budget of US $2.3 billion", - "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", - "The Meiji Restoration in 1868 transformed Japan into a modernized world power.", - ] - labels = [ - "A construction budget of US $2.3 BILLION", - "The EIFFEL TOWER, completed in 1889, symbolizes Paris's cultural magnificence.", - "The MEIJI RESTORATION in 1868 transformed Japan into a modernized world power.", - ] - - evaluation_result = self.create_evaluation_result(predictions, labels) - # SAS after case ignoring - sas_result = evaluation_result._calculate_sas( - output_key="answers", model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2", ignore_case=True - ) - - assert sas_result["sas"] == pytest.approx(1.0) - assert sas_result["scores"] == pytest.approx([1.0, 1.0, 1.0]) - - @pytest.mark.integration - def test_sas_ignore_punctuation(self): - """ - Test calculation of Semantic Answer Similarity (SAS) Score with ignoring punctuation. - """ - predictions = [ - "A construction budget of US $2.3 billion", - "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", - "The Meiji Restoration, in 1868, transformed Japan into a modernized world power.", - ] - labels = [ - "A construction budget of US $2.3 billion", - "The Eiffel Tower completed in 1889 symbolizes Paris's cultural magnificence", - "The Meiji Restoration in 1868 transformed Japan into a modernized world power", - ] - - evaluation_result = self.create_evaluation_result(predictions, labels) - # SAS after ignoring punctuation - sas_result = evaluation_result._calculate_sas( - output_key="answers", - model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2", - ignore_punctuation=True, - ) - - assert sas_result["sas"] == pytest.approx(1.0) - assert sas_result["scores"] == pytest.approx([1.0, 1.0, 1.0]) - - @pytest.mark.integration - def test_sas_ignore_numbers(self): - """ - Test calculation of Semantic Answer Similarity (SAS) Score with ignoring numbers. - """ - predictions = [ - "A construction budget of US $2.3 billion", - "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", - "The Meiji Restoration, in 1868, transformed Japan into a modernized world power.", - ] - labels = [ - "A construction budget of US $10.3 billion", - "The Eiffel Tower, completed in 2005, symbolizes Paris's cultural magnificence.", - "The Meiji Restoration, in 1989, transformed Japan into a modernized world power.", - ] - - evaluation_result = self.create_evaluation_result(predictions, labels) - # SAS after ignoring numbers - sas_result = evaluation_result._calculate_sas( - output_key="answers", - model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2", - ignore_numbers=True, - ) - - assert sas_result["sas"] == pytest.approx(1.0) - assert sas_result["scores"] == pytest.approx([1.0, 1.0, 1.0]) - - @pytest.mark.integration - def test_sas_regex_ignore(self): - """ - Test calculation of Semantic Answer Similarity (SAS) Score with ignoring specific regex patterns. - """ - predictions = [ - "A construction budget of US $2.3 billion", - "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", - "The Meiji Restoration, in 1868, transformed Japan into a modernized world power.", - ] - labels = [ - "A construction budget of US $10.3 billion", - "The Eiffel Tower, completed in 2005, symbolizes Paris's cultural magnificence.", - "The Meiji Restoration, in 1989, transformed Japan into a modernized world power.", - ] - - evaluation_result = self.create_evaluation_result(predictions, labels) - # Ignore numeric patterns - regex_to_ignore = [r"\d+"] - sas_result = evaluation_result._calculate_sas( - output_key="answers", - model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2", - regexes_to_ignore=regex_to_ignore, - ) - - assert sas_result["sas"] == pytest.approx(1.0) - assert sas_result["scores"] == pytest.approx([1.0, 1.0, 1.0]) - - @pytest.mark.integration - def test_sas_multiple_ignore_regex(self): - """ - Test calculation of Semantic Answer Similarity (SAS) Score with multiple ignoring parameters. - """ - predictions = [ - "A construction budget of US $2.3 billion", - "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", - "The Meiji Restoration, in 1868, transformed Japan into a modernized world power.", - ] - labels = [ - "A construction budget of US #10.3 billion", - "The Eiffel Tower!!, completed in 2005, symbolizes Paris's cultural magnificence.", - "The **Meiji Restoration**, in 1989, transformed Japan into a modernized world power.", - ] - - evaluation_result = self.create_evaluation_result(predictions, labels) - # Ignore numeric patterns and punctuation excluding whitespaces - regex_to_ignore = [r"\d+", r"[^\w\s]"] - sas_result = evaluation_result._calculate_sas( - output_key="answers", - model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2", - regexes_to_ignore=regex_to_ignore, - ) - - assert sas_result["sas"] == pytest.approx(1.0) - assert sas_result["scores"] == pytest.approx([1.0, 1.0, 1.0]) - - @pytest.mark.integration - def test_sas_multiple_ignore_combination(self): - """ - Test calculation of Semantic Answer Similarity (SAS) Score with multiple ignoring parameters combined. - """ - predictions = [ - "A construction budget of US $2.3 billion", - "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", - "The Meiji Restoration, in 1868, transformed Japan into a modernized world power.", - ] - labels = [ - "A construction budget of US #10.3 BILLION", - "The EIFFEL TOWER!!, completed in 2005, symbolizes Paris's cultural magnificence.", - "The **MEIJI RESTORATION**, in 1989, transformed Japan into a modernized world power.", - ] - - evaluation_result = self.create_evaluation_result(predictions, labels) - # Ignore only special characters using regex - regex_to_ignore = [r"[^\w\s\d]+"] - sas_result = evaluation_result._calculate_sas( - output_key="answers", - model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2", - ignore_numbers=True, - ignore_punctuation=True, - ignore_case=True, - regexes_to_ignore=regex_to_ignore, - ) - - assert sas_result["sas"] == pytest.approx(1.0) - assert sas_result["scores"] == pytest.approx([1.0, 1.0, 1.0]) - - @pytest.mark.integration - def test_sas_bi_encoder(self): - """ - Test calculation of Semantic Answer Similarity (SAS) Score using a Bi-Encoder model. - """ - predictions = [ - "A construction budget of US $2.3 billion", - "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", - "The Meiji Restoration in 1868 transformed Japan into a modernized world power.", - ] - labels = [ - "A construction budget of US $2.3 billion", - "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", - "The Meiji Restoration in 1868 transformed Japan into a modernized world power.", - ] - evaluation_result = self.create_evaluation_result(predictions, labels) - sas_result = evaluation_result._calculate_sas( - output_key="answers", model="sentence-transformers/all-mpnet-base-v2" - ) - - assert sas_result["sas"] == pytest.approx(1.0) - assert sas_result["scores"] == pytest.approx([1.0, 1.0, 1.0]) - - @pytest.mark.integration - def test_sas_cross_encoder(self): - """ - Test calculation of Semantic Answer Similarity (SAS) Score using a Cross Encoder model. - """ - predictions = [ - "A construction budget of US $2.3 billion", - "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", - "The Meiji Restoration in 1868 transformed Japan into a modernized world power.", - ] - labels = [ - "A construction budget of US $2.3 billion", - "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", - "The Meiji Restoration in 1868 transformed Japan into a modernized world power.", - ] - evaluation_result = self.create_evaluation_result(predictions, labels) - sas_result = evaluation_result._calculate_sas( - output_key="answers", model="cross-encoder/ms-marco-MiniLM-L-6-v2" - ) - - assert sas_result["sas"] == pytest.approx(0.999967, abs=1e-5) - assert sas_result["scores"] == pytest.approx( - [0.9999765157699585, 0.999968409538269, 0.9999572038650513], abs=1e-5 - )