From 701a790ebd997673400486a8b332bb1465635fa8 Mon Sep 17 00:00:00 2001 From: tstadel <60758086+tstadel@users.noreply.github.com> Date: Tue, 29 Oct 2024 17:49:29 +0100 Subject: [PATCH] feat: efficient knn filtering support for OpenSearch (#1134) * feat: efficient filtering support for OpenSearch * add hint about supported knn engines to docstring * Apply suggestions from code review --- .../opensearch/embedding_retriever.py | 12 +++++ .../opensearch/document_store.py | 8 +++- .../opensearch/tests/test_document_store.py | 44 +++++++++++++++++++ .../tests/test_embedding_retriever.py | 15 ++++++- 4 files changed, 76 insertions(+), 3 deletions(-) diff --git a/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/embedding_retriever.py b/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/embedding_retriever.py index e159634cf..1e9bb9132 100644 --- a/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/embedding_retriever.py +++ b/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/embedding_retriever.py @@ -31,6 +31,7 @@ def __init__( filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE, custom_query: Optional[Dict[str, Any]] = None, raise_on_failure: bool = True, + efficient_filtering: bool = False, ): """ Create the OpenSearchEmbeddingRetriever component. @@ -85,6 +86,8 @@ def __init__( :param raise_on_failure: If `True`, raises an exception if the API call fails. If `False`, logs a warning and returns an empty list. + :param efficient_filtering: If `True`, the filter will be applied during the approximate kNN search. + This is only supported for knn engines "faiss" and "lucene" and does not work with the default "nmslib". :raises ValueError: If `document_store` is not an instance of OpenSearchDocumentStore. """ @@ -100,6 +103,7 @@ def __init__( ) self._custom_query = custom_query self._raise_on_failure = raise_on_failure + self._efficient_filtering = efficient_filtering def to_dict(self) -> Dict[str, Any]: """ @@ -116,6 +120,7 @@ def to_dict(self) -> Dict[str, Any]: filter_policy=self._filter_policy.value, custom_query=self._custom_query, raise_on_failure=self._raise_on_failure, + efficient_filtering=self._efficient_filtering, ) @classmethod @@ -146,6 +151,7 @@ def run( filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None, custom_query: Optional[Dict[str, Any]] = None, + efficient_filtering: Optional[bool] = None, ): """ Retrieve documents using a vector similarity metric. @@ -196,6 +202,9 @@ def run( ) ``` + :param efficient_filtering: If `True`, the filter will be applied during the approximate kNN search. + This is only supported for knn engines "faiss" and "lucene" and does not work with the default "nmslib". + :returns: Dictionary with key "documents" containing the retrieved Documents. - documents: List of Document similar to `query_embedding`. @@ -208,6 +217,8 @@ def run( top_k = self._top_k if custom_query is None: custom_query = self._custom_query + if efficient_filtering is None: + efficient_filtering = self._efficient_filtering docs: List[Document] = [] @@ -217,6 +228,7 @@ def run( filters=filters, top_k=top_k, custom_query=custom_query, + efficient_filtering=efficient_filtering, ) except Exception as e: if self._raise_on_failure: diff --git a/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/document_store.py b/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/document_store.py index 6f7a6c96e..4ec2420b3 100644 --- a/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/document_store.py +++ b/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/document_store.py @@ -438,6 +438,7 @@ def _embedding_retrieval( filters: Optional[Dict[str, Any]] = None, top_k: int = 10, custom_query: Optional[Dict[str, Any]] = None, + efficient_filtering: bool = False, ) -> List[Document]: """ Retrieves documents that are most similar to the query embedding using a vector similarity metric. @@ -474,6 +475,8 @@ def _embedding_retrieval( } ``` + :param efficient_filtering: If `True`, the filter will be applied during the approximate kNN search. + This is only supported for knn engines "faiss" and "lucene" and does not work with the default "nmslib". :raises ValueError: If `query_embedding` is an empty list :returns: List of Document that are most similar to `query_embedding` """ @@ -509,7 +512,10 @@ def _embedding_retrieval( } if filters: - body["query"]["bool"]["filter"] = normalize_filters(filters) + if efficient_filtering: + body["query"]["bool"]["must"][0]["knn"]["embedding"]["filter"] = normalize_filters(filters) + else: + body["query"]["bool"]["filter"] = normalize_filters(filters) body["size"] = top_k diff --git a/integrations/opensearch/tests/test_document_store.py b/integrations/opensearch/tests/test_document_store.py index 9cc4bf4ea..043f59891 100644 --- a/integrations/opensearch/tests/test_document_store.py +++ b/integrations/opensearch/tests/test_document_store.py @@ -337,6 +337,27 @@ def document_store_embedding_dim_4(self, request): yield store store.client.indices.delete(index=index, params={"ignore": [400, 404]}) + @pytest.fixture + def document_store_embedding_dim_4_faiss(self, request): + """ + This is the most basic requirement for the child class: provide + an instance of this document store so the base class can use it. + """ + hosts = ["https://localhost:9200"] + # Use a different index for each test so we can run them in parallel + index = f"{request.node.name}" + + store = OpenSearchDocumentStore( + hosts=hosts, + index=index, + http_auth=("admin", "admin"), + verify_certs=False, + embedding_dim=4, + method={"space_type": "innerproduct", "engine": "faiss", "name": "hnsw"}, + ) + yield store + store.client.indices.delete(index=index, params={"ignore": [400, 404]}) + def assert_documents_are_equal(self, received: List[Document], expected: List[Document]): """ The OpenSearchDocumentStore.filter_documents() method returns a Documents with their score set. @@ -690,6 +711,29 @@ def test_embedding_retrieval_with_filters(self, document_store_embedding_dim_4: assert len(results) == 1 assert results[0].content == "Not very similar document with meta field" + def test_embedding_retrieval_with_filters_efficient_filtering( + self, document_store_embedding_dim_4_faiss: OpenSearchDocumentStore + ): + docs = [ + Document(content="Most similar document", embedding=[1.0, 1.0, 1.0, 1.0]), + Document(content="2nd best document", embedding=[0.8, 0.8, 0.8, 1.0]), + Document( + content="Not very similar document with meta field", + embedding=[0.0, 0.8, 0.3, 0.9], + meta={"meta_field": "custom_value"}, + ), + ] + document_store_embedding_dim_4_faiss.write_documents(docs) + + filters = {"field": "meta_field", "operator": "==", "value": "custom_value"} + results = document_store_embedding_dim_4_faiss._embedding_retrieval( + query_embedding=[0.1, 0.1, 0.1, 0.1], + filters=filters, + efficient_filtering=True, + ) + assert len(results) == 1 + assert results[0].content == "Not very similar document with meta field" + def test_embedding_retrieval_pagination(self, document_store_embedding_dim_4: OpenSearchDocumentStore): """ Test that handling of pagination works as expected, when the matching documents are > 10. diff --git a/integrations/opensearch/tests/test_embedding_retriever.py b/integrations/opensearch/tests/test_embedding_retriever.py index 75c191946..84e9828ca 100644 --- a/integrations/opensearch/tests/test_embedding_retriever.py +++ b/integrations/opensearch/tests/test_embedding_retriever.py @@ -19,6 +19,7 @@ def test_init_default(): assert retriever._filters == {} assert retriever._top_k == 10 assert retriever._filter_policy == FilterPolicy.REPLACE + assert retriever._efficient_filtering is False retriever = OpenSearchEmbeddingRetriever(document_store=mock_store, filter_policy="replace") assert retriever._filter_policy == FilterPolicy.REPLACE @@ -82,6 +83,7 @@ def test_to_dict(_mock_opensearch_client): "filter_policy": "replace", "custom_query": {"some": "custom query"}, "raise_on_failure": True, + "efficient_filtering": False, }, } @@ -101,6 +103,7 @@ def test_from_dict(_mock_opensearch_client): "filter_policy": "replace", "custom_query": {"some": "custom query"}, "raise_on_failure": False, + "efficient_filtering": True, }, } retriever = OpenSearchEmbeddingRetriever.from_dict(data) @@ -110,6 +113,7 @@ def test_from_dict(_mock_opensearch_client): assert retriever._custom_query == {"some": "custom query"} assert retriever._raise_on_failure is False assert retriever._filter_policy == FilterPolicy.REPLACE + assert retriever._efficient_filtering is True # For backwards compatibility with older versions of the retriever without a filter policy data = { @@ -139,6 +143,7 @@ def test_run(): filters={}, top_k=10, custom_query=None, + efficient_filtering=False, ) assert len(res) == 1 assert len(res["documents"]) == 1 @@ -150,7 +155,11 @@ def test_run_init_params(): mock_store = Mock(spec=OpenSearchDocumentStore) mock_store._embedding_retrieval.return_value = [Document(content="Test doc", embedding=[0.1, 0.2])] retriever = OpenSearchEmbeddingRetriever( - document_store=mock_store, filters={"from": "init"}, top_k=11, custom_query="custom_query" + document_store=mock_store, + filters={"from": "init"}, + top_k=11, + custom_query="custom_query", + efficient_filtering=True, ) res = retriever.run(query_embedding=[0.5, 0.7]) mock_store._embedding_retrieval.assert_called_once_with( @@ -158,6 +167,7 @@ def test_run_init_params(): filters={"from": "init"}, top_k=11, custom_query="custom_query", + efficient_filtering=True, ) assert len(res) == 1 assert len(res["documents"]) == 1 @@ -169,12 +179,13 @@ def test_run_time_params(): mock_store = Mock(spec=OpenSearchDocumentStore) mock_store._embedding_retrieval.return_value = [Document(content="Test doc", embedding=[0.1, 0.2])] retriever = OpenSearchEmbeddingRetriever(document_store=mock_store, filters={"from": "init"}, top_k=11) - res = retriever.run(query_embedding=[0.5, 0.7], filters={"from": "run"}, top_k=9) + res = retriever.run(query_embedding=[0.5, 0.7], filters={"from": "run"}, top_k=9, efficient_filtering=True) mock_store._embedding_retrieval.assert_called_once_with( query_embedding=[0.5, 0.7], filters={"from": "run"}, top_k=9, custom_query=None, + efficient_filtering=True, ) assert len(res) == 1 assert len(res["documents"]) == 1