From a4a3e5655665fc716b286c6590e1718a22edfd26 Mon Sep 17 00:00:00 2001 From: ikka Date: Wed, 11 Sep 2024 16:15:06 +0530 Subject: [PATCH] fix: v1 to v2 dataset (#1275) fixes: #1271 --- src/ragas/evaluation.py | 12 ++++++++---- src/ragas/utils.py | 11 +++++++++++ 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/src/ragas/evaluation.py b/src/ragas/evaluation.py index b9b3d09593..cd4fabf975 100644 --- a/src/ragas/evaluation.py +++ b/src/ragas/evaluation.py @@ -34,7 +34,12 @@ ) from ragas.metrics.critique import AspectCritique from ragas.run_config import RunConfig -from ragas.utils import REQUIRED_COLS_v1, get_feature_language, safe_nanmean +from ragas.utils import ( + convert_v1_to_v2_dataset, + convert_v2_to_v1_dataset, + get_feature_language, + safe_nanmean, +) from ragas.validation import ( remap_column_names, validate_required_columns, @@ -164,7 +169,7 @@ def evaluate( # remap column names from the dataset v1_input = True dataset = remap_column_names(dataset, column_map) - dataset = remap_column_names(dataset, REQUIRED_COLS_v1) + dataset = convert_v1_to_v2_dataset(dataset) # validation dataset = EvaluationDataset.from_list(dataset.to_list()) @@ -310,8 +315,7 @@ def evaluate( # convert to v.1 dataset dataset = dataset.to_hf_dataset() if v1_input: - cols = {k: v for v, k in REQUIRED_COLS_v1.items()} - dataset = remap_column_names(dataset, cols) + dataset = convert_v2_to_v1_dataset(dataset) cost_cb = ragas_callbacks["cost_cb"] if "cost_cb" in ragas_callbacks else None result = Result( diff --git a/src/ragas/utils.py b/src/ragas/utils.py index 258dbad6ad..b9dcac8053 100644 --- a/src/ragas/utils.py +++ b/src/ragas/utils.py @@ -7,6 +7,7 @@ from functools import lru_cache import numpy as np +from datasets import Dataset if t.TYPE_CHECKING: from ragas.metrics.base import Metric @@ -197,3 +198,13 @@ def get_required_columns_v1(metric: Metric): def convert_row_v1_to_v2(row: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: required_cols_v2 = {k: v for v, k in REQUIRED_COLS_v1.items()} return {required_cols_v2[k]: v for k, v in row.items() if k in required_cols_v2} + + +def convert_v1_to_v2_dataset(dataset: Dataset) -> Dataset: + columns_map = {v: k for k, v in REQUIRED_COLS_v1.items() if v in dataset.features} + return dataset.rename_columns(columns_map) + + +def convert_v2_to_v1_dataset(dataset: Dataset) -> Dataset: + columns_map = {k: v for k, v in REQUIRED_COLS_v1.items() if k in dataset.features} + return dataset.rename_columns(columns_map)