From 252d27df70f90c933fd9a870b5293da04370f29d Mon Sep 17 00:00:00 2001 From: Trivan Menezes <47679108+ttmenezes@users.noreply.github.com> Date: Mon, 11 Nov 2024 15:16:07 -0800 Subject: [PATCH] new azure retrievers --- .../retrievers/azure_ai_search/__init__.py | 4 +- .../azure_ai_search/bm25_retriever.py | 116 +++++++++++++++++ .../azure_ai_search/hybrid_retriever.py | 120 ++++++++++++++++++ .../azure_ai_search/document_store.py | 52 ++++++++ 4 files changed, 291 insertions(+), 1 deletion(-) create mode 100644 integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/bm25_retriever.py create mode 100644 integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py diff --git a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/__init__.py b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/__init__.py index eb75ffa6c..eebe990f3 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/__init__.py +++ b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/__init__.py @@ -1,3 +1,5 @@ from .embedding_retriever import AzureAISearchEmbeddingRetriever +from .bm25_retriever import AzureAISearchBM25Retriever +from .hybrid_retriever import AzureAISearchHybridRetriever -__all__ = ["AzureAISearchEmbeddingRetriever"] +__all__ = ["AzureAISearchEmbeddingRetriever", "AzureAISearchBM25Retriever", "AzureAISearchHybridRetriever"] diff --git a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/bm25_retriever.py b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/bm25_retriever.py new file mode 100644 index 000000000..65e273b73 --- /dev/null +++ b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/bm25_retriever.py @@ -0,0 +1,116 @@ +import logging +from typing import Any, Dict, List, Optional, Union + +from haystack import Document, component, default_from_dict, default_to_dict +from haystack.document_stores.types import FilterPolicy +from haystack.document_stores.types.filter_policy import apply_filter_policy + +from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore, normalize_filters + +logger = logging.getLogger(__name__) + + +@component +class AzureAISearchBM25Retriever: + """ + Retrieves documents from the AzureAISearchDocumentStore using BM25 retrieval. + Must be connected to the AzureAISearchDocumentStore to run. + + """ + + def __init__( + self, + *, + document_store: AzureAISearchDocumentStore, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE, + ): + """ + Create the AzureAISearchBM25Retriever component. + + :param document_store: An instance of AzureAISearchDocumentStore to use with the Retriever. + :param filters: Filters applied when fetching documents from the Document Store. + Filters are applied during the BM25 search to ensure the Retriever returns + `top_k` matching documents. + :param top_k: Maximum number of documents to return. + :filter_policy: Policy to determine how filters are applied. Possible options: + + """ + self._filters = filters or {} + self._top_k = top_k + self._document_store = document_store + self._filter_policy = ( + filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) + ) + + if not isinstance(document_store, AzureAISearchDocumentStore): + message = "document_store must be an instance of AzureAISearchDocumentStore" + raise Exception(message) + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + return default_to_dict( + self, + filters=self._filters, + top_k=self._top_k, + document_store=self._document_store.to_dict(), + filter_policy=self._filter_policy.value, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "AzureAISearchBM25Retriever": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + + :returns: + Deserialized component. + """ + data["init_parameters"]["document_store"] = AzureAISearchDocumentStore.from_dict( + data["init_parameters"]["document_store"] + ) + + # Pipelines serialized with old versions of the component might not + # have the filter_policy field. + if "filter_policy" in data["init_parameters"]: + data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(data["init_parameters"]["filter_policy"]) + return default_from_dict(cls, data) + + @component.output_types(documents=List[Document]) + def run(self, query: str, filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None): + """Retrieve documents from the AzureAISearchDocumentStore. + + :param query: Text of the query. + :param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on + the `filter_policy` chosen at retriever initialization. See init method docstring for more + details. + :param top_k: the maximum number of documents to retrieve. + :returns: a dictionary with the following keys: + - `documents`: A list of documents retrieved from the AzureAISearchDocumentStore. + """ + + top_k = top_k or self._top_k + if filters is not None: + applied_filters = apply_filter_policy(self._filter_policy, self._filters, filters) + normalized_filters = normalize_filters(applied_filters) + else: + normalized_filters = "" + + try: + docs = self._document_store._bm25_retrieval( + query=query, + filters=normalized_filters, + top_k=top_k, + ) + except Exception as e: + raise e + + return {"documents": docs} diff --git a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py new file mode 100644 index 000000000..eb28c4e73 --- /dev/null +++ b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py @@ -0,0 +1,120 @@ +import logging +from typing import Any, Dict, List, Optional, Union + +from haystack import Document, component, default_from_dict, default_to_dict +from haystack.document_stores.types import FilterPolicy +from haystack.document_stores.types.filter_policy import apply_filter_policy + +from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore, normalize_filters + +logger = logging.getLogger(__name__) + + +@component +class AzureAISearchHybridRetriever: + """ + Retrieves documents from the AzureAISearchDocumentStore using a hybrid (vector + BM25) retrieval. + Must be connected to the AzureAISearchDocumentStore to run. + + """ + + def __init__( + self, + *, + document_store: AzureAISearchDocumentStore, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE, + ): + """ + Create the AzureAISearchHybridRetriever component. + + :param document_store: An instance of AzureAISearchDocumentStore to use with the Retriever. + :param filters: Filters applied when fetching documents from the Document Store. + Filters are applied during the hybrid search to ensure the Retriever returns + `top_k` matching documents. + :param top_k: Maximum number of documents to return. + :filter_policy: Policy to determine how filters are applied. Possible options: + + """ + self._filters = filters or {} + self._top_k = top_k + self._document_store = document_store + self._filter_policy = ( + filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) + ) + + if not isinstance(document_store, AzureAISearchDocumentStore): + message = "document_store must be an instance of AzureAISearchDocumentStore" + raise Exception(message) + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + return default_to_dict( + self, + filters=self._filters, + top_k=self._top_k, + document_store=self._document_store.to_dict(), + filter_policy=self._filter_policy.value, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "AzureAISearchHybridRetriever": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + + :returns: + Deserialized component. + """ + data["init_parameters"]["document_store"] = AzureAISearchDocumentStore.from_dict( + data["init_parameters"]["document_store"] + ) + + # Pipelines serialized with old versions of the component might not + # have the filter_policy field. + if "filter_policy" in data["init_parameters"]: + data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(data["init_parameters"]["filter_policy"]) + return default_from_dict(cls, data) + + @component.output_types(documents=List[Document]) + def run( + self, query: str, query_embedding: List[float], filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None + ): + """Retrieve documents from the AzureAISearchDocumentStore. + + :param query: Text of the query. + :param query_embedding: floats representing the query embedding + :param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on + the `filter_policy` chosen at retriever initialization. See init method docstring for more + details. + :param top_k: the maximum number of documents to retrieve. + :returns: a dictionary with the following keys: + - `documents`: A list of documents retrieved from the AzureAISearchDocumentStore. + """ + + top_k = top_k or self._top_k + if filters is not None: + applied_filters = apply_filter_policy(self._filter_policy, self._filters, filters) + normalized_filters = normalize_filters(applied_filters) + else: + normalized_filters = "" + + try: + docs = self._document_store._hybrid_retrieval( + query=query, + query_embedding=query_embedding, + filters=normalized_filters, + top_k=top_k, + ) + except Exception as e: + raise e + + return {"documents": docs} diff --git a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py index 0b59b6e37..4e3b3b44a 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py +++ b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py @@ -438,3 +438,55 @@ def _embedding_retrieval( result = self.client.search(search_text=None, vector_queries=[vector_query], select=fields, filter=filters) azure_docs = list(result) return self._convert_search_result_to_documents(azure_docs) + + def _bm25_retrieval( + self, + query: str, + top_k: int = 10, + fields: Optional[List[str]] = None, + filters: Optional[Dict[str, Any]] = None, + ) -> List[Document]: + """ + Retrieves documents that are most similar to `query`, using the BM25 algorithm + + This method is not meant to be part of the public interface of + `AzureAISearchDocumentStore` nor called directly. + `AzureAISearchBM25Retriever` uses this method directly and is the public interface for it. + + :param query: Text of the query. + :param filters: Filters applied to the retrieved Documents. Defaults to None. + Filters are applied during the approximate kNN search to ensure that top_k matching documents are returned. + :param top_k: Maximum number of Documents to return, defaults to 10 + + :raises ValueError: If `query` is an empty string + :returns: List of Document that are most similar to `query` + """ + + if query is None: + msg = "query must not be None" + raise ValueError(msg) + + result = self.client.search(search_text=query, select=fields, filter=filters, top=top_k, query_type="simple") + azure_docs = list(result) + return self._convert_search_result_to_documents(azure_docs) + + def _hybrid_retrieval( + self, + query: str, + query_embedding: List[float], + top_k: int = 10, + fields: Optional[List[str]] = None, + filters: Optional[Dict[str, Any]] = None, + ) -> List[Document]: + + if query is None: + msg = "query must not be None" + raise ValueError(msg) + if not query_embedding: + msg = "query_embedding must be a non-empty list of floats" + raise ValueError(msg) + + vector_query = VectorizedQuery(vector=query_embedding, k_nearest_neighbors=top_k, fields="embedding") + result = self.client.search(search_text=query, vector_queries=[vector_query], select=fields, filter=filters, top=top_k, query_type="simple") + azure_docs = list(result) + return self._convert_search_result_to_documents(azure_docs) \ No newline at end of file