Skip to content

Commit

Permalink
Add embedding retriever and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Amnah199 committed Oct 14, 2024
1 parent 3ebaa49 commit e5227a6
Show file tree
Hide file tree
Showing 8 changed files with 324 additions and 123 deletions.
16 changes: 8 additions & 8 deletions integrations/azure_ai_search/pydoc/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ loaders:
- type: haystack_pydoc_tools.loaders.CustomPythonLoader
search_path: [../src]
modules: [
"haystack_integrations.components.retrievers.opensearch.bm25_retriever",
"haystack_integrations.components.retrievers.opensearch.embedding_retriever",
"haystack_integrations.document_stores.opensearch.document_store",
"haystack_integrations.document_stores.opensearch.filters",
"haystack_integrations.components.retrievers.azure_ai_search.bm25_retriever",
"haystack_integrations.components.retrievers.azure_ai_search.embedding_retriever",
"haystack_integrations.document_stores.azure_ai_search.document_store",
"haystack_integrations.document_stores.azure_ai_search.filters",
]
ignore_when_discovered: ["__init__"]
processors:
Expand All @@ -18,15 +18,15 @@ processors:
- type: crossref
renderer:
type: haystack_pydoc_tools.renderers.ReadmeIntegrationRenderer
excerpt: OpenSearch integration for Haystack
excerpt: Azure AI Search integration for Haystack
category_slug: integrations-api
title: OpenSearch
slug: integrations-opensearch
title: Azure AI Search
slug: integrations-azure_ai_search
order: 180
markdown:
descriptive_class_title: false
classdef_code_block: false
descriptive_module_title: true
add_method_class_prefix: true
add_member_class_prefix: false
filename: _readme_opensearch.md
filename: _readme_azure_ai_search.md
4 changes: 2 additions & 2 deletions integrations/azure_ai_search/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ classifiers = [
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]
dependencies = ["haystack-ai", "azure-search-documents>=11.5"]
dependencies = ["haystack-ai", "azure-search-documents>=11.5", "azure-identity"]

[project.urls]
Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/opensearch#readme"
Expand Down Expand Up @@ -154,5 +154,5 @@ minversion = "6.0"
markers = ["unit: unit tests", "integration: integration tests"]

[[tool.mypy.overrides]]
module = ["haystack.*", "haystack_integrations.*", "pytest.*", "azure-ai-search.*"]
module = ["haystack.*", "haystack_integrations.*", "pytest.*", "azure-ai-search.*", "azure.identity.*"]
ignore_missing_imports = true

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .embedding_retriever import AzureAISearchEmbeddingRetriever

__all__ = ["AzureAISearchEmbeddingRetriever"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import logging
from typing import Any, Dict, List, Optional, Union

from haystack import Document, component, default_from_dict, default_to_dict
from haystack.document_stores.types import FilterPolicy
from haystack.document_stores.types.filter_policy import apply_filter_policy

from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore

logger = logging.getLogger(__name__)


@component
class AzureAISearchEmbeddingRetriever:
"""
Retrieves documents from the AzureAISearchDocumentStore using a vector similarity metric.
Must be connected to the AzureAISearchDocumentStore to run.
"""

def __init__(
self,
*,
document_store: AzureAISearchDocumentStore,
filters: Optional[Dict[str, Any]] = None,
top_k: int = 10,
filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE,
raise_on_failure: bool = True,
):
"""
Create the AzureAISearchEmbeddingRetriever component.
:param document_store: An instance of AzureAISearchDocumentStore to use with the Retriever.
:param filters: Filters applied when fetching documents from the Document Store.
Filters are applied during the approximate kNN search to ensure the Retriever returns
`top_k` matching documents.
:param top_k: Maximum number of documents to return.
:filter_policy: Policy to determine how filters are applied. Possible options:
"""
self._filters = filters or {}
self._top_k = top_k
self._document_store = document_store
self._filter_policy = (
filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy)
)
self._raise_on_failure = raise_on_failure

if not isinstance(document_store, AzureAISearchDocumentStore):
message = "document_store must be an instance of AzureAISearchDocumentStore"
raise Exception(message)

