Skip to content

Commit

Permalink
fix: fix serialization of DocumentRecallEvaluator (#7662)
Browse files Browse the repository at this point in the history
* fix serialization of DocumentRecallEvaluator

* add requested tests
  • Loading branch information
anakin87 authored May 8, 2024
1 parent f14bc53 commit 9446714
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 1 deletion.
12 changes: 11 additions & 1 deletion haystack/components/evaluators/document_recall.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from enum import Enum
from typing import Any, Dict, List, Union

from haystack.core.component import component
from haystack import component, default_to_dict
from haystack.dataclasses import Document


Expand Down Expand Up @@ -74,6 +74,7 @@ def __init__(self, mode: Union[str, RecallMode] = RecallMode.SINGLE_HIT):

mode_functions = {RecallMode.SINGLE_HIT: self._recall_single_hit, RecallMode.MULTI_HIT: self._recall_multi_hit}
self.mode_function = mode_functions[mode]
self.mode = mode

def _recall_single_hit(self, ground_truth_documents: List[Document], retrieved_documents: List[Document]) -> float:
unique_truths = {g.content for g in ground_truth_documents}
Expand Down Expand Up @@ -117,3 +118,12 @@ def run(
scores.append(score)

return {"score": sum(scores) / len(retrieved_documents), "individual_scores": scores}

def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.
:returns:
Dictionary with serialized data.
"""
return default_to_dict(self, mode=str(self.mode))
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
fixes:
- |
Add `to_dict` method to `DocumentRecallEvaluator` to allow proper serialization of the component.
31 changes: 31 additions & 0 deletions test/components/evaluators/test_document_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from haystack.components.evaluators.document_recall import DocumentRecallEvaluator, RecallMode
from haystack.dataclasses import Document
from haystack import default_from_dict


def test_init_with_unknown_mode_string():
Expand Down Expand Up @@ -78,6 +79,21 @@ def test_run_with_different_lengths(self, evaluator):
retrieved_documents=[[Document(content="Berlin")]],
)

def test_to_dict(self, evaluator):
data = evaluator.to_dict()
assert data == {
"type": "haystack.components.evaluators.document_recall.DocumentRecallEvaluator",
"init_parameters": {"mode": "single_hit"},
}

def test_from_dict(self):
data = {
"type": "haystack.components.evaluators.document_recall.DocumentRecallEvaluator",
"init_parameters": {"mode": "single_hit"},
}
new_evaluator = default_from_dict(DocumentRecallEvaluator, data)
assert new_evaluator.mode == RecallMode.SINGLE_HIT


class TestDocumentRecallEvaluatorMultiHit:
@pytest.fixture
Expand Down Expand Up @@ -152,3 +168,18 @@ def test_run_with_different_lengths(self, evaluator):
ground_truth_documents=[[Document(content="Berlin")], [Document(content="Paris")]],
retrieved_documents=[[Document(content="Berlin")]],
)

def test_to_dict(self, evaluator):
data = evaluator.to_dict()
assert data == {
"type": "haystack.components.evaluators.document_recall.DocumentRecallEvaluator",
"init_parameters": {"mode": "multi_hit"},
}

def test_from_dict(self):
data = {
"type": "haystack.components.evaluators.document_recall.DocumentRecallEvaluator",
"init_parameters": {"mode": "multi_hit"},
}
new_evaluator = default_from_dict(DocumentRecallEvaluator, data)
assert new_evaluator.mode == RecallMode.MULTI_HIT

0 comments on commit 9446714

Please sign in to comment.