Skip to content

Commit

Permalink
new azure retrievers
Browse files Browse the repository at this point in the history
  • Loading branch information
ttmenezes committed Nov 11, 2024
1 parent 4cfee2d commit 252d27d
Show file tree
Hide file tree
Showing 4 changed files with 291 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from .embedding_retriever import AzureAISearchEmbeddingRetriever
from .bm25_retriever import AzureAISearchBM25Retriever
from .hybrid_retriever import AzureAISearchHybridRetriever

__all__ = ["AzureAISearchEmbeddingRetriever"]
__all__ = ["AzureAISearchEmbeddingRetriever", "AzureAISearchBM25Retriever", "AzureAISearchHybridRetriever"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import logging
from typing import Any, Dict, List, Optional, Union

from haystack import Document, component, default_from_dict, default_to_dict
from haystack.document_stores.types import FilterPolicy
from haystack.document_stores.types.filter_policy import apply_filter_policy

from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore, normalize_filters

logger = logging.getLogger(__name__)


@component
class AzureAISearchBM25Retriever:
"""
Retrieves documents from the AzureAISearchDocumentStore using BM25 retrieval.
Must be connected to the AzureAISearchDocumentStore to run.
"""

def __init__(
self,
*,
document_store: AzureAISearchDocumentStore,
filters: Optional[Dict[str, Any]] = None,
top_k: int = 10,
filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE,
):
"""
Create the AzureAISearchBM25Retriever component.
:param document_store: An instance of AzureAISearchDocumentStore to use with the Retriever.
:param filters: Filters applied when fetching documents from the Document Store.
Filters are applied during the BM25 search to ensure the Retriever returns
`top_k` matching documents.
:param top_k: Maximum number of documents to return.
:filter_policy: Policy to determine how filters are applied. Possible options:
"""
self._filters = filters or {}
self._top_k = top_k
self._document_store = document_store
self._filter_policy = (
filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy)
)

if not isinstance(document_store, AzureAISearchDocumentStore):
message = "document_store must be an instance of AzureAISearchDocumentStore"
raise Exception(message)

