Skip to content

Commit

Permalink
feat: defer the database connection to when it's needed (#802)
Browse files Browse the repository at this point in the history
* feat: defer the database connection to when it's needed

* remove unneeded noqa

* fix fixture

* trigger the connection before asserting

* trigger connection

* make also serialization lazy

* remove copypasta leftovers
  • Loading branch information
masci authored Jun 10, 2024
1 parent 5bc08df commit f70664d
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -139,49 +139,72 @@ def __init__(
:param grpc_secure:
Whether to use a secure channel for the underlying gRPC API.
"""
self._url = url
self._auth_client_secret = auth_client_secret
self._additional_headers = additional_headers
self._embedded_options = embedded_options
self._additional_config = additional_config
self._grpc_port = grpc_port
self._grpc_secure = grpc_secure
self._client = None
self._collection = None
# Store the connection settings dictionary
self._collection_settings = collection_settings or {
"class": "Default",
"invertedIndexConfig": {"indexNullState": True},
"properties": DOCUMENT_COLLECTION_PROPERTIES,
}
self._clean_connection_settings()

def _clean_connection_settings(self):
# Set the class if not set
_class_name = self._collection_settings.get("class", "Default")
_class_name = _class_name[0].upper() + _class_name[1:]
self._collection_settings["class"] = _class_name
# Set the properties if they're not set
self._collection_settings["properties"] = self._collection_settings.get(
"properties", DOCUMENT_COLLECTION_PROPERTIES
)

@property
def client(self):
if self._client:
return self._client

# proxies, timeout_config, trust_env are part of additional_config now
# startup_period has been removed
self._client = weaviate.WeaviateClient(
connection_params=(
weaviate.connect.base.ConnectionParams.from_url(url=url, grpc_port=grpc_port, grpc_secure=grpc_secure)
if url
weaviate.connect.base.ConnectionParams.from_url(
url=self._url, grpc_port=self._grpc_port, grpc_secure=self._grpc_secure
)
if self._url
else None
),
auth_client_secret=auth_client_secret.resolve_value() if auth_client_secret else None,
additional_config=additional_config,
additional_headers=additional_headers,
embedded_options=embedded_options,
auth_client_secret=self._auth_client_secret.resolve_value() if self._auth_client_secret else None,
additional_config=self._additional_config,
additional_headers=self._additional_headers,
embedded_options=self._embedded_options,
skip_init_checks=False,
)

self._client.connect()

# Test connection, it will raise an exception if it fails.
self._client.collections._get_all(simple=True)
if not self._client.collections.exists(self._collection_settings["class"]):
self._client.collections.create_from_dict(self._collection_settings)

if collection_settings is None:
collection_settings = {
"class": "Default",
"invertedIndexConfig": {"indexNullState": True},
"properties": DOCUMENT_COLLECTION_PROPERTIES,
}
else:
# Set the class if not set
_class_name = collection_settings.get("class", "Default")
_class_name = _class_name[0].upper() + _class_name[1:]
collection_settings["class"] = _class_name
# Set the properties if they're not set
collection_settings["properties"] = collection_settings.get("properties", DOCUMENT_COLLECTION_PROPERTIES)
return self._client

if not self._client.collections.exists(collection_settings["class"]):
self._client.collections.create_from_dict(collection_settings)
@property
def collection(self):
if self._collection:
return self._collection

self._url = url
self._collection_settings = collection_settings
self._auth_client_secret = auth_client_secret
self._additional_headers = additional_headers
self._embedded_options = embedded_options
self._additional_config = additional_config
self._collection = self._client.collections.get(collection_settings["class"])
client = self.client
self._collection = client.collections.get(self._collection_settings["class"])
return self._collection

def to_dict(self) -> Dict[str, Any]:
"""
Expand Down Expand Up @@ -230,7 +253,7 @@ def count_documents(self) -> int:
"""
Returns the number of documents present in the DocumentStore.
"""
total = self._collection.aggregate.over_all(total_count=True).total_count
total = self.collection.aggregate.over_all(total_count=True).total_count
return total if total else 0

def _to_data_object(self, document: Document) -> Dict[str, Any]:
Expand Down Expand Up @@ -302,16 +325,16 @@ def _to_document(self, data: DataObject[Dict[str, Any], None]) -> Document:
return Document.from_dict(document_data)

def _query(self) -> List[Dict[str, Any]]:
properties = [p.name for p in self._collection.config.get().properties]
properties = [p.name for p in self.collection.config.get().properties]
try:
result = self._collection.iterator(include_vector=True, return_properties=properties)
result = self.collection.iterator(include_vector=True, return_properties=properties)
except weaviate.exceptions.WeaviateQueryError as e:
msg = f"Failed to query documents in Weaviate. Error: {e.message}"
raise DocumentStoreError(msg) from e
return result

def _query_with_filters(self, filters: Dict[str, Any]) -> List[Dict[str, Any]]:
properties = [p.name for p in self._collection.config.get().properties]
properties = [p.name for p in self.collection.config.get().properties]
# When querying with filters we need to paginate using limit and offset as using
# a cursor with after is not possible. See the official docs:
# https://weaviate.io/developers/weaviate/api/graphql/additional-operators#cursor-with-after
Expand All @@ -327,7 +350,7 @@ def _query_with_filters(self, filters: Dict[str, Any]) -> List[Dict[str, Any]]:
# Keep querying until we get all documents matching the filters
while partial_result is None or len(partial_result.objects) == DEFAULT_QUERY_LIMIT:
try:
partial_result = self._collection.query.fetch_objects(
partial_result = self.collection.query.fetch_objects(
filters=convert_filters(filters),
include_vector=True,
limit=DEFAULT_QUERY_LIMIT,
Expand Down Expand Up @@ -365,19 +388,19 @@ def _batch_write(self, documents: List[Document]) -> int:
Raises in case of errors.
"""

with self._client.batch.dynamic() as batch:
with self.client.batch.dynamic() as batch:
for doc in documents:
if not isinstance(doc, Document):
msg = f"Expected a Document, got '{type(doc)}' instead."
raise ValueError(msg)

batch.add_object(
properties=self._to_data_object(doc),
collection=self._collection.name,
collection=self.collection.name,
uuid=generate_uuid5(doc.id),
vector=doc.embedding,
)
if failed_objects := self._client.batch.failed_objects:
if failed_objects := self.client.batch.failed_objects:
# We fallback to use the UUID if the _original_id is not present, this is just to be
mapped_objects = {}
for obj in failed_objects:
Expand Down Expand Up @@ -413,12 +436,12 @@ def _write(self, documents: List[Document], policy: DuplicatePolicy) -> int:
msg = f"Expected a Document, got '{type(doc)}' instead."
raise ValueError(msg)

if policy == DuplicatePolicy.SKIP and self._collection.data.exists(uuid=generate_uuid5(doc.id)):
if policy == DuplicatePolicy.SKIP and self.collection.data.exists(uuid=generate_uuid5(doc.id)):
# This Document already exists, we skip it
continue

try:
self._collection.data.insert(
self.collection.data.insert(
uuid=generate_uuid5(doc.id),
properties=self._to_data_object(doc),
vector=doc.embedding,
Expand Down Expand Up @@ -454,13 +477,13 @@ def delete_documents(self, document_ids: List[str]) -> None:
:param document_ids: The object_ids to delete.
"""
weaviate_ids = [generate_uuid5(doc_id) for doc_id in document_ids]
self._collection.data.delete_many(where=weaviate.classes.query.Filter.by_id().contains_any(weaviate_ids))
self.collection.data.delete_many(where=weaviate.classes.query.Filter.by_id().contains_any(weaviate_ids))

def _bm25_retrieval(
self, query: str, filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None
) -> List[Document]:
properties = [p.name for p in self._collection.config.get().properties]
result = self._collection.query.bm25(
properties = [p.name for p in self.collection.config.get().properties]
result = self.collection.query.bm25(
query=query,
filters=convert_filters(filters) if filters else None,
limit=top_k,
Expand All @@ -484,8 +507,8 @@ def _embedding_retrieval(
msg = "Can't use 'distance' and 'certainty' parameters together"
raise ValueError(msg)

properties = [p.name for p in self._collection.config.get().properties]
result = self._collection.query.near_vector(
properties = [p.name for p in self.collection.config.get().properties]
result = self.collection.query.near_vector(
near_vector=query_embedding,
distance=distance,
certainty=certainty,
Expand Down
17 changes: 13 additions & 4 deletions integrations/weaviate/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@
)


@patch("haystack_integrations.document_stores.weaviate.document_store.weaviate.WeaviateClient")
def test_init_is_lazy(_mock_client):
_ = WeaviateDocumentStore()
_mock_client.assert_not_called()


@pytest.mark.integration
class TestWeaviateDocumentStore(CountDocumentsTest, WriteDocumentsTest, DeleteDocumentsTest, FilterDocumentsTest):
@pytest.fixture
Expand All @@ -57,7 +63,7 @@ def document_store(self, request) -> WeaviateDocumentStore:
collection_settings=collection_settings,
)
yield store
store._client.collections.delete(collection_settings["class"])
store.client.collections.delete(collection_settings["class"])

@pytest.fixture
def filterable_docs(self) -> List[Document]:
Expand Down Expand Up @@ -150,12 +156,12 @@ def assert_documents_are_equal(self, received: List[Document], expected: List[Do
assert received_meta.get(key) == expected_meta.get(key)

@patch("haystack_integrations.document_stores.weaviate.document_store.weaviate.WeaviateClient")
def test_init(self, mock_weaviate_client_class, monkeypatch):
def test_connection(self, mock_weaviate_client_class, monkeypatch):
mock_client = MagicMock()
mock_client.collections.exists.return_value = False
mock_weaviate_client_class.return_value = mock_client
monkeypatch.setenv("WEAVIATE_API_KEY", "my_api_key")
WeaviateDocumentStore(
ds = WeaviateDocumentStore(
collection_settings={"class": "My_collection"},
auth_client_secret=AuthApiKey(),
additional_headers={"X-HuggingFace-Api-Key": "MY_HUGGINGFACE_KEY"},
Expand All @@ -170,8 +176,11 @@ def test_init(self, mock_weaviate_client_class, monkeypatch):
),
)

# Verify client is created with correct parameters
# Trigger the actual database connection by accessing the `client` property so we
# can assert the setup was good
_ = ds.client

# Verify client is created with correct parameters
mock_weaviate_client_class.assert_called_once_with(
auth_client_secret=AuthApiKey().resolve_value(),
connection_params=None,
Expand Down

0 comments on commit f70664d

Please sign in to comment.