-
Notifications
You must be signed in to change notification settings - Fork 126
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
324 additions
and
123 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
105 changes: 0 additions & 105 deletions
105
...ch/src/haystack_integrations/components/azure_ai_search/retrievers/embedding_retriever.py
This file was deleted.
Oops, something went wrong.
3 changes: 3 additions & 0 deletions
3
...ure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .embedding_retriever import AzureAISearchEmbeddingRetriever | ||
|
||
__all__ = ["AzureAISearchEmbeddingRetriever"] |
126 changes: 126 additions & 0 deletions
126
...ch/src/haystack_integrations/components/retrievers/azure_ai_search/embedding_retriever.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
import logging | ||
from typing import Any, Dict, List, Optional, Union | ||
|
||
from haystack import Document, component, default_from_dict, default_to_dict | ||
from haystack.document_stores.types import FilterPolicy | ||
from haystack.document_stores.types.filter_policy import apply_filter_policy | ||
|
||
from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@component | ||
class AzureAISearchEmbeddingRetriever: | ||
""" | ||
Retrieves documents from the AzureAISearchDocumentStore using a vector similarity metric. | ||
Must be connected to the AzureAISearchDocumentStore to run. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
*, | ||
document_store: AzureAISearchDocumentStore, | ||
filters: Optional[Dict[str, Any]] = None, | ||
top_k: int = 10, | ||
filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE, | ||
raise_on_failure: bool = True, | ||
): | ||
""" | ||
Create the AzureAISearchEmbeddingRetriever component. | ||
:param document_store: An instance of AzureAISearchDocumentStore to use with the Retriever. | ||
:param filters: Filters applied when fetching documents from the Document Store. | ||
Filters are applied during the approximate kNN search to ensure the Retriever returns | ||
`top_k` matching documents. | ||
:param top_k: Maximum number of documents to return. | ||
:filter_policy: Policy to determine how filters are applied. Possible options: | ||
""" | ||
self._filters = filters or {} | ||
self._top_k = top_k | ||
self._document_store = document_store | ||
self._filter_policy = ( | ||
filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) | ||
) | ||
self._raise_on_failure = raise_on_failure | ||
|
||
if not isinstance(document_store, AzureAISearchDocumentStore): | ||
message = "document_store must be an instance of AzureAISearchDocumentStore" | ||
raise Exception(message) | ||
|
||
def to_dict(self) -> Dict[str, Any]: | ||
""" | ||
Serializes the component to a dictionary. | ||
:returns: | ||
Dictionary with serialized data. | ||
""" | ||
return default_to_dict( | ||
self, | ||
filters=self._filters, | ||
top_k=self._top_k, | ||
document_store=self._document_store.to_dict(), | ||
filter_policy=self._filter_policy.value, | ||
) | ||
|
||
@classmethod | ||
def from_dict(cls, data: Dict[str, Any]) -> "AzureAISearchEmbeddingRetriever": | ||
""" | ||
Deserializes the component from a dictionary. | ||
:param data: | ||
Dictionary to deserialize from. | ||
:returns: | ||
Deserialized component. | ||
""" | ||
data["init_parameters"]["document_store"] = AzureAISearchDocumentStore.from_dict( | ||
data["init_parameters"]["document_store"] | ||
) | ||
|
||
# Pipelines serialized with old versions of the component might not | ||
# have the filter_policy field. | ||
if "filter_policy" in data["init_parameters"]: | ||
data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(data["init_parameters"]["filter_policy"]) | ||
return default_from_dict(cls, data) | ||
|
||
@component.output_types(documents=List[Document]) | ||
def run(self, query_embedding: List[float], filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None): | ||
"""Retrieve documents from the AzureAISearchDocumentStore. | ||
:param query_embedding: floats representing the query embedding | ||
:param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on | ||
the `filter_policy` chosen at retriever initialization. See init method docstring for more | ||
details. | ||
:param top_k: the maximum number of documents to retrieve. | ||
:returns: a dictionary with the following keys: | ||
- `documents`: A list of documents retrieved from the AzureAISearchDocumentStore. | ||
""" | ||
filters = apply_filter_policy(self._filter_policy, self._filters, filters) | ||
top_k = top_k or self._top_k | ||
if filters is None: | ||
filters = self._filters | ||
if top_k is None: | ||
top_k = self._top_k | ||
|
||
docs: List[Document] = [] | ||
|
||
try: | ||
docs = self._document_store._embedding_retrieval( | ||
query_embedding=query_embedding, | ||
filters=filters, | ||
top_k=top_k, | ||
) | ||
except Exception as e: | ||
if self._raise_on_failure: | ||
raise e | ||
else: | ||
logger.warning( | ||
"An error during embedding retrieval occurred and will be ignored by returning empty results: %s", | ||
str(e), | ||
exc_info=True, | ||
) | ||
|
||
return {"documents": docs} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.