def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.
:returns:
Dictionary with serialized data.
"""
return default_to_dict(
self,
filters=self._filters,
top_k=self._top_k,
document_store=self._document_store.to_dict(),
filter_policy=self._filter_policy.value,
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "AzureAISearchBM25Retriever":
"""
Deserializes the component from a dictionary.
:param data:
Dictionary to deserialize from.
:returns:
Deserialized component.
"""
data["init_parameters"]["document_store"] = AzureAISearchDocumentStore.from_dict(
data["init_parameters"]["document_store"]
)

# Pipelines serialized with old versions of the component might not
# have the filter_policy field.
if "filter_policy" in data["init_parameters"]:
data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(data["init_parameters"]["filter_policy"])
return default_from_dict(cls, data)

@component.output_types(documents=List[Document])
def run(self, query: str, filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None):
"""Retrieve documents from the AzureAISearchDocumentStore.
:param query: Text of the query.
:param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on
the `filter_policy` chosen at retriever initialization. See init method docstring for more
details.
:param top_k: the maximum number of documents to retrieve.
:returns: a dictionary with the following keys:
- `documents`: A list of documents retrieved from the AzureAISearchDocumentStore.
"""

top_k = top_k or self._top_k
if filters is not None:
applied_filters = apply_filter_policy(self._filter_policy, self._filters, filters)
normalized_filters = normalize_filters(applied_filters)
else:
normalized_filters = ""

try:
docs = self._document_store._bm25_retrieval(
query=query,
filters=normalized_filters,
top_k=top_k,
)
except Exception as e:
raise e

return {"documents": docs}
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import logging
from typing import Any, Dict, List, Optional, Union

from haystack import Document, component, default_from_dict, default_to_dict
from haystack.document_stores.types import FilterPolicy
from haystack.document_stores.types.filter_policy import apply_filter_policy

from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore, normalize_filters

logger = logging.getLogger(__name__)


@component
class AzureAISearchHybridRetriever:
"""
Retrieves documents from the AzureAISearchDocumentStore using a hybrid (vector + BM25) retrieval.
Must be connected to the AzureAISearchDocumentStore to run.
"""

def __init__(
self,
*,
document_store: AzureAISearchDocumentStore,
filters: Optional[Dict[str, Any]] = None,
top_k: int = 10,
filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE,
):
"""
Create the AzureAISearchHybridRetriever component.
:param document_store: An instance of AzureAISearchDocumentStore to use with the Retriever.
:param filters: Filters applied when fetching documents from the Document Store.
Filters are applied during the hybrid search to ensure the Retriever returns
`top_k` matching documents.
:param top_k: Maximum number of documents to return.
:filter_policy: Policy to determine how filters are applied. Possible options:
"""
self._filters = filters or {}
self._top_k = top_k
self._document_store = document_store
self._filter_policy = (
filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy)
)

if not isinstance(document_store, AzureAISearchDocumentStore):
message = "document_store must be an instance of AzureAISearchDocumentStore"
raise Exception(message)

def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.
:returns:
Dictionary with serialized data.
"""
return default_to_dict(
self,
filters=self._filters,
top_k=self._top_k,
document_store=self._document_store.to_dict(),
filter_policy=self._filter_policy.value,
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "AzureAISearchHybridRetriever":
"""
Deserializes the component from a dictionary.
:param data:
Dictionary to deserialize from.
:returns:
Deserialized component.
"""
data["init_parameters"]["document_store"] = AzureAISearchDocumentStore.from_dict(
data["init_parameters"]["document_store"]
)

# Pipelines serialized with old versions of the component might not
# have the filter_policy field.
if "filter_policy" in data["init_parameters"]:
data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(data["init_parameters"]["filter_policy"])
return default_from_dict(cls, data)

@component.output_types(documents=List[Document])
def run(
self, query: str, query_embedding: List[float], filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None
):
"""Retrieve documents from the AzureAISearchDocumentStore.
:param query: Text of the query.
:param query_embedding: floats representing the query embedding
:param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on
the `filter_policy` chosen at retriever initialization. See init method docstring for more
details.
:param top_k: the maximum number of documents to retrieve.
:returns: a dictionary with the following keys:
- `documents`: A list of documents retrieved from the AzureAISearchDocumentStore.
"""

top_k = top_k or self._top_k
if filters is not None:
applied_filters = apply_filter_policy(self._filter_policy, self._filters, filters)
normalized_filters = normalize_filters(applied_filters)
else:
normalized_filters = ""

try:
docs = self._document_store._hybrid_retrieval(
query=query,
query_embedding=query_embedding,
filters=normalized_filters,
top_k=top_k,
)
except Exception as e:
raise e

return {"documents": docs}
Original file line number Diff line number Diff line change
Expand Up @@ -438,3 +438,55 @@ def _embedding_retrieval(
result = self.client.search(search_text=None, vector_queries=[vector_query], select=fields, filter=filters)
azure_docs = list(result)
return self._convert_search_result_to_documents(azure_docs)

def _bm25_retrieval(
self,
query: str,
top_k: int = 10,
fields: Optional[List[str]] = None,
filters: Optional[Dict[str, Any]] = None,
) -> List[Document]:
"""
Retrieves documents that are most similar to `query`, using the BM25 algorithm
This method is not meant to be part of the public interface of
`AzureAISearchDocumentStore` nor called directly.
`AzureAISearchBM25Retriever` uses this method directly and is the public interface for it.
:param query: Text of the query.
:param filters: Filters applied to the retrieved Documents. Defaults to None.
Filters are applied during the approximate kNN search to ensure that top_k matching documents are returned.
:param top_k: Maximum number of Documents to return, defaults to 10
:raises ValueError: If `query` is an empty string
:returns: List of Document that are most similar to `query`
"""

if query is None:
msg = "query must not be None"
raise ValueError(msg)

result = self.client.search(search_text=query, select=fields, filter=filters, top=top_k, query_type="simple")
azure_docs = list(result)
return self._convert_search_result_to_documents(azure_docs)

def _hybrid_retrieval(
self,
query: str,
query_embedding: List[float],
top_k: int = 10,
fields: Optional[List[str]] = None,
filters: Optional[Dict[str, Any]] = None,
) -> List[Document]:

if query is None:
msg = "query must not be None"
raise ValueError(msg)
if not query_embedding:
msg = "query_embedding must be a non-empty list of floats"
raise ValueError(msg)

vector_query = VectorizedQuery(vector=query_embedding, k_nearest_neighbors=top_k, fields="embedding")
result = self.client.search(search_text=query, vector_queries=[vector_query], select=fields, filter=filters, top=top_k, query_type="simple")
azure_docs = list(result)
return self._convert_search_result_to_documents(azure_docs)

0 comments on commit 252d27d

Please sign in to comment.