def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.
:returns:
Dictionary with serialized data.
"""
return default_to_dict(
self,
filters=self._filters,
top_k=self._top_k,
document_store=self._document_store.to_dict(),
filter_policy=self._filter_policy.value,
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "AzureAISearchEmbeddingRetriever":
"""
Deserializes the component from a dictionary.
:param data:
Dictionary to deserialize from.
:returns:
Deserialized component.
"""
data["init_parameters"]["document_store"] = AzureAISearchDocumentStore.from_dict(
data["init_parameters"]["document_store"]
)

# Pipelines serialized with old versions of the component might not
# have the filter_policy field.
if "filter_policy" in data["init_parameters"]:
data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(data["init_parameters"]["filter_policy"])
return default_from_dict(cls, data)

@component.output_types(documents=List[Document])
def run(self, query_embedding: List[float], filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None):
"""Retrieve documents from the AzureAISearchDocumentStore.
:param query_embedding: floats representing the query embedding
:param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on
the `filter_policy` chosen at retriever initialization. See init method docstring for more
details.
:param top_k: the maximum number of documents to retrieve.
:returns: a dictionary with the following keys:
- `documents`: A list of documents retrieved from the AzureAISearchDocumentStore.
"""
filters = apply_filter_policy(self._filter_policy, self._filters, filters)
top_k = top_k or self._top_k
if filters is None:
filters = self._filters
if top_k is None:
top_k = self._top_k

docs: List[Document] = []

try:
docs = self._document_store._embedding_retrieval(
query_embedding=query_embedding,
filters=filters,
top_k=top_k,
)
except Exception as e:
if self._raise_on_failure:
raise e
else:
logger.warning(
"An error during embedding retrieval occurred and will be ignored by returning empty results: %s",
str(e),
exc_info=True,
)

return {"documents": docs}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
VectorSearchAlgorithmMetric,
VectorSearchProfile,
)
from azure.search.documents.models import VectorizedQuery
from haystack import default_from_dict, default_to_dict
from haystack.dataclasses import Document
from haystack.document_stores.errors import DuplicateDocumentError
Expand Down Expand Up @@ -61,7 +62,7 @@ def __init__(
api_key: Secret = Secret.from_env_var("AZURE_SEARCH_API_KEY", strict=False),
azure_endpoint: Secret = Secret.from_env_var("AZURE_SEARCH_SERVICE_ENDPOINT", strict=False),
index_name: str = "default",
embedding_dimension: int = 768, # whats a better default value
embedding_dimension: int = 768,
metadata_fields: Optional[Dict[str, type]] = None,
vector_search_configuration: VectorSearch = None,
create_index: bool = True,
Expand Down Expand Up @@ -102,6 +103,7 @@ def __init__(
self._azure_endpoint = azure_endpoint
self._index_name = index_name
self._embedding_dimension = embedding_dimension
self._dummy_vector = [-10.0] * self._embedding_dimension
self._metadata_fields = metadata_fields
self._vector_search_configuration = vector_search_configuration or DEFAULT_VECTOR_SEARCH
self._create_index = create_index
Expand Down Expand Up @@ -149,6 +151,7 @@ def create_index(self, index_name: str, **kwargs) -> None:
name="embedding",
type=SearchFieldDataType.Collection(SearchFieldDataType.Single),
searchable=True,
hidden=False,
vector_search_dimensions=self._embedding_dimension,
vector_search_profile_name="default-vector-config",
),
Expand Down Expand Up @@ -218,19 +221,20 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D
:return: the number of documents added to index.
"""

if len(documents) > 0:
if not isinstance(documents[0], Document):
msg = "param 'documents' must contain a list of objects of type Document"
raise ValueError(msg)

def _convert_input_document(documents: Document):
document_dict = asdict(documents)
if not isinstance(document_dict["id"], str):
msg = f"Document id {document_dict['id']} is not a string, "
raise Exception(msg)
index_document = self._default_index_mapping(document_dict)

return index_document

if len(documents) > 0:
if not isinstance(documents[0], Document):
msg = "param 'documents' must contain a list of objects of type Document"
raise ValueError(msg)

documents_to_write = []
for doc in documents:
try:
Expand Down Expand Up @@ -343,12 +347,10 @@ def _default_index_mapping(self, document: Dict[str, Any]) -> Dict[str, Any]:

keys_to_remove = ["dataframe", "blob", "sparse_embedding", "score"]
index_document = {k: v for k, v in document.items() if k not in keys_to_remove}

metadata = index_document.pop("meta", None)
for key, value in metadata.items():
index_document[key] = value
if index_document["embedding"] is None:
self._dummy_vector = [-10.0] * self._embedding_dimension
index_document["embedding"] = self._dummy_vector

return index_document
Expand Down Expand Up @@ -376,3 +378,37 @@ def _map_metadata_field_types(self, metadata: Dict[str, type]) -> Dict[str, str]
metadata_field_mapping[key] = field_type

return metadata_field_mapping

def _embedding_retrieval(
self,
query_embedding: List[float],
*,
top_k: int = 10,
fields: Optional[List[str]] = None,
filters: Optional[Dict[str, Any]] = None, # TODO will be used in the future
) -> List[Document]:
"""
Retrieves documents that are most similar to the query embedding using a vector similarity metric.
It uses the vector configuration of the document store. By default it uses the HNSW algorithm with cosine similarity.
This method is not meant to be part of the public interface of
`AzureAISearchDocumentStore` nor called directly.
`AzureAISearchEmbeddingRetriever` uses this method directly and is the public interface for it.
:param query_embedding: Embedding of the query.
:param filters: Filters applied to the retrieved Documents. Defaults to None.
Filters are applied during the approximate kNN search to ensure that top_k matching documents are returned.
:param top_k: Maximum number of Documents to return, defaults to 10
:raises ValueError: If `query_embedding` is an empty list
:returns: List of Document that are most similar to `query_embedding`
"""

if not query_embedding:
msg = "query_embedding must be a non-empty list of floats"
raise ValueError(msg)

vector_query = VectorizedQuery(vector=query_embedding, k_nearest_neighbors=3, fields="embedding")
result = self.client.search(search_text=None, vector_queries=[vector_query], select=fields, top=top_k)
azure_docs = list(result)
return self._convert_search_result_to_documents(azure_docs)
Loading

0 comments on commit e5227a6

Please sign in to comment.