Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add BM25 and Hybrid Search Retrievers to Azure AI Search Integration #1175

Merged
merged 16 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from .bm25_retriever import AzureAISearchBM25Retriever
from .embedding_retriever import AzureAISearchEmbeddingRetriever
from .hybrid_retriever import AzureAISearchHybridRetriever

__all__ = ["AzureAISearchEmbeddingRetriever"]
__all__ = ["AzureAISearchBM25Retriever", "AzureAISearchEmbeddingRetriever", "AzureAISearchHybridRetriever"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import logging
from typing import Any, Dict, List, Optional, Union

from haystack import Document, component, default_from_dict, default_to_dict
from haystack.document_stores.types import FilterPolicy
from haystack.document_stores.types.filter_policy import apply_filter_policy

from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore, _normalize_filters

logger = logging.getLogger(__name__)


@component
class AzureAISearchBM25Retriever:
"""
Retrieves documents from the AzureAISearchDocumentStore using BM25 retrieval.
Must be connected to the AzureAISearchDocumentStore to run.

"""

def __init__(
self,
*,
document_store: AzureAISearchDocumentStore,
filters: Optional[Dict[str, Any]] = None,
top_k: int = 10,
filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE,
Amnah199 marked this conversation as resolved.
Show resolved Hide resolved
**kwargs,
):
"""
Create the AzureAISearchBM25Retriever component.

:param document_store: An instance of AzureAISearchDocumentStore to use with the Retriever.
:param filters: Filters applied when fetching documents from the Document Store.
Filters are applied during the BM25 search to ensure the Retriever returns
`top_k` matching documents.
:param top_k: Maximum number of documents to return.
:filter_policy: Policy to determine how filters are applied.
ttmenezes marked this conversation as resolved.
Show resolved Hide resolved
:param kwargs: Additional keyword arguments to pass to the Azure AI's search endpoint.
Some of the supported parameters:
- `query_type`: A string indicating the type of query to perform. Possible values are
'simple','full' and 'semantic'.
- `semantic_configuration_name`: The name of semantic configuration to be used when
processing semantic queries.
For more information on parameters, see the
[official Azure AI Search documentation](https://learn.microsoft.com/en-us/azure/search/).
ttmenezes marked this conversation as resolved.
Show resolved Hide resolved


"""
self._filters = filters or {}
self._top_k = top_k
self._document_store = document_store
self._filter_policy = (
filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy)
)
self._kwargs = kwargs
if not isinstance(document_store, AzureAISearchDocumentStore):
message = "document_store must be an instance of AzureAISearchDocumentStore"
raise Exception(message)
julian-risch marked this conversation as resolved.
Show resolved Hide resolved

def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.

:returns:
Dictionary with serialized data.
"""
return default_to_dict(
self,
filters=self._filters,
top_k=self._top_k,
document_store=self._document_store.to_dict(),
filter_policy=self._filter_policy.value,
**self._kwargs,
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "AzureAISearchBM25Retriever":
"""
Deserializes the component from a dictionary.

:param data:
Dictionary to deserialize from.

:returns:
Deserialized component.
"""
data["init_parameters"]["document_store"] = AzureAISearchDocumentStore.from_dict(
data["init_parameters"]["document_store"]
)

# Pipelines serialized with old versions of the component might not
# have the filter_policy field.
if "filter_policy" in data["init_parameters"]:
data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(data["init_parameters"]["filter_policy"])
return default_from_dict(cls, data)

@component.output_types(documents=List[Document])
def run(self, query: str, filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None):
"""Retrieve documents from the AzureAISearchDocumentStore.

