Skip to content

Commit

Permalink
feat(qdrant): allow payload indexing
Browse files Browse the repository at this point in the history
  • Loading branch information
lambda-science committed Mar 6, 2024
1 parent 2c6b218 commit 12a2a36
Showing 1 changed file with 21 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def __init__(
metadata: Optional[dict] = None,
write_batch_size: int = 100,
scroll_size: int = 10_000,
payload_field_to_index: Optional[List[dict]] = None,
):
super().__init__()

Expand Down Expand Up @@ -130,9 +131,10 @@ def __init__(
self.init_from = init_from
self.wait_result_from_api = wait_result_from_api
self.recreate_index = recreate_index
self.payload_field_to_index = payload_field_to_index

# Make sure the collection is properly set up
self._set_up_collection(index, embedding_dim, recreate_index, similarity)
self._set_up_collection(index, embedding_dim, recreate_index, similarity, payload_field_to_index)

self.embedding_dim = embedding_dim
self.content_field = content_field
Expand Down Expand Up @@ -334,19 +336,35 @@ def _get_distance(self, similarity: str) -> rest.Distance:
)
raise QdrantStoreError(msg) from ke

def _create_payload_index(self, collection_name: str, payload_field_to_index: Optional[List[dict]] = None):
"""
Create payload index for the collection if payload_field_to_index is provided
See: https://qdrant.tech/documentation/concepts/indexing/#payload-index
"""
if payload_field_to_index is not None:
for payload_index in payload_field_to_index:
self.client.create_payload_index(
collection_name=collection_name,
field_name=payload_index["field_name"],
field_schema=payload_index["field_schema"],
)

def _set_up_collection(
self,
collection_name: str,
embedding_dim: int,
recreate_collection: bool, # noqa: FBT001
similarity: str,
payload_field_to_index: Optional[List[dict]] = None,
):
distance = self._get_distance(similarity)

if recreate_collection:
# There is no need to verify the current configuration of that
# collection. It might be just recreated again.
self._recreate_collection(collection_name, distance, embedding_dim)
# Create Payload index if payload_field_to_index is provided
self._create_payload_index(collection_name, payload_field_to_index)
return

try:
Expand All @@ -361,6 +379,8 @@ def _set_up_collection(
# with the remote server UnexpectedResponse / RpcError is raised.
# Until that's unified, we need to catch both.
self._recreate_collection(collection_name, distance, embedding_dim)
# Create Payload index if payload_field_to_index is provided
self._create_payload_index(collection_name, payload_field_to_index)
return

current_distance = collection_info.config.params.vectors.distance
Expand Down

0 comments on commit 12a2a36

Please sign in to comment.