Skip to content

Commit

Permalink
fix: v1 to v2 dataset (#1275)
Browse files Browse the repository at this point in the history
fixes: #1271
  • Loading branch information
shahules786 authored Sep 11, 2024
1 parent 3076f50 commit a4a3e56
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 4 deletions.
12 changes: 8 additions & 4 deletions src/ragas/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,12 @@
)
from ragas.metrics.critique import AspectCritique
from ragas.run_config import RunConfig
from ragas.utils import REQUIRED_COLS_v1, get_feature_language, safe_nanmean
from ragas.utils import (
convert_v1_to_v2_dataset,
convert_v2_to_v1_dataset,
get_feature_language,
safe_nanmean,
)
from ragas.validation import (
remap_column_names,
validate_required_columns,
Expand Down Expand Up @@ -164,7 +169,7 @@ def evaluate(
# remap column names from the dataset
v1_input = True
dataset = remap_column_names(dataset, column_map)
dataset = remap_column_names(dataset, REQUIRED_COLS_v1)
dataset = convert_v1_to_v2_dataset(dataset)
# validation
dataset = EvaluationDataset.from_list(dataset.to_list())

Expand Down Expand Up @@ -310,8 +315,7 @@ def evaluate(
# convert to v.1 dataset
dataset = dataset.to_hf_dataset()
if v1_input:
cols = {k: v for v, k in REQUIRED_COLS_v1.items()}
dataset = remap_column_names(dataset, cols)
dataset = convert_v2_to_v1_dataset(dataset)

cost_cb = ragas_callbacks["cost_cb"] if "cost_cb" in ragas_callbacks else None
result = Result(
Expand Down
11 changes: 11 additions & 0 deletions src/ragas/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from functools import lru_cache

import numpy as np
from datasets import Dataset

if t.TYPE_CHECKING:
from ragas.metrics.base import Metric
Expand Down Expand Up @@ -197,3 +198,13 @@ def get_required_columns_v1(metric: Metric):
def convert_row_v1_to_v2(row: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
required_cols_v2 = {k: v for v, k in REQUIRED_COLS_v1.items()}
return {required_cols_v2[k]: v for k, v in row.items() if k in required_cols_v2}


def convert_v1_to_v2_dataset(dataset: Dataset) -> Dataset:
columns_map = {v: k for k, v in REQUIRED_COLS_v1.items() if v in dataset.features}
return dataset.rename_columns(columns_map)


def convert_v2_to_v1_dataset(dataset: Dataset) -> Dataset:
columns_map = {k: v for k, v in REQUIRED_COLS_v1.items() if k in dataset.features}
return dataset.rename_columns(columns_map)

0 comments on commit a4a3e56

Please sign in to comment.