:param query: Text of the query.
:param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on
the `filter_policy` chosen at retriever initialization. See init method docstring for more
details.
:param top_k: the maximum number of documents to retrieve.
:returns: a dictionary with the following keys:
ttmenezes marked this conversation as resolved.
Show resolved Hide resolved
- `documents`: A list of documents retrieved from the AzureAISearchDocumentStore.
"""

top_k = top_k or self._top_k
if filters is not None:
applied_filters = apply_filter_policy(self._filter_policy, self._filters, filters)
normalized_filters = _normalize_filters(applied_filters)
else:
normalized_filters = ""

try:
docs = self._document_store._bm25_retrieval(
query=query,
filters=normalized_filters,
top_k=top_k,
**self._kwargs,
)
except Exception as e:
msg = (
"An error occurred during the bm25 retrieval process from the AzureAISearchDocumentStore. "
"Ensure that the query is valid and the document store is correctly configured."
)
raise RuntimeError(msg) from e

return {"documents": docs}
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import logging
from typing import Any, Dict, List, Optional, Union

from haystack import Document, component, default_from_dict, default_to_dict
from haystack.document_stores.types import FilterPolicy
from haystack.document_stores.types.filter_policy import apply_filter_policy

from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore, _normalize_filters

logger = logging.getLogger(__name__)


@component
class AzureAISearchHybridRetriever:
"""
Retrieves documents from the AzureAISearchDocumentStore using a hybrid (vector + BM25) retrieval.
Must be connected to the AzureAISearchDocumentStore to run.

"""

def __init__(
self,
*,
document_store: AzureAISearchDocumentStore,
filters: Optional[Dict[str, Any]] = None,
top_k: int = 10,
filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE,
**kwargs,
):
"""
Create the AzureAISearchHybridRetriever component.

