Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add BM25 and Hybrid Search Retrievers to Azure AI Search Integration #1175

Merged
merged 16 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from .bm25_retriever import AzureAISearchBM25Retriever
from .embedding_retriever import AzureAISearchEmbeddingRetriever
from .hybrid_retriever import AzureAISearchHybridRetriever

__all__ = ["AzureAISearchEmbeddingRetriever"]
__all__ = ["AzureAISearchBM25Retriever", "AzureAISearchEmbeddingRetriever", "AzureAISearchHybridRetriever"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import logging
from typing import Any, Dict, List, Optional, Union

from haystack import Document, component, default_from_dict, default_to_dict
from haystack.document_stores.types import FilterPolicy
from haystack.document_stores.types.filter_policy import apply_filter_policy

from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore, _normalize_filters

logger = logging.getLogger(__name__)


@component
class AzureAISearchBM25Retriever:
"""
Retrieves documents from the AzureAISearchDocumentStore using BM25 retrieval.
Must be connected to the AzureAISearchDocumentStore to run.

"""

def __init__(
self,
*,
document_store: AzureAISearchDocumentStore,
filters: Optional[Dict[str, Any]] = None,
top_k: int = 10,
filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE,
Amnah199 marked this conversation as resolved.
Show resolved Hide resolved
**kwargs,
):
"""
Create the AzureAISearchBM25Retriever component.

:param document_store: An instance of AzureAISearchDocumentStore to use with the Retriever.
:param filters: Filters applied when fetching documents from the Document Store.
Filters are applied during the BM25 search to ensure the Retriever returns
`top_k` matching documents.
:param top_k: Maximum number of documents to return.
:param filter_policy: Policy to determine how filters are applied.
:param kwargs: Additional keyword arguments to pass to the Azure AI's search endpoint.
Some of the supported parameters:
- `query_type`: A string indicating the type of query to perform. Possible values are
'simple','full' and 'semantic'.
- `semantic_configuration_name`: The name of semantic configuration to be used when
processing semantic queries.
For more information on parameters, see the
[official Azure AI Search documentation](https://learn.microsoft.com/en-us/azure/search/).
ttmenezes marked this conversation as resolved.
Show resolved Hide resolved
:raises TypeError: If the document store is not an instance of AzureAISearchDocumentStore.
:raises RuntimeError: If the query is not valid, or if the document store is not correctly configured.

"""
self._filters = filters or {}
self._top_k = top_k
self._document_store = document_store
self._filter_policy = (
filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy)
)
self._kwargs = kwargs
if not isinstance(document_store, AzureAISearchDocumentStore):
message = "document_store must be an instance of AzureAISearchDocumentStore"
raise TypeError(message)

def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.

:returns:
Dictionary with serialized data.
"""
return default_to_dict(
self,
filters=self._filters,
top_k=self._top_k,
document_store=self._document_store.to_dict(),
filter_policy=self._filter_policy.value,
**self._kwargs,
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "AzureAISearchBM25Retriever":
"""
Deserializes the component from a dictionary.

:param data:
Dictionary to deserialize from.

:returns:
Deserialized component.
"""
data["init_parameters"]["document_store"] = AzureAISearchDocumentStore.from_dict(
data["init_parameters"]["document_store"]
)

# Pipelines serialized with old versions of the component might not
# have the filter_policy field.
if "filter_policy" in data["init_parameters"]:
data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(data["init_parameters"]["filter_policy"])
return default_from_dict(cls, data)

@component.output_types(documents=List[Document])
def run(self, query: str, filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None):
"""Retrieve documents from the AzureAISearchDocumentStore.

:param query: Text of the query.
:param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on
the `filter_policy` chosen at retriever initialization. See init method docstring for more
details.
:param top_k: the maximum number of documents to retrieve.
:raises RuntimeError: If an error occurs during the BM25 retrieval process.
:returns: a dictionary with the following keys:
ttmenezes marked this conversation as resolved.
Show resolved Hide resolved
- `documents`: A list of documents retrieved from the AzureAISearchDocumentStore.
"""

top_k = top_k or self._top_k
filters = filters or self._filters
if filters:
applied_filters = apply_filter_policy(self._filter_policy, self._filters, filters)
normalized_filters = _normalize_filters(applied_filters)
else:
normalized_filters = ""

try:
docs = self._document_store._bm25_retrieval(
query=query,
filters=normalized_filters,
top_k=top_k,
**self._kwargs,
)
except Exception as e:
msg = (
"An error occurred during the bm25 retrieval process from the AzureAISearchDocumentStore. "
"Ensure that the query is valid and the document store is correctly configured."
)
raise RuntimeError(msg) from e

return {"documents": docs}
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ def run(self, query_embedding: List[float], filters: Optional[Dict[str, Any]] =
"""

