forked from deepset-ai/haystack-core-integrations
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_retriever.py
105 lines (95 loc) · 4.5 KB
/
test_retriever.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
import pytest
from haystack.document_stores.types import FilterPolicy
from haystack_integrations.components.retrievers.chroma import ChromaQueryTextRetriever
from haystack_integrations.document_stores.chroma import ChromaDocumentStore
@pytest.mark.integration
def test_retriever_init(request):
ds = ChromaDocumentStore(
collection_name=request.node.name, embedding_function="HuggingFaceEmbeddingFunction", api_key="1234567890"
)
retriever = ChromaQueryTextRetriever(ds, filters={"foo": "bar"}, top_k=99, filter_policy="replace")
assert retriever.filter_policy == FilterPolicy.REPLACE
with pytest.raises(ValueError):
ChromaQueryTextRetriever(ds, filters={"foo": "bar"}, top_k=99, filter_policy="unknown")
@pytest.mark.integration
def test_retriever_to_json(request):
ds = ChromaDocumentStore(
collection_name=request.node.name, embedding_function="HuggingFaceEmbeddingFunction", api_key="1234567890"
)
retriever = ChromaQueryTextRetriever(ds, filters={"foo": "bar"}, top_k=99)
assert retriever.to_dict() == {
"type": "haystack_integrations.components.retrievers.chroma.retriever.ChromaQueryTextRetriever",
"init_parameters": {
"filters": {"foo": "bar"},
"top_k": 99,
"filter_policy": "replace",
"document_store": {
"type": "haystack_integrations.document_stores.chroma.document_store.ChromaDocumentStore",
"init_parameters": {
"collection_name": "test_retriever_to_json",
"embedding_function": "HuggingFaceEmbeddingFunction",
"persist_path": None,
"api_key": "1234567890",
"distance_function": "l2",
},
},
},
}
@pytest.mark.integration
def test_retriever_from_json(request):
data = {
"type": "haystack_integrations.components.retrievers.chroma.retriever.ChromaQueryTextRetriever",
"init_parameters": {
"filters": {"bar": "baz"},
"top_k": 42,
"filter_policy": "replace",
"document_store": {
"type": "haystack_integrations.document_stores.chroma.document_store.ChromaDocumentStore",
"init_parameters": {
"collection_name": "test_retriever_from_json",
"embedding_function": "HuggingFaceEmbeddingFunction",
"persist_path": ".",
"api_key": "1234567890",
"distance_function": "l2",
},
},
},
}
retriever = ChromaQueryTextRetriever.from_dict(data)
assert retriever.document_store._collection_name == request.node.name
assert retriever.document_store._embedding_function == "HuggingFaceEmbeddingFunction"
assert retriever.document_store._embedding_function_params == {"api_key": "1234567890"}
assert retriever.document_store._persist_path == "."
assert retriever.filters == {"bar": "baz"}
assert retriever.top_k == 42
assert retriever.filter_policy == FilterPolicy.REPLACE
@pytest.mark.integration
def test_retriever_from_json_no_filter_policy(request):
data = {
"type": "haystack_integrations.components.retrievers.chroma.retriever.ChromaQueryTextRetriever",
"init_parameters": {
"filters": {"bar": "baz"},
"top_k": 42,
"document_store": {
"type": "haystack_integrations.document_stores.chroma.document_store.ChromaDocumentStore",
"init_parameters": {
"collection_name": "test_retriever_from_json_no_filter_policy",
"embedding_function": "HuggingFaceEmbeddingFunction",
"persist_path": ".",
"api_key": "1234567890",
"distance_function": "l2",
},
},
},
}
retriever = ChromaQueryTextRetriever.from_dict(data)
assert retriever.document_store._collection_name == request.node.name
assert retriever.document_store._embedding_function == "HuggingFaceEmbeddingFunction"
assert retriever.document_store._embedding_function_params == {"api_key": "1234567890"}
assert retriever.document_store._persist_path == "."
assert retriever.filters == {"bar": "baz"}
assert retriever.top_k == 42
assert retriever.filter_policy == FilterPolicy.REPLACE # default even if not specified