Skip to content

Commit

Permalink
only output warning message if ground_truths is found
Browse files Browse the repository at this point in the history
  • Loading branch information
jjmachan committed Feb 2, 2024
1 parent 15d8272 commit f4fa4f1
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
2 changes: 2 additions & 0 deletions src/ragas/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
remap_column_names,
validate_column_dtypes,
validate_evaluation_modes,
handle_deprecated_ground_truths,
)

if t.TYPE_CHECKING:
Expand Down Expand Up @@ -134,6 +135,7 @@ def evaluate(
# remap column names from the dataset
dataset = remap_column_names(dataset, column_map)
# validation
dataset = handle_deprecated_ground_truths(dataset)
validate_evaluation_modes(dataset, metrics)
validate_column_dtypes(dataset)

Expand Down
25 changes: 24 additions & 1 deletion src/ragas/validation.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from __future__ import annotations

import logging

from datasets import Dataset, Sequence

from ragas.metrics._context_precision import ContextPrecision
from ragas.metrics.base import EvaluationMode, Metric

logger = logging.getLogger(__name__)


def remap_column_names(dataset: Dataset, column_map: dict[str, str]) -> Dataset:
"""
Expand All @@ -15,6 +19,26 @@ def remap_column_names(dataset: Dataset, column_map: dict[str, str]) -> Dataset:
return dataset.rename_columns(inverse_column_map)


def handle_deprecated_ground_truths(ds: Dataset) -> Dataset:
if "ground_truths" in ds.features:
column_names = "ground_truths"
if (
isinstance(ds.features[column_names], Sequence)
and ds.features[column_names].feature.dtype == "string"
):
logger.warning(
"passing column names as 'ground_truths' is deprecated and will be removed in the next version, please use 'ground_truth' instead. Note that `ground_truth` should be of type string and not Sequence[string] like `ground_truths`"
)
gt = [gt[0] for gt in ds["ground_truths"]]
ds = ds.add_column(
"ground_truth",
gt,
new_fingerprint=ds._fingerprint
+ "a", # adding random to fingerprint to avoid caching
)
return ds


def validate_column_dtypes(ds: Dataset):
for column_names in ["question", "answer", "ground_truth"]:
if column_names in ds.features:
Expand Down Expand Up @@ -56,7 +80,6 @@ def validate_evaluation_modes(ds: Dataset, metrics: list[Metric]):
3. (q,c)
4. (g,a)
"""

for m in metrics:
required_columns = set(EVALMODE_TO_COLUMNS[m.evaluation_mode])
available_columns = set(ds.features.keys())
Expand Down

0 comments on commit f4fa4f1

Please sign in to comment.