top_k = top_k or self._top_k
if filters is not None:
filters = filters or self._filters
if filters:
applied_filters = apply_filter_policy(self._filter_policy, self._filters, filters)
normalized_filters = _normalize_filters(applied_filters)
else:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import logging
from typing import Any, Dict, List, Optional, Union

from haystack import Document, component, default_from_dict, default_to_dict
from haystack.document_stores.types import FilterPolicy
from haystack.document_stores.types.filter_policy import apply_filter_policy

from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore, _normalize_filters

logger = logging.getLogger(__name__)


@component
class AzureAISearchHybridRetriever:
"""
Retrieves documents from the AzureAISearchDocumentStore using a hybrid (vector + BM25) retrieval.
Must be connected to the AzureAISearchDocumentStore to run.

"""

def __init__(
self,
*,
document_store: AzureAISearchDocumentStore,
filters: Optional[Dict[str, Any]] = None,
top_k: int = 10,
filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE,
**kwargs,
):
"""
Create the AzureAISearchHybridRetriever component.

:param document_store: An instance of AzureAISearchDocumentStore to use with the Retriever.
:param filters: Filters applied when fetching documents from the Document Store.
Filters are applied during the hybrid search to ensure the Retriever returns
`top_k` matching documents.
:param top_k: Maximum number of documents to return.
:param filter_policy: Policy to determine how filters are applied.
:param kwargs: Additional keyword arguments to pass to the Azure AI's search endpoint.
Some of the supported parameters:
- `query_type`: A string indicating the type of query to perform. Possible values are
'simple','full' and 'semantic'.
- `semantic_configuration_name`: The name of semantic configuration to be used when
processing semantic queries.
For more information on parameters, see the
[official Azure AI Search documentation](https://learn.microsoft.com/en-us/azure/search/).
ttmenezes marked this conversation as resolved.
Show resolved Hide resolved
:raises TypeError: If the document store is not an instance of AzureAISearchDocumentStore.
:raises RuntimeError: If query or query_embedding are invalid, or if document store is not correctly configured.
"""
self._filters = filters or {}
self._top_k = top_k
self._document_store = document_store
self._filter_policy = (
filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy)
)
self._kwargs = kwargs

if not isinstance(document_store, AzureAISearchDocumentStore):
message = "document_store must be an instance of AzureAISearchDocumentStore"
raise TypeError(message)

def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.

:returns:
Dictionary with serialized data.
"""
return default_to_dict(
self,
filters=self._filters,
top_k=self._top_k,
document_store=self._document_store.to_dict(),
filter_policy=self._filter_policy.value,
**self._kwargs,
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "AzureAISearchHybridRetriever":
"""
Deserializes the component from a dictionary.

:param data:
Dictionary to deserialize from.

:returns:
Deserialized component.
"""
data["init_parameters"]["document_store"] = AzureAISearchDocumentStore.from_dict(
data["init_parameters"]["document_store"]
)

# Pipelines serialized with old versions of the component might not
# have the filter_policy field.
if "filter_policy" in data["init_parameters"]:
data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(data["init_parameters"]["filter_policy"])
return default_from_dict(cls, data)

@component.output_types(documents=List[Document])
def run(
self,
query: str,
query_embedding: List[float],
filters: Optional[Dict[str, Any]] = None,
top_k: Optional[int] = None,
):
"""Retrieve documents from the AzureAISearchDocumentStore.

:param query: Text of the query.
:param query_embedding: A list of floats representing the query embedding
:param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on
the `filter_policy` chosen at retriever initialization. See `__init__` method docstring for more
details.
:param top_k: The maximum number of documents to retrieve.
ttmenezes marked this conversation as resolved.
Show resolved Hide resolved
:raises RuntimeError: If an error occurs during the hybrid retrieval process.
:returns: A dictionary with the following keys:
- `documents`: A list of documents retrieved from the AzureAISearchDocumentStore.
"""

top_k = top_k or self._top_k
filters = filters or self._filters
if filters:
applied_filters = apply_filter_policy(self._filter_policy, self._filters, filters)
normalized_filters = _normalize_filters(applied_filters)
else:
normalized_filters = ""

try:
docs = self._document_store._hybrid_retrieval(
query=query, query_embedding=query_embedding, filters=normalized_filters, top_k=top_k, **self._kwargs
)
except Exception as e:
msg = (
"An error occurred during the hybrid retrieval process from the AzureAISearchDocumentStore. "
"Ensure that the query and query_embedding are valid and the document store is correctly configured."
)
raise RuntimeError(msg) from e

return {"documents": docs}
Loading