Skip to content

Commit

Permalink
Updated code based on PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Amnah199 committed Oct 31, 2024
1 parent 23d1e22 commit 3afc3b9
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 219 deletions.
163 changes: 0 additions & 163 deletions integrations/azure_ai_search/.gitignore

This file was deleted.

3 changes: 1 addition & 2 deletions integrations/azure_ai_search/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,13 @@ classifiers = [
"License :: OSI Approved :: Apache Software License",
"Development Status :: 4 - Beta",
"Programming Language :: Python",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]
dependencies = ["haystack-ai", "azure-search-documents>=11.5", "azure-identity", "torch>=1.11.0"]
dependencies = ["haystack-ai", "azure-search-documents>=11.5", "azure-identity"]

[project.urls]
Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/azure_ai_search#readme"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def __init__(
filters: Optional[Dict[str, Any]] = None,
top_k: int = 10,
filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE,
raise_on_failure: bool = True,
):
"""
Create the AzureAISearchEmbeddingRetriever component.
Expand All @@ -44,7 +43,6 @@ def __init__(
self._filter_policy = (
filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy)
)
self._raise_on_failure = raise_on_failure

if not isinstance(document_store, AzureAISearchDocumentStore):
message = "document_store must be an instance of AzureAISearchDocumentStore"
Expand Down Expand Up @@ -113,13 +111,6 @@ def run(self, query_embedding: List[float], filters: Optional[Dict[str, Any]] =
top_k=top_k,
)
except Exception as e:
if self._raise_on_failure:
raise e
logger.warning(
"An error occurred during embedding retrieval and will be ignored, returning empty results: %s",
str(e),
exc_info=True,
)
docs = []
raise e

return {"documents": docs}
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ def __init__(
embedding_dimension: int = 768,
metadata_fields: Optional[Dict[str, type]] = None,
vector_search_configuration: VectorSearch = None,
create_index: bool = True,
**kwargs,
):
"""
Expand Down Expand Up @@ -117,7 +116,6 @@ def __init__(
self._dummy_vector = [-10.0] * self._embedding_dimension
self._metadata_fields = metadata_fields
self._vector_search_configuration = vector_search_configuration or DEFAULT_VECTOR_SEARCH
self._create_index = create_index
self._kwargs = kwargs

@property
Expand All @@ -133,13 +131,13 @@ def client(self) -> SearchClient:
try:
if not self._index_client:
self._index_client = SearchIndexClient(resolved_endpoint, credential, **self._kwargs)
if not self.index_exists(self._index_name):
if not self._index_exists(self._index_name):
# Create a new index if it does not exist
logger.debug(
"The index '%s' does not exist. A new index will be created.",
self._index_name,
)
self.create_index(self._index_name)
self._create_index(self._index_name)
except (HttpResponseError, ClientAuthenticationError) as error:
msg = f"Failed to authenticate with Azure Search: {error}"
raise AzureAISearchDocumentStoreConfigError(msg) from error
Expand All @@ -155,7 +153,7 @@ def client(self) -> SearchClient:

return self._client

def create_index(self, index_name: str, **kwargs) -> None:
def _create_index(self, index_name: str, **kwargs) -> None:
"""
Creates a new search index.
:param index_name: Name of the index to create. If None, the index name from the constructor is used.
Expand Down Expand Up @@ -201,7 +199,6 @@ def to_dict(self) -> Dict[str, Any]:
azure_endpoint=self._azure_endpoint.to_dict() if self._azure_endpoint is not None else None,
api_key=self._api_key.to_dict() if self._api_key is not None else None,
index_name=self._index_name,
create_index=self._create_index,
embedding_dimension=self._embedding_dimension,
metadata_fields=self._metadata_fields,
vector_search_configuration=self._vector_search_configuration.as_dict(),
Expand All @@ -225,14 +222,13 @@ def from_dict(cls, data: Dict[str, Any]) -> "AzureAISearchDocumentStore":
data["init_parameters"]["vector_search_configuration"] = VectorSearch.from_dict(vector_search_configuration)
return default_from_dict(cls, data)

def count_documents(self, **kwargs: Any) -> int:
def count_documents(self) -> int:
"""
Returns how many documents are present in the search index.
:param kwargs: additional keyword parameters.
:returns: list of retrieved documents.
"""
return self.client.get_document_count(**kwargs)
return self.client.get_document_count()

def write_documents(self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.FAIL) -> int:
"""
Expand Down Expand Up @@ -292,7 +288,7 @@ def delete_documents(self, document_ids: List[str]) -> None:
def get_documents_by_id(self, document_ids: List[str]) -> List[Document]:
return self._convert_search_result_to_documents(self._get_raw_documents_by_id(document_ids))

def search_documents(self, search_text: Optional[str] = "*", top_k: Optional[int] = 10) -> List[Document]:
def search_documents(self, search_text: str = "*", top_k: int = 10) -> List[Document]:
"""
Returns all documents that match the provided search_text.
If search_text is None, returns all documents.
Expand Down Expand Up @@ -345,7 +341,7 @@ def _convert_search_result_to_documents(self, azure_docs: List[Dict[str, Any]])
documents.append(doc)
return documents

def index_exists(self, index_name: Optional[str]) -> bool:
def _index_exists(self, index_name: Optional[str]) -> bool:
"""
Check if the index exists in the Azure AI Search service.
Expand Down Expand Up @@ -403,14 +399,19 @@ def _map_metadata_field_types(self, metadata: Dict[str, type]) -> Dict[str, str]

for key, value_type in metadata.items():

# Azure Search index only allows field names starting with letters
field_name = next((key[i:] for i, char in enumerate(key) if char.isalpha()), key)
if not key[0].isalpha():
msg = (
f"Azure Search index only allows field names starting with letters. "
f"Invalid key: {key} will be dropped."
)
logger.warning(msg)
continue

field_type = type_mapping.get(value_type)
if not field_type:
error_message = f"Unsupported field type for key '{field_name}': {value_type}"
error_message = f"Unsupported field type for key '{key}': {value_type}"
raise ValueError(error_message)
metadata_field_mapping[field_name] = field_type
metadata_field_mapping[key] = field_type

return metadata_field_mapping

Expand Down
Loading

0 comments on commit 3afc3b9

Please sign in to comment.