From 252d27df70f90c933fd9a870b5293da04370f29d Mon Sep 17 00:00:00 2001 From: Trivan Menezes <47679108+ttmenezes@users.noreply.github.com> Date: Mon, 11 Nov 2024 15:16:07 -0800 Subject: [PATCH 01/12] new azure retrievers --- .../retrievers/azure_ai_search/__init__.py | 4 +- .../azure_ai_search/bm25_retriever.py | 116 +++++++++++++++++ .../azure_ai_search/hybrid_retriever.py | 120 ++++++++++++++++++ .../azure_ai_search/document_store.py | 52 ++++++++ 4 files changed, 291 insertions(+), 1 deletion(-) create mode 100644 integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/bm25_retriever.py create mode 100644 integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py diff --git a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/__init__.py b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/__init__.py index eb75ffa6c..eebe990f3 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/__init__.py +++ b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/__init__.py @@ -1,3 +1,5 @@ from .embedding_retriever import AzureAISearchEmbeddingRetriever +from .bm25_retriever import AzureAISearchBM25Retriever +from .hybrid_retriever import AzureAISearchHybridRetriever -__all__ = ["AzureAISearchEmbeddingRetriever"] +__all__ = ["AzureAISearchEmbeddingRetriever", "AzureAISearchBM25Retriever", "AzureAISearchHybridRetriever"] diff --git a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/bm25_retriever.py b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/bm25_retriever.py new file mode 100644 index 000000000..65e273b73 --- /dev/null +++ b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/bm25_retriever.py @@ -0,0 +1,116 @@ +import logging +from typing import Any, Dict, List, Optional, Union + +from haystack import Document, component, default_from_dict, default_to_dict +from haystack.document_stores.types import FilterPolicy +from haystack.document_stores.types.filter_policy import apply_filter_policy + +from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore, normalize_filters + +logger = logging.getLogger(__name__) + + +@component +class AzureAISearchBM25Retriever: + """ + Retrieves documents from the AzureAISearchDocumentStore using BM25 retrieval. + Must be connected to the AzureAISearchDocumentStore to run. + + """ + + def __init__( + self, + *, + document_store: AzureAISearchDocumentStore, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE, + ): + """ + Create the AzureAISearchBM25Retriever component. + + :param document_store: An instance of AzureAISearchDocumentStore to use with the Retriever. + :param filters: Filters applied when fetching documents from the Document Store. + Filters are applied during the BM25 search to ensure the Retriever returns + `top_k` matching documents. + :param top_k: Maximum number of documents to return. + :filter_policy: Policy to determine how filters are applied. Possible options: + + """ + self._filters = filters or {} + self._top_k = top_k + self._document_store = document_store + self._filter_policy = ( + filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) + ) + + if not isinstance(document_store, AzureAISearchDocumentStore): + message = "document_store must be an instance of AzureAISearchDocumentStore" + raise Exception(message) + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + return default_to_dict( + self, + filters=self._filters, + top_k=self._top_k, + document_store=self._document_store.to_dict(), + filter_policy=self._filter_policy.value, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "AzureAISearchBM25Retriever": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + + :returns: + Deserialized component. + """ + data["init_parameters"]["document_store"] = AzureAISearchDocumentStore.from_dict( + data["init_parameters"]["document_store"] + ) + + # Pipelines serialized with old versions of the component might not + # have the filter_policy field. + if "filter_policy" in data["init_parameters"]: + data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(data["init_parameters"]["filter_policy"]) + return default_from_dict(cls, data) + + @component.output_types(documents=List[Document]) + def run(self, query: str, filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None): + """Retrieve documents from the AzureAISearchDocumentStore. + + :param query: Text of the query. + :param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on + the `filter_policy` chosen at retriever initialization. See init method docstring for more + details. + :param top_k: the maximum number of documents to retrieve. + :returns: a dictionary with the following keys: + - `documents`: A list of documents retrieved from the AzureAISearchDocumentStore. + """ + + top_k = top_k or self._top_k + if filters is not None: + applied_filters = apply_filter_policy(self._filter_policy, self._filters, filters) + normalized_filters = normalize_filters(applied_filters) + else: + normalized_filters = "" + + try: + docs = self._document_store._bm25_retrieval( + query=query, + filters=normalized_filters, + top_k=top_k, + ) + except Exception as e: + raise e + + return {"documents": docs} diff --git a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py new file mode 100644 index 000000000..eb28c4e73 --- /dev/null +++ b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py @@ -0,0 +1,120 @@ +import logging +from typing import Any, Dict, List, Optional, Union + +from haystack import Document, component, default_from_dict, default_to_dict +from haystack.document_stores.types import FilterPolicy +from haystack.document_stores.types.filter_policy import apply_filter_policy + +from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore, normalize_filters + +logger = logging.getLogger(__name__) + + +@component +class AzureAISearchHybridRetriever: + """ + Retrieves documents from the AzureAISearchDocumentStore using a hybrid (vector + BM25) retrieval. + Must be connected to the AzureAISearchDocumentStore to run. + + """ + + def __init__( + self, + *, + document_store: AzureAISearchDocumentStore, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE, + ): + """ + Create the AzureAISearchHybridRetriever component. + + :param document_store: An instance of AzureAISearchDocumentStore to use with the Retriever. + :param filters: Filters applied when fetching documents from the Document Store. + Filters are applied during the hybrid search to ensure the Retriever returns + `top_k` matching documents. + :param top_k: Maximum number of documents to return. + :filter_policy: Policy to determine how filters are applied. Possible options: + + """ + self._filters = filters or {} + self._top_k = top_k + self._document_store = document_store + self._filter_policy = ( + filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) + ) + + if not isinstance(document_store, AzureAISearchDocumentStore): + message = "document_store must be an instance of AzureAISearchDocumentStore" + raise Exception(message) + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + return default_to_dict( + self, + filters=self._filters, + top_k=self._top_k, + document_store=self._document_store.to_dict(), + filter_policy=self._filter_policy.value, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "AzureAISearchHybridRetriever": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + + :returns: + Deserialized component. + """ + data["init_parameters"]["document_store"] = AzureAISearchDocumentStore.from_dict( + data["init_parameters"]["document_store"] + ) + + # Pipelines serialized with old versions of the component might not + # have the filter_policy field. + if "filter_policy" in data["init_parameters"]: + data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(data["init_parameters"]["filter_policy"]) + return default_from_dict(cls, data) + + @component.output_types(documents=List[Document]) + def run( + self, query: str, query_embedding: List[float], filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None + ): + """Retrieve documents from the AzureAISearchDocumentStore. + + :param query: Text of the query. + :param query_embedding: floats representing the query embedding + :param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on + the `filter_policy` chosen at retriever initialization. See init method docstring for more + details. + :param top_k: the maximum number of documents to retrieve. + :returns: a dictionary with the following keys: + - `documents`: A list of documents retrieved from the AzureAISearchDocumentStore. + """ + + top_k = top_k or self._top_k + if filters is not None: + applied_filters = apply_filter_policy(self._filter_policy, self._filters, filters) + normalized_filters = normalize_filters(applied_filters) + else: + normalized_filters = "" + + try: + docs = self._document_store._hybrid_retrieval( + query=query, + query_embedding=query_embedding, + filters=normalized_filters, + top_k=top_k, + ) + except Exception as e: + raise e + + return {"documents": docs} diff --git a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py index 0b59b6e37..4e3b3b44a 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py +++ b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py @@ -438,3 +438,55 @@ def _embedding_retrieval( result = self.client.search(search_text=None, vector_queries=[vector_query], select=fields, filter=filters) azure_docs = list(result) return self._convert_search_result_to_documents(azure_docs) + + def _bm25_retrieval( + self, + query: str, + top_k: int = 10, + fields: Optional[List[str]] = None, + filters: Optional[Dict[str, Any]] = None, + ) -> List[Document]: + """ + Retrieves documents that are most similar to `query`, using the BM25 algorithm + + This method is not meant to be part of the public interface of + `AzureAISearchDocumentStore` nor called directly. + `AzureAISearchBM25Retriever` uses this method directly and is the public interface for it. + + :param query: Text of the query. + :param filters: Filters applied to the retrieved Documents. Defaults to None. + Filters are applied during the approximate kNN search to ensure that top_k matching documents are returned. + :param top_k: Maximum number of Documents to return, defaults to 10 + + :raises ValueError: If `query` is an empty string + :returns: List of Document that are most similar to `query` + """ + + if query is None: + msg = "query must not be None" + raise ValueError(msg) + + result = self.client.search(search_text=query, select=fields, filter=filters, top=top_k, query_type="simple") + azure_docs = list(result) + return self._convert_search_result_to_documents(azure_docs) + + def _hybrid_retrieval( + self, + query: str, + query_embedding: List[float], + top_k: int = 10, + fields: Optional[List[str]] = None, + filters: Optional[Dict[str, Any]] = None, + ) -> List[Document]: + + if query is None: + msg = "query must not be None" + raise ValueError(msg) + if not query_embedding: + msg = "query_embedding must be a non-empty list of floats" + raise ValueError(msg) + + vector_query = VectorizedQuery(vector=query_embedding, k_nearest_neighbors=top_k, fields="embedding") + result = self.client.search(search_text=query, vector_queries=[vector_query], select=fields, filter=filters, top=top_k, query_type="simple") + azure_docs = list(result) + return self._convert_search_result_to_documents(azure_docs) \ No newline at end of file From 0f38ea8186f1316e9f121b35c2705ed3a6c69882 Mon Sep 17 00:00:00 2001 From: Trivan Menezes <47679108+ttmenezes@users.noreply.github.com> Date: Mon, 11 Nov 2024 15:34:18 -0800 Subject: [PATCH 02/12] fix styling --- .../retrievers/azure_ai_search/__init__.py | 4 ++-- .../retrievers/azure_ai_search/hybrid_retriever.py | 6 +++++- .../azure_ai_search/document_store.py | 12 +++++++++--- 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/__init__.py b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/__init__.py index eebe990f3..56dc30db4 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/__init__.py +++ b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/__init__.py @@ -1,5 +1,5 @@ -from .embedding_retriever import AzureAISearchEmbeddingRetriever from .bm25_retriever import AzureAISearchBM25Retriever +from .embedding_retriever import AzureAISearchEmbeddingRetriever from .hybrid_retriever import AzureAISearchHybridRetriever -__all__ = ["AzureAISearchEmbeddingRetriever", "AzureAISearchBM25Retriever", "AzureAISearchHybridRetriever"] +__all__ = ["AzureAISearchBM25Retriever", "AzureAISearchEmbeddingRetriever", "AzureAISearchHybridRetriever"] diff --git a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py index eb28c4e73..77cc0c586 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py +++ b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py @@ -86,7 +86,11 @@ def from_dict(cls, data: Dict[str, Any]) -> "AzureAISearchHybridRetriever": @component.output_types(documents=List[Document]) def run( - self, query: str, query_embedding: List[float], filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None + self, + query: str, + query_embedding: List[float], + filters: Optional[Dict[str, Any]] = None, + top_k: Optional[int] = None ): """Retrieve documents from the AzureAISearchDocumentStore. diff --git a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py index 4e3b3b44a..737a0bd11 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py +++ b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py @@ -438,7 +438,7 @@ def _embedding_retrieval( result = self.client.search(search_text=None, vector_queries=[vector_query], select=fields, filter=filters) azure_docs = list(result) return self._convert_search_result_to_documents(azure_docs) - + def _bm25_retrieval( self, query: str, @@ -487,6 +487,12 @@ def _hybrid_retrieval( raise ValueError(msg) vector_query = VectorizedQuery(vector=query_embedding, k_nearest_neighbors=top_k, fields="embedding") - result = self.client.search(search_text=query, vector_queries=[vector_query], select=fields, filter=filters, top=top_k, query_type="simple") + result = self.client.search( + search_text=query, + vector_queries=[vector_query], + select=fields, filter=filters, + top=top_k, query_type="simple" + ) azure_docs = list(result) - return self._convert_search_result_to_documents(azure_docs) \ No newline at end of file + return self._convert_search_result_to_documents(azure_docs) + \ No newline at end of file From 41672f7a81e685489dc5942cef8afb5b0b674fcc Mon Sep 17 00:00:00 2001 From: Trivan Menezes <47679108+ttmenezes@users.noreply.github.com> Date: Mon, 11 Nov 2024 15:39:16 -0800 Subject: [PATCH 03/12] fix whitespace --- .../retrievers/azure_ai_search/hybrid_retriever.py | 8 ++++---- .../document_stores/azure_ai_search/document_store.py | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py index 77cc0c586..ace032a8d 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py +++ b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py @@ -86,10 +86,10 @@ def from_dict(cls, data: Dict[str, Any]) -> "AzureAISearchHybridRetriever": @component.output_types(documents=List[Document]) def run( - self, - query: str, - query_embedding: List[float], - filters: Optional[Dict[str, Any]] = None, + self, + query: str, + query_embedding: List[float], + filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None ): """Retrieve documents from the AzureAISearchDocumentStore. diff --git a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py index 737a0bd11..3f4427909 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py +++ b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py @@ -488,9 +488,9 @@ def _hybrid_retrieval( vector_query = VectorizedQuery(vector=query_embedding, k_nearest_neighbors=top_k, fields="embedding") result = self.client.search( - search_text=query, - vector_queries=[vector_query], - select=fields, filter=filters, + search_text=query, + vector_queries=[vector_query], + select=fields, filter=filters, top=top_k, query_type="simple" ) azure_docs = list(result) From 914d27fbd77d94eb7a5151d2c3d8557800d85bcd Mon Sep 17 00:00:00 2001 From: Trivan Menezes <47679108+ttmenezes@users.noreply.github.com> Date: Mon, 11 Nov 2024 15:40:56 -0800 Subject: [PATCH 04/12] fix whitespace --- .../document_stores/azure_ai_search/document_store.py | 1 - 1 file changed, 1 deletion(-) diff --git a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py index 3f4427909..05d92f523 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py +++ b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py @@ -495,4 +495,3 @@ def _hybrid_retrieval( ) azure_docs = list(result) return self._convert_search_result_to_documents(azure_docs) - \ No newline at end of file From ee41bf0e3660457a2ef22f8919f6f4fdf7d7f68d Mon Sep 17 00:00:00 2001 From: Trivan Menezes <47679108+ttmenezes@users.noreply.github.com> Date: Mon, 11 Nov 2024 15:47:10 -0800 Subject: [PATCH 05/12] fix linting --- .../retrievers/azure_ai_search/hybrid_retriever.py | 2 +- .../document_stores/azure_ai_search/document_store.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py index ace032a8d..fbe7752aa 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py +++ b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py @@ -90,7 +90,7 @@ def run( query: str, query_embedding: List[float], filters: Optional[Dict[str, Any]] = None, - top_k: Optional[int] = None + top_k: Optional[int] = None, ): """Retrieve documents from the AzureAISearchDocumentStore. diff --git a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py index 05d92f523..68a2db22e 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py +++ b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py @@ -490,8 +490,10 @@ def _hybrid_retrieval( result = self.client.search( search_text=query, vector_queries=[vector_query], - select=fields, filter=filters, - top=top_k, query_type="simple" + select=fields, + filter=filters, + top=top_k, + query_type="simple", ) azure_docs = list(result) return self._convert_search_result_to_documents(azure_docs) From 1cb63f489b4a5be63b1140f37bf2e4e3abe7896e Mon Sep 17 00:00:00 2001 From: Trivan Menezes <47679108+ttmenezes@users.noreply.github.com> Date: Wed, 13 Nov 2024 11:49:41 -0800 Subject: [PATCH 06/12] tests for bm25 and hybrid retrievers --- .../tests/test_bm25_retriever.py | 128 ++++++++++++++++ .../tests/test_hybrid_retriever.py | 145 ++++++++++++++++++ 2 files changed, 273 insertions(+) create mode 100644 integrations/azure_ai_search/tests/test_bm25_retriever.py create mode 100644 integrations/azure_ai_search/tests/test_hybrid_retriever.py diff --git a/integrations/azure_ai_search/tests/test_bm25_retriever.py b/integrations/azure_ai_search/tests/test_bm25_retriever.py new file mode 100644 index 000000000..d0c6d0da9 --- /dev/null +++ b/integrations/azure_ai_search/tests/test_bm25_retriever.py @@ -0,0 +1,128 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import os +from typing import List +from unittest.mock import Mock + +import pytest +from azure.core.exceptions import HttpResponseError +from haystack.dataclasses import Document +from haystack.document_stores.types import FilterPolicy +from numpy.random import rand # type: ignore + +from haystack_integrations.components.retrievers.azure_ai_search import AzureAISearchBM25Retriever +from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore + + +def test_init_default(): + mock_store = Mock(spec=AzureAISearchDocumentStore) + retriever = AzureAISearchBM25Retriever(document_store=mock_store) + assert retriever._document_store == mock_store + assert retriever._filters == {} + assert retriever._top_k == 10 + assert retriever._filter_policy == FilterPolicy.REPLACE + + retriever = AzureAISearchBM25Retriever(document_store=mock_store, filter_policy="replace") + assert retriever._filter_policy == FilterPolicy.REPLACE + + with pytest.raises(ValueError): + AzureAISearchBM25Retriever(document_store=mock_store, filter_policy="unknown") + + +def test_to_dict(): + document_store = AzureAISearchDocumentStore(hosts="some fake host") + retriever = AzureAISearchBM25Retriever(document_store=document_store) + res = retriever.to_dict() + assert res == { + "type": "haystack_integrations.components.retrievers.azure_ai_search.bm25_retriever.AzureAISearchBM25Retriever", # noqa: E501 + "init_parameters": { + "filters": {}, + "top_k": 10, + "document_store": { + "type": "haystack_integrations.document_stores.azure_ai_search.document_store.AzureAISearchDocumentStore", # noqa: E501 + "init_parameters": { + "azure_endpoint": { + "type": "env_var", + "env_vars": ["AZURE_SEARCH_SERVICE_ENDPOINT"], + "strict": True, + }, + "api_key": {"type": "env_var", "env_vars": ["AZURE_SEARCH_API_KEY"], "strict": False}, + "index_name": "default", + "embedding_dimension": 768, + "metadata_fields": None, + "vector_search_configuration": { + "profiles": [ + {"name": "default-vector-config", "algorithm_configuration_name": "cosine-algorithm-config"} + ], + "algorithms": [ + { + "name": "cosine-algorithm-config", + "kind": "hnsw", + "parameters": {"m": 4, "ef_construction": 400, "ef_search": 500, "metric": "cosine"}, + } + ], + }, + "hosts": "some fake host", + }, + }, + "filter_policy": "replace", + }, + } + + +def test_from_dict(): + data = { + "type": "haystack_integrations.components.retrievers.azure_ai_search.bm25_retriever.AzureAISearchBM25Retriever", # noqa: E501 + "init_parameters": { + "filters": {}, + "top_k": 10, + "document_store": { + "type": "haystack_integrations.document_stores.azure_ai_search.document_store.AzureAISearchDocumentStore", # noqa: E501 + "init_parameters": { + "azure_endpoint": { + "type": "env_var", + "env_vars": ["AZURE_SEARCH_SERVICE_ENDPOINT"], + "strict": True, + }, + "api_key": {"type": "env_var", "env_vars": ["AZURE_SEARCH_API_KEY"], "strict": False}, + "index_name": "default", + "metadata_fields": None, + "hosts": "some fake host", + }, + }, + "filter_policy": "replace", + }, + } + retriever = AzureAISearchBM25Retriever.from_dict(data) + assert isinstance(retriever._document_store, AzureAISearchDocumentStore) + assert retriever._filters == {} + assert retriever._top_k == 10 + assert retriever._filter_policy == FilterPolicy.REPLACE + + +@pytest.mark.skipif( + not os.environ.get("AZURE_SEARCH_SERVICE_ENDPOINT", None) and not os.environ.get("AZURE_SEARCH_API_KEY", None), + reason="Missing AZURE_SEARCH_SERVICE_ENDPOINT or AZURE_SEARCH_API_KEY.", +) +@pytest.mark.integration +class TestRetriever: + + def test_run(self, document_store: AzureAISearchDocumentStore): + docs = [Document(id="1", content="Test document")] + document_store.write_documents(docs) + retriever = AzureAISearchBM25Retriever(document_store=document_store) + res = retriever.run(query="Test document") + assert res["documents"] == docs + + def test_document_retrieval(self, document_store: AzureAISearchDocumentStore): + docs = [ + Document(content="This is first document"), + Document(content="This is second document"), + Document(content="This is third document"), + ] + + document_store.write_documents(docs) + retriever = AzureAISearchBM25Retriever(document_store=document_store) + results = retriever.run(query="This is first document") + assert results["documents"][0].content == "This is first document" diff --git a/integrations/azure_ai_search/tests/test_hybrid_retriever.py b/integrations/azure_ai_search/tests/test_hybrid_retriever.py new file mode 100644 index 000000000..2447949fd --- /dev/null +++ b/integrations/azure_ai_search/tests/test_hybrid_retriever.py @@ -0,0 +1,145 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import os +from typing import List +from unittest.mock import Mock + +import pytest +from azure.core.exceptions import HttpResponseError +from haystack.dataclasses import Document +from haystack.document_stores.types import FilterPolicy +from numpy.random import rand # type: ignore + +from haystack_integrations.components.retrievers.azure_ai_search import AzureAISearchHybridRetriever +from haystack_integrations.document_stores.azure_ai_search import DEFAULT_VECTOR_SEARCH, AzureAISearchDocumentStore + + +def test_init_default(): + mock_store = Mock(spec=AzureAISearchDocumentStore) + retriever = AzureAISearchHybridRetriever(document_store=mock_store) + assert retriever._document_store == mock_store + assert retriever._filters == {} + assert retriever._top_k == 10 + assert retriever._filter_policy == FilterPolicy.REPLACE + + retriever = AzureAISearchHybridRetriever(document_store=mock_store, filter_policy="replace") + assert retriever._filter_policy == FilterPolicy.REPLACE + + with pytest.raises(ValueError): + AzureAISearchHybridRetriever(document_store=mock_store, filter_policy="unknown") + + +def test_to_dict(): + document_store = AzureAISearchDocumentStore(hosts="some fake host") + retriever = AzureAISearchHybridRetriever(document_store=document_store) + res = retriever.to_dict() + assert res == { + "type": "haystack_integrations.components.retrievers.azure_ai_search.hybrid_retriever.AzureAISearchHybridRetriever", # noqa: E501 + "init_parameters": { + "filters": {}, + "top_k": 10, + "document_store": { + "type": "haystack_integrations.document_stores.azure_ai_search.document_store.AzureAISearchDocumentStore", # noqa: E501 + "init_parameters": { + "azure_endpoint": { + "type": "env_var", + "env_vars": ["AZURE_SEARCH_SERVICE_ENDPOINT"], + "strict": True, + }, + "api_key": {"type": "env_var", "env_vars": ["AZURE_SEARCH_API_KEY"], "strict": False}, + "index_name": "default", + "embedding_dimension": 768, + "metadata_fields": None, + "vector_search_configuration": { + "profiles": [ + {"name": "default-vector-config", "algorithm_configuration_name": "cosine-algorithm-config"} + ], + "algorithms": [ + { + "name": "cosine-algorithm-config", + "kind": "hnsw", + "parameters": {"m": 4, "ef_construction": 400, "ef_search": 500, "metric": "cosine"}, + } + ], + }, + "hosts": "some fake host", + }, + }, + "filter_policy": "replace", + }, + } + + +def test_from_dict(): + data = { + "type": "haystack_integrations.components.retrievers.azure_ai_search.hybrid_retriever.AzureAISearchHybridRetriever", # noqa: E501 + "init_parameters": { + "filters": {}, + "top_k": 10, + "document_store": { + "type": "haystack_integrations.document_stores.azure_ai_search.document_store.AzureAISearchDocumentStore", # noqa: E501 + "init_parameters": { + "azure_endpoint": { + "type": "env_var", + "env_vars": ["AZURE_SEARCH_SERVICE_ENDPOINT"], + "strict": True, + }, + "api_key": {"type": "env_var", "env_vars": ["AZURE_SEARCH_API_KEY"], "strict": False}, + "index_name": "default", + "embedding_dimension": 768, + "metadata_fields": None, + "vector_search_configuration": DEFAULT_VECTOR_SEARCH, + "hosts": "some fake host", + }, + }, + "filter_policy": "replace", + }, + } + retriever = AzureAISearchHybridRetriever.from_dict(data) + assert isinstance(retriever._document_store, AzureAISearchDocumentStore) + assert retriever._filters == {} + assert retriever._top_k == 10 + assert retriever._filter_policy == FilterPolicy.REPLACE + + +@pytest.mark.skipif( + not os.environ.get("AZURE_SEARCH_SERVICE_ENDPOINT", None) and not os.environ.get("AZURE_SEARCH_API_KEY", None), + reason="Missing AZURE_SEARCH_SERVICE_ENDPOINT or AZURE_SEARCH_API_KEY.", +) +@pytest.mark.integration +class TestRetriever: + + def test_run(self, document_store: AzureAISearchDocumentStore): + docs = [Document(id="1")] + document_store.write_documents(docs) + retriever = AzureAISearchHybridRetriever(document_store=document_store) + res = retriever.run(query="Test document", query_embedding=[0.1] * 768) + assert res["documents"] == docs + + def test_embedding_retrieval(self, document_store: AzureAISearchDocumentStore): + query_embedding = [0.1] * 768 + most_similar_embedding = [0.8] * 768 + second_best_embedding = [0.8] * 200 + [0.1] * 300 + [0.2] * 268 + another_embedding = rand(768).tolist() + + docs = [ + Document(content="This is first document", embedding=most_similar_embedding), + Document(content="This is second document", embedding=second_best_embedding), + Document(content="This is third document", embedding=another_embedding), + ] + + document_store.write_documents(docs) + retriever = AzureAISearchHybridRetriever(document_store=document_store) + results = retriever.run(query="This is first document", query_embedding=query_embedding) + assert results["documents"][0].content == "This is first document" + + def test_empty_query_embedding(self, document_store: AzureAISearchDocumentStore): + query_embedding: List[float] = [] + with pytest.raises(ValueError): + document_store._hybrid_retrieval(query="", query_embedding=query_embedding) + + def test_query_embedding_wrong_dimension(self, document_store: AzureAISearchDocumentStore): + query_embedding = [0.1] * 4 + with pytest.raises(HttpResponseError): + document_store._hybrid_retrieval(query="", query_embedding=query_embedding) From 7b197d273cb7cff09acc22aba4ed90aa1d065198 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Fri, 15 Nov 2024 16:45:00 +0100 Subject: [PATCH 07/12] Enable kwargs for semantic ranking in retrievers --- .../azure_ai_search/bm25_retriever.py | 26 ++++++++-- .../azure_ai_search/hybrid_retriever.py | 37 +++++++++----- .../azure_ai_search/document_store.py | 48 ++++++++++++------- .../tests/test_bm25_retriever.py | 7 +-- 4 files changed, 80 insertions(+), 38 deletions(-) diff --git a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/bm25_retriever.py b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/bm25_retriever.py index 65e273b73..476144545 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/bm25_retriever.py +++ b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/bm25_retriever.py @@ -5,7 +5,7 @@ from haystack.document_stores.types import FilterPolicy from haystack.document_stores.types.filter_policy import apply_filter_policy -from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore, normalize_filters +from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore, _normalize_filters logger = logging.getLogger(__name__) @@ -25,6 +25,7 @@ def __init__( filters: Optional[Dict[str, Any]] = None, top_k: int = 10, filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE, + **kwargs, ): """ Create the AzureAISearchBM25Retriever component. @@ -34,7 +35,16 @@ def __init__( Filters are applied during the BM25 search to ensure the Retriever returns `top_k` matching documents. :param top_k: Maximum number of documents to return. - :filter_policy: Policy to determine how filters are applied. Possible options: + :filter_policy: Policy to determine how filters are applied. + :param kwargs: Additional keyword arguments to pass to the Azure AI's search endpoint. + Some of the supported parameters: + - `query_type`: A string indicating the type of query to perform. Possible values are + 'simple','full' and 'semantic'. + - `semantic_configuration_name`: The name of semantic configuration to be used when + processing semantic queries. + For more information on parameters, see the + [official Azure AI Search documentation](https://learn.microsoft.com/en-us/azure/search/). + """ self._filters = filters or {} @@ -43,7 +53,7 @@ def __init__( self._filter_policy = ( filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) ) - + self._kwargs = kwargs if not isinstance(document_store, AzureAISearchDocumentStore): message = "document_store must be an instance of AzureAISearchDocumentStore" raise Exception(message) @@ -61,6 +71,7 @@ def to_dict(self) -> Dict[str, Any]: top_k=self._top_k, document_store=self._document_store.to_dict(), filter_policy=self._filter_policy.value, + **self._kwargs, ) @classmethod @@ -100,7 +111,7 @@ def run(self, query: str, filters: Optional[Dict[str, Any]] = None, top_k: Optio top_k = top_k or self._top_k if filters is not None: applied_filters = apply_filter_policy(self._filter_policy, self._filters, filters) - normalized_filters = normalize_filters(applied_filters) + normalized_filters = _normalize_filters(applied_filters) else: normalized_filters = "" @@ -109,8 +120,13 @@ def run(self, query: str, filters: Optional[Dict[str, Any]] = None, top_k: Optio query=query, filters=normalized_filters, top_k=top_k, + **self._kwargs, ) except Exception as e: - raise e + msg = ( + "An error occurred during the bm25 retrieval process from the AzureAISearchDocumentStore. " + "Ensure that the query is valid and the document store is correctly configured." + ) + raise RuntimeError(msg) from e return {"documents": docs} diff --git a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py index fbe7752aa..ce0a17e2e 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py +++ b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py @@ -5,7 +5,7 @@ from haystack.document_stores.types import FilterPolicy from haystack.document_stores.types.filter_policy import apply_filter_policy -from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore, normalize_filters +from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore, _normalize_filters logger = logging.getLogger(__name__) @@ -25,6 +25,7 @@ def __init__( filters: Optional[Dict[str, Any]] = None, top_k: int = 10, filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE, + **kwargs, ): """ Create the AzureAISearchHybridRetriever component. @@ -34,7 +35,16 @@ def __init__( Filters are applied during the hybrid search to ensure the Retriever returns `top_k` matching documents. :param top_k: Maximum number of documents to return. - :filter_policy: Policy to determine how filters are applied. Possible options: + :filter_policy: Policy to determine how filters are applied. + :param kwargs: Additional keyword arguments to pass to the Azure AI's search endpoint. + Some of the supported parameters: + - `query_type`: A string indicating the type of query to perform. Possible values are + 'simple','full' and 'semantic'. + - `semantic_configuration_name`: The name of semantic configuration to be used when + processing semantic queries. + For more information on parameters, see the + [official Azure AI Search documentation](https://learn.microsoft.com/en-us/azure/search/). + """ self._filters = filters or {} @@ -43,6 +53,7 @@ def __init__( self._filter_policy = ( filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) ) + self._kwargs = kwargs if not isinstance(document_store, AzureAISearchDocumentStore): message = "document_store must be an instance of AzureAISearchDocumentStore" @@ -61,6 +72,7 @@ def to_dict(self) -> Dict[str, Any]: top_k=self._top_k, document_store=self._document_store.to_dict(), filter_policy=self._filter_policy.value, + **self._kwargs, ) @classmethod @@ -95,30 +107,31 @@ def run( """Retrieve documents from the AzureAISearchDocumentStore. :param query: Text of the query. - :param query_embedding: floats representing the query embedding + :param query_embedding: A list of floats representing the query embedding :param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on - the `filter_policy` chosen at retriever initialization. See init method docstring for more + the `filter_policy` chosen at retriever initialization. See `__init__` method docstring for more details. - :param top_k: the maximum number of documents to retrieve. - :returns: a dictionary with the following keys: + :param top_k: The maximum number of documents to retrieve. + :returns: A dictionary with the following keys: - `documents`: A list of documents retrieved from the AzureAISearchDocumentStore. """ top_k = top_k or self._top_k if filters is not None: applied_filters = apply_filter_policy(self._filter_policy, self._filters, filters) - normalized_filters = normalize_filters(applied_filters) + normalized_filters = _normalize_filters(applied_filters) else: normalized_filters = "" try: docs = self._document_store._hybrid_retrieval( - query=query, - query_embedding=query_embedding, - filters=normalized_filters, - top_k=top_k, + query=query, query_embedding=query_embedding, filters=normalized_filters, top_k=top_k, **self._kwargs ) except Exception as e: - raise e + msg = ( + "An error occurred during the hybrid retrieval process from the AzureAISearchDocumentStore. " + "Ensure that the query and query_embedding are valid and the document store is correctly configured." + ) + raise RuntimeError(msg) from e return {"documents": docs} diff --git a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py index 00682baba..cf0657495 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py +++ b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py @@ -421,7 +421,7 @@ def _embedding_retrieval( ) -> List[Document]: """ Retrieves documents that are most similar to the query embedding using a vector similarity metric. - It uses the vector configuration of the document store. By default it uses the HNSW algorithm + It uses the vector configuration specified in the document store. By default, it uses the HNSW algorithm with cosine similarity. This method is not meant to be part of the public interface of @@ -429,13 +429,12 @@ def _embedding_retrieval( `AzureAISearchEmbeddingRetriever` uses this method directly and is the public interface for it. :param query_embedding: Embedding of the query. - :param top_k: Maximum number of Documents to return, defaults to 10. - :param filters: Filters applied to the retrieved Documents. Defaults to None. - Filters are applied during the approximate kNN search to ensure that top_k matching documents are returned. + :param top_k: Maximum number of Documents to return. + :param filters: Filters applied to the retrieved Documents. :param kwargs: Optional keyword arguments to pass to the Azure AI's search endpoint. - :raises ValueError: If `query_embedding` is an empty list - :returns: List of Document that are most similar to `query_embedding` + :raises ValueError: If `query_embedding` is an empty list. + :returns: List of Document that are most similar to `query_embedding`. """ if not query_embedding: @@ -451,30 +450,31 @@ def _bm25_retrieval( self, query: str, top_k: int = 10, - fields: Optional[List[str]] = None, filters: Optional[Dict[str, Any]] = None, + **kwargs, ) -> List[Document]: """ - Retrieves documents that are most similar to `query`, using the BM25 algorithm + Retrieves documents that are most similar to `query`, using the BM25 algorithm. This method is not meant to be part of the public interface of `AzureAISearchDocumentStore` nor called directly. `AzureAISearchBM25Retriever` uses this method directly and is the public interface for it. :param query: Text of the query. - :param filters: Filters applied to the retrieved Documents. Defaults to None. - Filters are applied during the approximate kNN search to ensure that top_k matching documents are returned. - :param top_k: Maximum number of Documents to return, defaults to 10 + :param filters: Filters applied to the retrieved Documents. + :param top_k: Maximum number of Documents to return. + :param kwargs: Optional keyword arguments to pass to the Azure AI's search endpoint. - :raises ValueError: If `query` is an empty string - :returns: List of Document that are most similar to `query` + + :raises ValueError: If `query` is an empty string. + :returns: List of Document that are most similar to `query`. """ if query is None: msg = "query must not be None" raise ValueError(msg) - result = self.client.search(search_text=query, select=fields, filter=filters, top=top_k, query_type="simple") + result = self.client.search(search_text=query, filter=filters, top=top_k, query_type="simple", **kwargs) azure_docs = list(result) return self._convert_search_result_to_documents(azure_docs) @@ -483,9 +483,25 @@ def _hybrid_retrieval( query: str, query_embedding: List[float], top_k: int = 10, - fields: Optional[List[str]] = None, filters: Optional[Dict[str, Any]] = None, + **kwargs, ) -> List[Document]: + """ + Retrieves documents similar to query using the vector configuration in the document store and + the BM25 algorithm. This method combines vector similarity and BM25 for improved retrieval. + + This method is not meant to be part of the public interface of + `AzureAISearchDocumentStore` nor called directly. + `AzureAISearchHybridRetriever` uses this method directly and is the public interface for it. + + :param query: Text of the query. + :param filters: Filters applied to the retrieved Documents. + :param top_k: Maximum number of Documents to return. + :param kwargs: Optional keyword arguments to pass to the Azure AI's search endpoint. + + :raises ValueError: If `query` or `query_embedding` is empty. + :returns: List of Document that are most similar to `query`. + """ if query is None: msg = "query must not be None" @@ -498,10 +514,10 @@ def _hybrid_retrieval( result = self.client.search( search_text=query, vector_queries=[vector_query], - select=fields, filter=filters, top=top_k, query_type="simple", + **kwargs, ) azure_docs = list(result) return self._convert_search_result_to_documents(azure_docs) diff --git a/integrations/azure_ai_search/tests/test_bm25_retriever.py b/integrations/azure_ai_search/tests/test_bm25_retriever.py index d0c6d0da9..e6631a16b 100644 --- a/integrations/azure_ai_search/tests/test_bm25_retriever.py +++ b/integrations/azure_ai_search/tests/test_bm25_retriever.py @@ -2,14 +2,11 @@ # # SPDX-License-Identifier: Apache-2.0 import os -from typing import List from unittest.mock import Mock import pytest -from azure.core.exceptions import HttpResponseError from haystack.dataclasses import Document from haystack.document_stores.types import FilterPolicy -from numpy.random import rand # type: ignore from haystack_integrations.components.retrievers.azure_ai_search import AzureAISearchBM25Retriever from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore @@ -35,7 +32,7 @@ def test_to_dict(): retriever = AzureAISearchBM25Retriever(document_store=document_store) res = retriever.to_dict() assert res == { - "type": "haystack_integrations.components.retrievers.azure_ai_search.bm25_retriever.AzureAISearchBM25Retriever", # noqa: E501 + "type": "haystack_integrations.components.retrievers.azure_ai_search.bm25_retriever.AzureAISearchBM25Retriever", "init_parameters": { "filters": {}, "top_k": 10, @@ -73,7 +70,7 @@ def test_to_dict(): def test_from_dict(): data = { - "type": "haystack_integrations.components.retrievers.azure_ai_search.bm25_retriever.AzureAISearchBM25Retriever", # noqa: E501 + "type": "haystack_integrations.components.retrievers.azure_ai_search.bm25_retriever.AzureAISearchBM25Retriever", "init_parameters": { "filters": {}, "top_k": 10, From 5ac7cbec5c5e0c470ae99581e08c80fae9a21258 Mon Sep 17 00:00:00 2001 From: Trivan Menezes <47679108+ttmenezes@users.noreply.github.com> Date: Mon, 18 Nov 2024 08:30:38 -0800 Subject: [PATCH 08/12] Update integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/bm25_retriever.py Co-authored-by: Julian Risch --- .../components/retrievers/azure_ai_search/bm25_retriever.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/bm25_retriever.py b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/bm25_retriever.py index 476144545..cc8ab8217 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/bm25_retriever.py +++ b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/bm25_retriever.py @@ -35,7 +35,7 @@ def __init__( Filters are applied during the BM25 search to ensure the Retriever returns `top_k` matching documents. :param top_k: Maximum number of documents to return. - :filter_policy: Policy to determine how filters are applied. + :param filter_policy: Policy to determine how filters are applied. :param kwargs: Additional keyword arguments to pass to the Azure AI's search endpoint. Some of the supported parameters: - `query_type`: A string indicating the type of query to perform. Possible values are From 422d76b73ddeebd4da2ef980c4ec144e5e475930 Mon Sep 17 00:00:00 2001 From: Trivan Menezes <47679108+ttmenezes@users.noreply.github.com> Date: Mon, 18 Nov 2024 09:10:17 -0800 Subject: [PATCH 09/12] improve error handling and function comments --- .../retrievers/azure_ai_search/bm25_retriever.py | 5 +++-- .../retrievers/azure_ai_search/hybrid_retriever.py | 9 +++++---- .../document_stores/azure_ai_search/document_store.py | 5 ++++- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/bm25_retriever.py b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/bm25_retriever.py index cc8ab8217..dfbee32f6 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/bm25_retriever.py +++ b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/bm25_retriever.py @@ -44,7 +44,8 @@ def __init__( processing semantic queries. For more information on parameters, see the [official Azure AI Search documentation](https://learn.microsoft.com/en-us/azure/search/). - + :raises TypeError: If the document store is not an instance of AzureAISearchDocumentStore. + :raises RuntimeError: If the query is not valid, or if the document store is not correctly configured. """ self._filters = filters or {} @@ -56,7 +57,7 @@ def __init__( self._kwargs = kwargs if not isinstance(document_store, AzureAISearchDocumentStore): message = "document_store must be an instance of AzureAISearchDocumentStore" - raise Exception(message) + raise TypeError(message) def to_dict(self) -> Dict[str, Any]: """ diff --git a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py index ce0a17e2e..4f7554dc7 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py +++ b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py @@ -35,7 +35,7 @@ def __init__( Filters are applied during the hybrid search to ensure the Retriever returns `top_k` matching documents. :param top_k: Maximum number of documents to return. - :filter_policy: Policy to determine how filters are applied. + :param filter_policy: Policy to determine how filters are applied. :param kwargs: Additional keyword arguments to pass to the Azure AI's search endpoint. Some of the supported parameters: - `query_type`: A string indicating the type of query to perform. Possible values are @@ -44,8 +44,9 @@ def __init__( processing semantic queries. For more information on parameters, see the [official Azure AI Search documentation](https://learn.microsoft.com/en-us/azure/search/). - - + :raises TypeError: If the document store is not an instance of AzureAISearchDocumentStore. + :raises RuntimeError: If query or query_embedding are invalid, or if document store is not correctly configured. + """ self._filters = filters or {} self._top_k = top_k @@ -57,7 +58,7 @@ def __init__( if not isinstance(document_store, AzureAISearchDocumentStore): message = "document_store must be an instance of AzureAISearchDocumentStore" - raise Exception(message) + raise TypeError(message) def to_dict(self) -> Dict[str, Any]: """ diff --git a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py index cf0657495..9580c794d 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py +++ b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py @@ -240,6 +240,9 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D Writes the provided documents to search index. :param documents: documents to write to the index. + :param policy: Policy to determine how duplicates are handled. + :raises ValueError: If the documents are not of type Document. + :raises TypeError: If the document ids are not strings. :return: the number of documents added to index. """ @@ -247,7 +250,7 @@ def _convert_input_document(documents: Document): document_dict = asdict(documents) if not isinstance(document_dict["id"], str): msg = f"Document id {document_dict['id']} is not a string, " - raise Exception(msg) + raise TypeError(msg) index_document = self._convert_haystack_documents_to_azure(document_dict) return index_document From 284759071512a855e6776484d100a848efab1d00 Mon Sep 17 00:00:00 2001 From: Trivan Menezes <47679108+ttmenezes@users.noreply.github.com> Date: Mon, 18 Nov 2024 09:17:57 -0800 Subject: [PATCH 10/12] add error descriptions to function comments --- .../components/retrievers/azure_ai_search/bm25_retriever.py | 1 + .../components/retrievers/azure_ai_search/hybrid_retriever.py | 1 + 2 files changed, 2 insertions(+) diff --git a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/bm25_retriever.py b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/bm25_retriever.py index dfbee32f6..839fb6cf2 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/bm25_retriever.py +++ b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/bm25_retriever.py @@ -105,6 +105,7 @@ def run(self, query: str, filters: Optional[Dict[str, Any]] = None, top_k: Optio the `filter_policy` chosen at retriever initialization. See init method docstring for more details. :param top_k: the maximum number of documents to retrieve. + :raises RuntimeError: If an error occurs during the BM25 retrieval process. :returns: a dictionary with the following keys: - `documents`: A list of documents retrieved from the AzureAISearchDocumentStore. """ diff --git a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py index 4f7554dc7..36c70a9af 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py +++ b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py @@ -113,6 +113,7 @@ def run( the `filter_policy` chosen at retriever initialization. See `__init__` method docstring for more details. :param top_k: The maximum number of documents to retrieve. + :raises RuntimeError: If an error occurs during the hybrid retrieval process. :returns: A dictionary with the following keys: - `documents`: A list of documents retrieved from the AzureAISearchDocumentStore. """ From 5db804dec401417531ade28e3345c3c3416e9533 Mon Sep 17 00:00:00 2001 From: Trivan Menezes <47679108+ttmenezes@users.noreply.github.com> Date: Mon, 18 Nov 2024 09:29:57 -0800 Subject: [PATCH 11/12] address style comments --- .../components/retrievers/azure_ai_search/hybrid_retriever.py | 1 - 1 file changed, 1 deletion(-) diff --git a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py index 36c70a9af..fc8983754 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py +++ b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py @@ -46,7 +46,6 @@ def __init__( [official Azure AI Search documentation](https://learn.microsoft.com/en-us/azure/search/). :raises TypeError: If the document store is not an instance of AzureAISearchDocumentStore. :raises RuntimeError: If query or query_embedding are invalid, or if document store is not correctly configured. - """ self._filters = filters or {} self._top_k = top_k From 458e0d57b62c330a64f2637f1e2039fea571db47 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Wed, 20 Nov 2024 12:37:52 +0100 Subject: [PATCH 12/12] Added mock tests to the retrievers --- .../azure_ai_search/bm25_retriever.py | 3 +- .../azure_ai_search/embedding_retriever.py | 3 +- .../azure_ai_search/hybrid_retriever.py | 3 +- .../azure_ai_search/document_store.py | 1 + .../tests/test_bm25_retriever.py | 50 ++++++++++++++ .../tests/test_embedding_retriever.py | 60 ++++++++++++++++ .../tests/test_hybrid_retriever.py | 68 ++++++++++++++++++- 7 files changed, 184 insertions(+), 4 deletions(-) diff --git a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/bm25_retriever.py b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/bm25_retriever.py index 839fb6cf2..4a1c7f98c 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/bm25_retriever.py +++ b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/bm25_retriever.py @@ -111,7 +111,8 @@ def run(self, query: str, filters: Optional[Dict[str, Any]] = None, top_k: Optio """ top_k = top_k or self._top_k - if filters is not None: + filters = filters or self._filters + if filters: applied_filters = apply_filter_policy(self._filter_policy, self._filters, filters) normalized_filters = _normalize_filters(applied_filters) else: diff --git a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/embedding_retriever.py b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/embedding_retriever.py index af48b74fb..69fad7208 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/embedding_retriever.py +++ b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/embedding_retriever.py @@ -107,7 +107,8 @@ def run(self, query_embedding: List[float], filters: Optional[Dict[str, Any]] = """ top_k = top_k or self._top_k - if filters is not None: + filters = filters or self._filters + if filters: applied_filters = apply_filter_policy(self._filter_policy, self._filters, filters) normalized_filters = _normalize_filters(applied_filters) else: diff --git a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py index fc8983754..79282933f 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py +++ b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py @@ -118,7 +118,8 @@ def run( """ top_k = top_k or self._top_k - if filters is not None: + filters = filters or self._filters + if filters: applied_filters = apply_filter_policy(self._filter_policy, self._filters, filters) normalized_filters = _normalize_filters(applied_filters) else: diff --git a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py index 9580c794d..137ff621c 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py +++ b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py @@ -498,6 +498,7 @@ def _hybrid_retrieval( `AzureAISearchHybridRetriever` uses this method directly and is the public interface for it. :param query: Text of the query. + :param query_embedding: Embedding of the query. :param filters: Filters applied to the retrieved Documents. :param top_k: Maximum number of Documents to return. :param kwargs: Optional keyword arguments to pass to the Azure AI's search endpoint. diff --git a/integrations/azure_ai_search/tests/test_bm25_retriever.py b/integrations/azure_ai_search/tests/test_bm25_retriever.py index e6631a16b..6ebb20949 100644 --- a/integrations/azure_ai_search/tests/test_bm25_retriever.py +++ b/integrations/azure_ai_search/tests/test_bm25_retriever.py @@ -98,6 +98,56 @@ def test_from_dict(): assert retriever._filter_policy == FilterPolicy.REPLACE +def test_run(): + mock_store = Mock(spec=AzureAISearchDocumentStore) + mock_store._bm25_retrieval.return_value = [Document(content="Test doc")] + retriever = AzureAISearchBM25Retriever(document_store=mock_store) + res = retriever.run(query="Test query") + mock_store._bm25_retrieval.assert_called_once_with( + query="Test query", + filters="", + top_k=10, + ) + assert len(res) == 1 + assert len(res["documents"]) == 1 + assert res["documents"][0].content == "Test doc" + + +def test_run_init_params(): + mock_store = Mock(spec=AzureAISearchDocumentStore) + mock_store._bm25_retrieval.return_value = [Document(content="Test doc")] + retriever = AzureAISearchBM25Retriever( + document_store=mock_store, filters={"field": "type", "operator": "==", "value": "article"}, top_k=11 + ) + res = retriever.run(query="Test query") + mock_store._bm25_retrieval.assert_called_once_with( + query="Test query", + filters="type eq 'article'", + top_k=11, + ) + assert len(res) == 1 + assert len(res["documents"]) == 1 + assert res["documents"][0].content == "Test doc" + + +def test_run_time_params(): + mock_store = Mock(spec=AzureAISearchDocumentStore) + mock_store._bm25_retrieval.return_value = [Document(content="Test doc")] + retriever = AzureAISearchBM25Retriever( + document_store=mock_store, + filters={"field": "type", "operator": "==", "value": "article"}, + top_k=11, + select="name", + ) + res = retriever.run(query="Test query", filters={"field": "type", "operator": "==", "value": "book"}, top_k=5) + mock_store._bm25_retrieval.assert_called_once_with( + query="Test query", filters="type eq 'book'", top_k=5, select="name" + ) + assert len(res) == 1 + assert len(res["documents"]) == 1 + assert res["documents"][0].content == "Test doc" + + @pytest.mark.skipif( not os.environ.get("AZURE_SEARCH_SERVICE_ENDPOINT", None) and not os.environ.get("AZURE_SEARCH_API_KEY", None), reason="Missing AZURE_SEARCH_SERVICE_ENDPOINT or AZURE_SEARCH_API_KEY.", diff --git a/integrations/azure_ai_search/tests/test_embedding_retriever.py b/integrations/azure_ai_search/tests/test_embedding_retriever.py index d4615ec44..576ecda08 100644 --- a/integrations/azure_ai_search/tests/test_embedding_retriever.py +++ b/integrations/azure_ai_search/tests/test_embedding_retriever.py @@ -103,6 +103,66 @@ def test_from_dict(): assert retriever._filter_policy == FilterPolicy.REPLACE +def test_run(): + mock_store = Mock(spec=AzureAISearchDocumentStore) + mock_store._embedding_retrieval.return_value = [Document(content="Test doc", embedding=[0.1, 0.2])] + retriever = AzureAISearchEmbeddingRetriever(document_store=mock_store) + res = retriever.run(query_embedding=[0.5, 0.7]) + mock_store._embedding_retrieval.assert_called_once_with( + query_embedding=[0.5, 0.7], + filters="", + top_k=10, + ) + assert len(res) == 1 + assert len(res["documents"]) == 1 + assert res["documents"][0].content == "Test doc" + assert res["documents"][0].embedding == [0.1, 0.2] + + +def test_run_init_params(): + mock_store = Mock(spec=AzureAISearchDocumentStore) + mock_store._embedding_retrieval.return_value = [Document(content="Test doc", embedding=[0.1, 0.2])] + retriever = AzureAISearchEmbeddingRetriever( + document_store=mock_store, + filters={"field": "type", "operator": "==", "value": "article"}, + top_k=11, + ) + res = retriever.run(query_embedding=[0.5, 0.7]) + mock_store._embedding_retrieval.assert_called_once_with( + query_embedding=[0.5, 0.7], + filters="type eq 'article'", + top_k=11, + ) + assert len(res) == 1 + assert len(res["documents"]) == 1 + assert res["documents"][0].content == "Test doc" + assert res["documents"][0].embedding == [0.1, 0.2] + + +def test_run_time_params(): + mock_store = Mock(spec=AzureAISearchDocumentStore) + mock_store._embedding_retrieval.return_value = [Document(content="Test doc", embedding=[0.1, 0.2])] + retriever = AzureAISearchEmbeddingRetriever( + document_store=mock_store, + filters={"field": "type", "operator": "==", "value": "article"}, + top_k=11, + select="name", + ) + res = retriever.run( + query_embedding=[0.5, 0.7], filters={"field": "type", "operator": "==", "value": "book"}, top_k=9 + ) + mock_store._embedding_retrieval.assert_called_once_with( + query_embedding=[0.5, 0.7], + filters="type eq 'book'", + top_k=9, + select="name", + ) + assert len(res) == 1 + assert len(res["documents"]) == 1 + assert res["documents"][0].content == "Test doc" + assert res["documents"][0].embedding == [0.1, 0.2] + + @pytest.mark.skipif( not os.environ.get("AZURE_SEARCH_SERVICE_ENDPOINT", None) and not os.environ.get("AZURE_SEARCH_API_KEY", None), reason="Missing AZURE_SEARCH_SERVICE_ENDPOINT or AZURE_SEARCH_API_KEY.", diff --git a/integrations/azure_ai_search/tests/test_hybrid_retriever.py b/integrations/azure_ai_search/tests/test_hybrid_retriever.py index 2447949fd..bf305c4fe 100644 --- a/integrations/azure_ai_search/tests/test_hybrid_retriever.py +++ b/integrations/azure_ai_search/tests/test_hybrid_retriever.py @@ -103,6 +103,72 @@ def test_from_dict(): assert retriever._filter_policy == FilterPolicy.REPLACE +def test_run(): + mock_store = Mock(spec=AzureAISearchDocumentStore) + mock_store._hybrid_retrieval.return_value = [Document(content="Test doc", embedding=[0.1, 0.2])] + retriever = AzureAISearchHybridRetriever(document_store=mock_store) + res = retriever.run(query_embedding=[0.5, 0.7], query="Test query") + mock_store._hybrid_retrieval.assert_called_once_with( + query="Test query", + query_embedding=[0.5, 0.7], + filters="", + top_k=10, + ) + assert len(res) == 1 + assert len(res["documents"]) == 1 + assert res["documents"][0].content == "Test doc" + assert res["documents"][0].embedding == [0.1, 0.2] + + +def test_run_init_params(): + mock_store = Mock(spec=AzureAISearchDocumentStore) + mock_store._hybrid_retrieval.return_value = [Document(content="Test doc", embedding=[0.1, 0.2])] + retriever = AzureAISearchHybridRetriever( + document_store=mock_store, + filters={"field": "type", "operator": "==", "value": "article"}, + top_k=11, + ) + res = retriever.run(query_embedding=[0.5, 0.7], query="Test query") + mock_store._hybrid_retrieval.assert_called_once_with( + query="Test query", + query_embedding=[0.5, 0.7], + filters="type eq 'article'", + top_k=11, + ) + assert len(res) == 1 + assert len(res["documents"]) == 1 + assert res["documents"][0].content == "Test doc" + assert res["documents"][0].embedding == [0.1, 0.2] + + +def test_run_time_params(): + mock_store = Mock(spec=AzureAISearchDocumentStore) + mock_store._hybrid_retrieval.return_value = [Document(content="Test doc", embedding=[0.1, 0.2])] + retriever = AzureAISearchHybridRetriever( + document_store=mock_store, + filters={"field": "type", "operator": "==", "value": "article"}, + top_k=11, + select="name", + ) + res = retriever.run( + query_embedding=[0.5, 0.7], + query="Test query", + filters={"field": "type", "operator": "==", "value": "book"}, + top_k=9, + ) + mock_store._hybrid_retrieval.assert_called_once_with( + query="Test query", + query_embedding=[0.5, 0.7], + filters="type eq 'book'", + top_k=9, + select="name", + ) + assert len(res) == 1 + assert len(res["documents"]) == 1 + assert res["documents"][0].content == "Test doc" + assert res["documents"][0].embedding == [0.1, 0.2] + + @pytest.mark.skipif( not os.environ.get("AZURE_SEARCH_SERVICE_ENDPOINT", None) and not os.environ.get("AZURE_SEARCH_API_KEY", None), reason="Missing AZURE_SEARCH_SERVICE_ENDPOINT or AZURE_SEARCH_API_KEY.", @@ -117,7 +183,7 @@ def test_run(self, document_store: AzureAISearchDocumentStore): res = retriever.run(query="Test document", query_embedding=[0.1] * 768) assert res["documents"] == docs - def test_embedding_retrieval(self, document_store: AzureAISearchDocumentStore): + def test_hybrid_retrieval(self, document_store: AzureAISearchDocumentStore): query_embedding = [0.1] * 768 most_similar_embedding = [0.8] * 768 second_best_embedding = [0.8] * 200 + [0.1] * 300 + [0.2] * 268