Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
julian-risch committed Feb 14, 2024
1 parent d857880 commit ffc2c66
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 48 deletions.
8 changes: 6 additions & 2 deletions integrations/ragas/example/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,12 @@

# Each metric expects a specific set of parameters as input. Refer to the
# Ragas class' documentation for more details.
results = pipeline.run({"evaluator_context": {"questions": QUESTIONS, "contexts": CONTEXTS, "ground_truths": GROUND_TRUTHS},
"evaluator_aspect": {"questions": QUESTIONS, "contexts": CONTEXTS, "responses": RESPONSES}})
results = pipeline.run(
{
"evaluator_context": {"questions": QUESTIONS, "contexts": CONTEXTS, "ground_truths": GROUND_TRUTHS},
"evaluator_aspect": {"questions": QUESTIONS, "contexts": CONTEXTS, "responses": RESPONSES},
}
)


for component in ["evaluator_context", "evaluator_aspect"]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from ragas import evaluate # type: ignore
from ragas.evaluation import Result # type: ignore
from ragas.metrics import AspectCritique # type: ignore
from ragas.metrics.base import Metric # type: ignore

from .metrics import (
Expand Down Expand Up @@ -104,7 +103,7 @@ def run(self, **inputs) -> Dict[str, Any]:

OutputConverters.validate_outputs(results)
converted_results = [
[result.to_dict()] for result in OutputConverters.extract_results(results, self.metric, self.metric_params)
[result.to_dict()] for result in self.descriptor.output_converter(results, self.metric, self.metric_params)
]

return {"results": converted_results}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@

from ragas.evaluation import Result
from ragas.metrics import ( # type: ignore
AspectCritique, # type: ignore
AnswerCorrectness, # type: ignore
AnswerRelevancy, # type: ignore
AnswerSimilarity, # type: ignore
AspectCritique, # type: ignore
ContextPrecision, # type: ignore
ContextRecall, # type: ignore
ContextRelevancy, # type: ignore
Expand Down Expand Up @@ -40,7 +40,7 @@ def from_str(cls, string: str) -> "RagasMetric":
enum_map = {e.value: e for e in RagasMetric}
metric = enum_map.get(string)
if metric is None:
msg = f"Unknown Ragas enum value '{string}'. Supported enums: {list(enum_map.keys())}"
msg = f"Unknown Ragas metric '{string}'. Supported metrics: {list(enum_map.keys())}"
raise ValueError(msg)
return metric

Expand Down Expand Up @@ -119,6 +119,9 @@ class MetricDescriptor:
to set the input types of the evaluator component.
:param input_converter:
Callable that converts input parameters to the Ragas input format.
:param output_converter:
Callable that converts the Ragas output format to our output format.
Accepts a single output parameter and returns a list of results derived from it.
:param init_parameters:
Additional parameters that need to be passed to the metric class during initialization.
"""
Expand All @@ -128,6 +131,7 @@ class MetricDescriptor:
input_validator: Callable
input_parameters: Dict[str, Type]
input_converter: Callable[[Any], Iterable[Dict[str, str]]]
output_converter: Callable[[Result], List[MetricResult]]
init_parameters: Optional[Dict[str, Type[Any]]] = None

@classmethod
Expand All @@ -137,6 +141,7 @@ def new(
backend: Type[Metric],
input_validator,
input_converter: Callable[[Any], Iterable[Dict[str, str]]],
output_converter: Optional[Callable[[Result], List[MetricResult]]] = None,
*,
init_parameters: Optional[Dict[str, Type]] = None,
) -> "MetricDescriptor":
Expand All @@ -155,6 +160,7 @@ def new(
input_validator=input_validator,
input_parameters=input_parameters,
input_converter=input_converter,
output_converter=output_converter if output_converter is not None else OutputConverters.default,
init_parameters=init_parameters,
)

Expand All @@ -171,7 +177,8 @@ class InputValidators:
@staticmethod
def validate_empty_metric_parameters(metric: RagasMetric, metric_params: Dict[str, Any]):
if metric_params:
raise ValueError(f"Unexpected init parameters '{metric_params}' for metric '{metric}'.")
msg = f"Unexpected init parameters '{metric_params}' for metric '{metric}'."
raise ValueError(msg)

@staticmethod
def validate_aspect_critique_parameters(metric: RagasMetric, metric_params: Dict[str, Any]):
Expand Down Expand Up @@ -294,53 +301,79 @@ def validate_outputs(outputs: Result):
raise ValueError(msg)

@staticmethod
def extract_results(
output: Result, metric: RagasMetric, metric_params: Optional[Dict[str, Any]]
) -> List[MetricResult]:
def _extract_default_results(output: Result, metric_name: str) -> List[MetricResult]:
try:
metric_name = ""
if metric == RagasMetric.ASPECT_CRITIQUE and metric_params:
if "name" in metric_params:
metric_name = metric_params["name"]
elif "aspect" in metric_params:
metric_name = metric_params["aspect"].value
else:
metric_name = metric.value
output_scores: List[Dict[str, float]] = output.scores.to_list()
return [MetricResult(name=metric_name, score=metric_dict[metric_name]) for metric_dict in output_scores]
except KeyError as e:
msg = f"Ragas evaluator did not return an expected output for metric '{e.args[0]}'"
raise ValueError(msg) from e

@staticmethod
def default(output: Result, metric: RagasMetric, _) -> List[MetricResult]:
metric_name = metric.value
return OutputConverters._extract_default_results(output, metric_name)

@staticmethod
def aspect_critique(output: Result, _, metric_params: Dict[str, Any]) -> List[MetricResult]:
metric_name = metric_params["name"]
return OutputConverters._extract_default_results(output, metric_name)


METRIC_DESCRIPTORS = {
RagasMetric.ANSWER_CORRECTNESS: MetricDescriptor.new(
RagasMetric.ANSWER_CORRECTNESS, AnswerCorrectness, InputValidators.validate_empty_metric_parameters, InputConverters.question_response_ground_truth # type: ignore
RagasMetric.ANSWER_CORRECTNESS,
AnswerCorrectness,
InputValidators.validate_empty_metric_parameters,
InputConverters.question_response_ground_truth, # type: ignore
),
RagasMetric.FAITHFULNESS: MetricDescriptor.new(
RagasMetric.FAITHFULNESS, Faithfulness, InputValidators.validate_empty_metric_parameters, InputConverters.question_context_response # type: ignore
RagasMetric.FAITHFULNESS,
Faithfulness,
InputValidators.validate_empty_metric_parameters,
InputConverters.question_context_response, # type: ignore
),
RagasMetric.ANSWER_SIMILARITY: MetricDescriptor.new(
RagasMetric.ANSWER_SIMILARITY, AnswerSimilarity, InputValidators.validate_empty_metric_parameters, InputConverters.response_ground_truth # type: ignore
RagasMetric.ANSWER_SIMILARITY,
AnswerSimilarity,
InputValidators.validate_empty_metric_parameters,
InputConverters.response_ground_truth, # type: ignore
),
RagasMetric.CONTEXT_PRECISION: MetricDescriptor.new(
RagasMetric.CONTEXT_PRECISION, ContextPrecision, InputValidators.validate_empty_metric_parameters, InputConverters.question_context_ground_truth # type: ignore
RagasMetric.CONTEXT_PRECISION,
ContextPrecision,
InputValidators.validate_empty_metric_parameters,
InputConverters.question_context_ground_truth, # type: ignore
),
RagasMetric.CONTEXT_UTILIZATION: MetricDescriptor.new(
RagasMetric.CONTEXT_UTILIZATION,
ContextUtilization, InputValidators.validate_empty_metric_parameters,
ContextUtilization,
InputValidators.validate_empty_metric_parameters,
InputConverters.question_context_response, # type: ignore
),
RagasMetric.CONTEXT_RECALL: MetricDescriptor.new(
RagasMetric.CONTEXT_RECALL, ContextRecall, InputValidators.validate_empty_metric_parameters, InputConverters.question_context_ground_truth # type: ignore
RagasMetric.CONTEXT_RECALL,
ContextRecall,
InputValidators.validate_empty_metric_parameters,
InputConverters.question_context_ground_truth, # type: ignore
),
RagasMetric.ASPECT_CRITIQUE: MetricDescriptor.new(
RagasMetric.ASPECT_CRITIQUE, AspectCritique, InputValidators.validate_aspect_critique_parameters, InputConverters.question_context_response # type: ignore
RagasMetric.ASPECT_CRITIQUE,
AspectCritique,
InputValidators.validate_aspect_critique_parameters,
InputConverters.question_context_response,
OutputConverters.aspect_critique, # type: ignore
),
RagasMetric.CONTEXT_RELEVANCY: MetricDescriptor.new(
RagasMetric.CONTEXT_RELEVANCY, ContextRelevancy, InputValidators.validate_empty_metric_parameters, InputConverters.question_context # type: ignore
RagasMetric.CONTEXT_RELEVANCY,
ContextRelevancy,
InputValidators.validate_empty_metric_parameters,
InputConverters.question_context, # type: ignore
),
RagasMetric.ANSWER_RELEVANCY: MetricDescriptor.new(
RagasMetric.ANSWER_RELEVANCY, AnswerRelevancy, InputValidators.validate_empty_metric_parameters, InputConverters.question_context_response # type: ignore
RagasMetric.ANSWER_RELEVANCY,
AnswerRelevancy,
InputValidators.validate_empty_metric_parameters,
InputConverters.question_context_response, # type: ignore
),
}
61 changes: 41 additions & 20 deletions integrations/ragas/tests/test_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,17 @@ def evaluate(self, _, metric: Metric, **kwargs):
def test_evaluator_metric_init_params():
eval = RagasEvaluator(
RagasMetric.ASPECT_CRITIQUE,
metric_params={'name': 'harmfulness', 'definition': 'Does the submission causes or has the potential to cause '
'harm to individuals, groups, or society at large?'},
metric_params={
'name': 'harmfulness',
'definition': 'Does the submission causes or has the potential to cause '
'harm to individuals, groups, or society at large?',
},
)
assert eval.metric_params == {'definition': 'Does the submission causes or has the potential to cause harm to '
'individuals, groups, or society at large?', 'name': 'harmfulness'}
assert eval.metric_params == {
'definition': 'Does the submission causes or has the potential to cause harm to '
'individuals, groups, or society at large?',
'name': 'harmfulness',
}

with pytest.raises(ValueError, match="Invalid init parameters"):
RagasEvaluator(RagasMetric.ASPECT_CRITIQUE, metric_params=None)
Expand Down Expand Up @@ -96,9 +102,12 @@ def test_evaluator_metric_init_params():
def test_evaluator_serde():
init_params = {
"metric": RagasMetric.ASPECT_CRITIQUE,
"metric_params": {'name': 'harmfulness', 'definition': 'Does the submission causes or has the potential to '
'cause harm to individuals, groups, or society at '
'large?'},
"metric_params": {
'name': 'harmfulness',
'definition': 'Does the submission causes or has the potential to '
'cause harm to individuals, groups, or society at '
'large?',
},
}
eval = RagasEvaluator(**init_params)
serde_data = eval.to_dict()
Expand Down Expand Up @@ -126,9 +135,12 @@ def test_evaluator_serde():
(
RagasMetric.ASPECT_CRITIQUE,
{"questions": [], "contexts": [], "responses": []},
{'name': 'harmfulness', 'definition': 'Does the submission causes or has the potential to '
'cause harm to individuals, groups, or society at '
'large?'},
{
'name': 'harmfulness',
'definition': 'Does the submission causes or has the potential to '
'cause harm to individuals, groups, or society at '
'large?',
},
),
(RagasMetric.CONTEXT_RELEVANCY, {"questions": [], "contexts": []}, None),
(RagasMetric.ANSWER_RELEVANCY, {"questions": [], "contexts": [], "responses": []}, None),
Expand Down Expand Up @@ -214,9 +226,12 @@ def test_evaluator_invalid_inputs(current_metric, inputs, error_string, params):
RagasMetric.ASPECT_CRITIQUE,
{"questions": ["q7"], "contexts": [["c7"]], "responses": ["r7"]},
[[("harmfulness", 1.0)]],
{'name': 'harmfulness', 'definition': 'Does the submission causes or has the potential to '
'cause harm to individuals, groups, or society at '
'large?'},
{
'name': 'harmfulness',
'definition': 'Does the submission causes or has the potential to '
'cause harm to individuals, groups, or society at '
'large?',
},
),
(
RagasMetric.CONTEXT_RELEVANCY,
Expand Down Expand Up @@ -259,7 +274,6 @@ def test_evaluator_outputs(current_metric, inputs, expected_outputs, metric_para
@pytest.mark.parametrize(
"metric, inputs, metric_params",
[
(RagasMetric.ANSWER_CORRECTNESS, {"questions": DEFAULT_QUESTIONS, "contexts": DEFAULT_CONTEXTS}, None),
(
RagasMetric.ANSWER_CORRECTNESS,
{"questions": DEFAULT_QUESTIONS, "responses": DEFAULT_RESPONSES, "ground_truths": DEFAULT_GROUND_TRUTHS},
Expand All @@ -273,7 +287,7 @@ def test_evaluator_outputs(current_metric, inputs, expected_outputs, metric_para
(RagasMetric.ANSWER_SIMILARITY, {"responses": DEFAULT_QUESTIONS, "ground_truths": DEFAULT_GROUND_TRUTHS}, None),
(
RagasMetric.CONTEXT_PRECISION,
{"questions": DEFAULT_QUESTIONS, "responses": DEFAULT_RESPONSES, "ground_truths": DEFAULT_GROUND_TRUTHS},
{"questions": DEFAULT_QUESTIONS, "contexts": DEFAULT_CONTEXTS, "ground_truths": DEFAULT_GROUND_TRUTHS},
None,
),
(
Expand All @@ -283,18 +297,25 @@ def test_evaluator_outputs(current_metric, inputs, expected_outputs, metric_para
),
(
RagasMetric.CONTEXT_RECALL,
{"questions": DEFAULT_QUESTIONS, "contexts": DEFAULT_CONTEXTS, "responses": DEFAULT_RESPONSES},
{"questions": DEFAULT_QUESTIONS, "contexts": DEFAULT_CONTEXTS, "ground_truths": DEFAULT_GROUND_TRUTHS},
None,
),
(
RagasMetric.ASPECT_CRITIQUE,
{"questions": DEFAULT_QUESTIONS, "contexts": DEFAULT_CONTEXTS, "responses": DEFAULT_RESPONSES},
{'name': 'harmfulness', 'definition': 'Does the submission causes or has the potential to '
'cause harm to individuals, groups, or society at '
'large?'},
{
'name': 'harmfulness',
'definition': 'Does the submission causes or has the potential to '
'cause harm to individuals, groups, or society at '
'large?',
},
),
(RagasMetric.CONTEXT_RELEVANCY, {"questions": DEFAULT_QUESTIONS, "contexts": DEFAULT_CONTEXTS}, None),
(RagasMetric.ANSWER_RELEVANCY, {"questions": DEFAULT_QUESTIONS, "responses": DEFAULT_RESPONSES}, None),
(
RagasMetric.ANSWER_RELEVANCY,
{"questions": DEFAULT_QUESTIONS, "contexts": DEFAULT_CONTEXTS, "responses": DEFAULT_RESPONSES},
None,
),
],
)
def test_integration_run(metric, inputs, metric_params):
Expand Down
2 changes: 1 addition & 1 deletion integrations/ragas/tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ def test_ragas_metric():
for e in RagasMetric:
assert e == RagasMetric.from_str(e.value)

with pytest.raises(ValueError, match="Unknown Ragas enum value"):
with pytest.raises(ValueError, match="Unknown Ragas metric"):
RagasMetric.from_str("smugness")

0 comments on commit ffc2c66

Please sign in to comment.