:param document_store: An instance of AzureAISearchDocumentStore to use with the Retriever.
:param filters: Filters applied when fetching documents from the Document Store.
Filters are applied during the hybrid search to ensure the Retriever returns
`top_k` matching documents.
:param top_k: Maximum number of documents to return.
:filter_policy: Policy to determine how filters are applied.
ttmenezes marked this conversation as resolved.
Show resolved Hide resolved
:param kwargs: Additional keyword arguments to pass to the Azure AI's search endpoint.
Some of the supported parameters:
- `query_type`: A string indicating the type of query to perform. Possible values are
'simple','full' and 'semantic'.
- `semantic_configuration_name`: The name of semantic configuration to be used when
processing semantic queries.
For more information on parameters, see the
[official Azure AI Search documentation](https://learn.microsoft.com/en-us/azure/search/).
ttmenezes marked this conversation as resolved.
Show resolved Hide resolved


"""
self._filters = filters or {}
self._top_k = top_k
self._document_store = document_store
self._filter_policy = (
filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy)
)
self._kwargs = kwargs

if not isinstance(document_store, AzureAISearchDocumentStore):
message = "document_store must be an instance of AzureAISearchDocumentStore"
raise Exception(message)
julian-risch marked this conversation as resolved.
Show resolved Hide resolved

def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.

:returns:
Dictionary with serialized data.
"""
return default_to_dict(
self,
filters=self._filters,
top_k=self._top_k,
document_store=self._document_store.to_dict(),
filter_policy=self._filter_policy.value,
**self._kwargs,
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "AzureAISearchHybridRetriever":
"""
Deserializes the component from a dictionary.

:param data:
Dictionary to deserialize from.

:returns:
Deserialized component.
"""
data["init_parameters"]["document_store"] = AzureAISearchDocumentStore.from_dict(
data["init_parameters"]["document_store"]
)

# Pipelines serialized with old versions of the component might not
# have the filter_policy field.
if "filter_policy" in data["init_parameters"]:
data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(data["init_parameters"]["filter_policy"])
return default_from_dict(cls, data)

@component.output_types(documents=List[Document])
def run(
self,
query: str,
query_embedding: List[float],
filters: Optional[Dict[str, Any]] = None,
top_k: Optional[int] = None,
):
"""Retrieve documents from the AzureAISearchDocumentStore.

:param query: Text of the query.
:param query_embedding: A list of floats representing the query embedding
:param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on
the `filter_policy` chosen at retriever initialization. See `__init__` method docstring for more
details.
:param top_k: The maximum number of documents to retrieve.
ttmenezes marked this conversation as resolved.
Show resolved Hide resolved
:returns: A dictionary with the following keys:
- `documents`: A list of documents retrieved from the AzureAISearchDocumentStore.
"""

top_k = top_k or self._top_k
if filters is not None:
applied_filters = apply_filter_policy(self._filter_policy, self._filters, filters)
normalized_filters = _normalize_filters(applied_filters)
else:
normalized_filters = ""

try:
docs = self._document_store._hybrid_retrieval(
query=query, query_embedding=query_embedding, filters=normalized_filters, top_k=top_k, **self._kwargs
)
except Exception as e:
msg = (
"An error occurred during the hybrid retrieval process from the AzureAISearchDocumentStore. "
"Ensure that the query and query_embedding are valid and the document store is correctly configured."
)
raise RuntimeError(msg) from e

return {"documents": docs}
Original file line number Diff line number Diff line change
Expand Up @@ -421,21 +421,20 @@ def _embedding_retrieval(
) -> List[Document]:
"""
Retrieves documents that are most similar to the query embedding using a vector similarity metric.
It uses the vector configuration of the document store. By default it uses the HNSW algorithm
It uses the vector configuration specified in the document store. By default, it uses the HNSW algorithm
with cosine similarity.

This method is not meant to be part of the public interface of
`AzureAISearchDocumentStore` nor called directly.
`AzureAISearchEmbeddingRetriever` uses this method directly and is the public interface for it.

:param query_embedding: Embedding of the query.
:param top_k: Maximum number of Documents to return, defaults to 10.
:param filters: Filters applied to the retrieved Documents. Defaults to None.
Filters are applied during the approximate kNN search to ensure that top_k matching documents are returned.
:param top_k: Maximum number of Documents to return.
:param filters: Filters applied to the retrieved Documents.
:param kwargs: Optional keyword arguments to pass to the Azure AI's search endpoint.

:raises ValueError: If `query_embedding` is an empty list
:returns: List of Document that are most similar to `query_embedding`
:raises ValueError: If `query_embedding` is an empty list.
:returns: List of Document that are most similar to `query_embedding`.
ttmenezes marked this conversation as resolved.
Show resolved Hide resolved
"""

if not query_embedding:
Expand All @@ -446,3 +445,79 @@ def _embedding_retrieval(
result = self.client.search(vector_queries=[vector_query], filter=filters, **kwargs)
azure_docs = list(result)
return self._convert_search_result_to_documents(azure_docs)

def _bm25_retrieval(
self,
query: str,
top_k: int = 10,
filters: Optional[Dict[str, Any]] = None,
**kwargs,
) -> List[Document]:
"""
Retrieves documents that are most similar to `query`, using the BM25 algorithm.

This method is not meant to be part of the public interface of
`AzureAISearchDocumentStore` nor called directly.
`AzureAISearchBM25Retriever` uses this method directly and is the public interface for it.

:param query: Text of the query.
:param filters: Filters applied to the retrieved Documents.
:param top_k: Maximum number of Documents to return.
:param kwargs: Optional keyword arguments to pass to the Azure AI's search endpoint.


:raises ValueError: If `query` is an empty string.
:returns: List of Document that are most similar to `query`.
"""

if query is None:
msg = "query must not be None"
raise ValueError(msg)

result = self.client.search(search_text=query, filter=filters, top=top_k, query_type="simple", **kwargs)
azure_docs = list(result)
return self._convert_search_result_to_documents(azure_docs)

def _hybrid_retrieval(
self,
query: str,
query_embedding: List[float],
top_k: int = 10,
filters: Optional[Dict[str, Any]] = None,
**kwargs,
) -> List[Document]:
"""
Retrieves documents similar to query using the vector configuration in the document store and
the BM25 algorithm. This method combines vector similarity and BM25 for improved retrieval.

This method is not meant to be part of the public interface of
`AzureAISearchDocumentStore` nor called directly.
`AzureAISearchHybridRetriever` uses this method directly and is the public interface for it.

:param query: Text of the query.
:param filters: Filters applied to the retrieved Documents.
:param top_k: Maximum number of Documents to return.
:param kwargs: Optional keyword arguments to pass to the Azure AI's search endpoint.

:raises ValueError: If `query` or `query_embedding` is empty.
:returns: List of Document that are most similar to `query`.
"""

if query is None:
msg = "query must not be None"
raise ValueError(msg)
if not query_embedding:
msg = "query_embedding must be a non-empty list of floats"
raise ValueError(msg)

vector_query = VectorizedQuery(vector=query_embedding, k_nearest_neighbors=top_k, fields="embedding")
result = self.client.search(
search_text=query,
vector_queries=[vector_query],
filter=filters,
top=top_k,
query_type="simple",
**kwargs,
)
azure_docs = list(result)
return self._convert_search_result_to_documents(azure_docs)
Loading