From f4fa4f1e827bafedb471cc01967c631f61eded15 Mon Sep 17 00:00:00 2001 From: jjmachan Date: Thu, 1 Feb 2024 22:44:32 -0800 Subject: [PATCH] only output warning message if ground_truths is found --- src/ragas/evaluation.py | 2 ++ src/ragas/validation.py | 25 ++++++++++++++++++++++++- 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/src/ragas/evaluation.py b/src/ragas/evaluation.py index 858d0ab46..a188f1499 100644 --- a/src/ragas/evaluation.py +++ b/src/ragas/evaluation.py @@ -21,6 +21,7 @@ remap_column_names, validate_column_dtypes, validate_evaluation_modes, + handle_deprecated_ground_truths, ) if t.TYPE_CHECKING: @@ -134,6 +135,7 @@ def evaluate( # remap column names from the dataset dataset = remap_column_names(dataset, column_map) # validation + dataset = handle_deprecated_ground_truths(dataset) validate_evaluation_modes(dataset, metrics) validate_column_dtypes(dataset) diff --git a/src/ragas/validation.py b/src/ragas/validation.py index 17a6e3f4f..d5f2d2e96 100644 --- a/src/ragas/validation.py +++ b/src/ragas/validation.py @@ -1,10 +1,14 @@ from __future__ import annotations +import logging + from datasets import Dataset, Sequence from ragas.metrics._context_precision import ContextPrecision from ragas.metrics.base import EvaluationMode, Metric +logger = logging.getLogger(__name__) + def remap_column_names(dataset: Dataset, column_map: dict[str, str]) -> Dataset: """ @@ -15,6 +19,26 @@ def remap_column_names(dataset: Dataset, column_map: dict[str, str]) -> Dataset: return dataset.rename_columns(inverse_column_map) +def handle_deprecated_ground_truths(ds: Dataset) -> Dataset: + if "ground_truths" in ds.features: + column_names = "ground_truths" + if ( + isinstance(ds.features[column_names], Sequence) + and ds.features[column_names].feature.dtype == "string" + ): + logger.warning( + "passing column names as 'ground_truths' is deprecated and will be removed in the next version, please use 'ground_truth' instead. Note that `ground_truth` should be of type string and not Sequence[string] like `ground_truths`" + ) + gt = [gt[0] for gt in ds["ground_truths"]] + ds = ds.add_column( + "ground_truth", + gt, + new_fingerprint=ds._fingerprint + + "a", # adding random to fingerprint to avoid caching + ) + return ds + + def validate_column_dtypes(ds: Dataset): for column_names in ["question", "answer", "ground_truth"]: if column_names in ds.features: @@ -56,7 +80,6 @@ def validate_evaluation_modes(ds: Dataset, metrics: list[Metric]): 3. (q,c) 4. (g,a) """ - for m in metrics: required_columns = set(EVALMODE_TO_COLUMNS[m.evaluation_mode]) available_columns = set(ds.features.keys())