Skip to content

Commit

Permalink
feat(Qdrant): allow payload indexing + on disk vectors (#553)
Browse files Browse the repository at this point in the history
* feat(qdrant): allow payload indexing

* fix(qdrant): fix tests

* fix(qdrant): lint fix

* fix(qdrant): black formatting

* fix(qdrant): rename `payload_field_to_index` to `payload_fields_to_index`
  • Loading branch information
lambda-science authored Mar 7, 2024
1 parent 1032b2c commit 810ad84
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def __init__(
path: Optional[str] = None,
index: str = "Document",
embedding_dim: int = 768,
on_disk: bool = False, # noqa: FBT001, FBT002
content_field: str = "content",
name_field: str = "name",
embedding_field: str = "embedding",
Expand All @@ -84,6 +85,7 @@ def __init__(
metadata: Optional[dict] = None,
write_batch_size: int = 100,
scroll_size: int = 10_000,
payload_fields_to_index: Optional[List[dict]] = None,
):
super().__init__()

Expand Down Expand Up @@ -130,11 +132,13 @@ def __init__(
self.init_from = init_from
self.wait_result_from_api = wait_result_from_api
self.recreate_index = recreate_index
self.payload_fields_to_index = payload_fields_to_index

# Make sure the collection is properly set up
self._set_up_collection(index, embedding_dim, recreate_index, similarity)
self._set_up_collection(index, embedding_dim, recreate_index, similarity, on_disk, payload_fields_to_index)

self.embedding_dim = embedding_dim
self.on_disk = on_disk
self.content_field = content_field
self.name_field = name_field
self.embedding_field = embedding_field
Expand Down Expand Up @@ -334,19 +338,36 @@ def _get_distance(self, similarity: str) -> rest.Distance:
)
raise QdrantStoreError(msg) from ke

def _create_payload_index(self, collection_name: str, payload_fields_to_index: Optional[List[dict]] = None):
"""
Create payload index for the collection if payload_fields_to_index is provided
See: https://qdrant.tech/documentation/concepts/indexing/#payload-index
"""
if payload_fields_to_index is not None:
for payload_index in payload_fields_to_index:
self.client.create_payload_index(
collection_name=collection_name,
field_name=payload_index["field_name"],
field_schema=payload_index["field_schema"],
)

def _set_up_collection(
self,
collection_name: str,
embedding_dim: int,
recreate_collection: bool, # noqa: FBT001
similarity: str,
on_disk: bool = False, # noqa: FBT001, FBT002
payload_fields_to_index: Optional[List[dict]] = None,
):
distance = self._get_distance(similarity)

if recreate_collection:
# There is no need to verify the current configuration of that
# collection. It might be just recreated again.
self._recreate_collection(collection_name, distance, embedding_dim)
self._recreate_collection(collection_name, distance, embedding_dim, on_disk)
# Create Payload index if payload_fields_to_index is provided
self._create_payload_index(collection_name, payload_fields_to_index)
return

try:
Expand All @@ -360,7 +381,9 @@ def _set_up_collection(
# Qdrant local raises ValueError if the collection is not found, but
# with the remote server UnexpectedResponse / RpcError is raised.
# Until that's unified, we need to catch both.
self._recreate_collection(collection_name, distance, embedding_dim)
self._recreate_collection(collection_name, distance, embedding_dim, on_disk)
# Create Payload index if payload_fields_to_index is provided
self._create_payload_index(collection_name, payload_fields_to_index)
return

current_distance = collection_info.config.params.vectors.distance
Expand All @@ -384,11 +407,12 @@ def _set_up_collection(
)
raise ValueError(msg)

def _recreate_collection(self, collection_name: str, distance, embedding_dim: int):
def _recreate_collection(self, collection_name: str, distance, embedding_dim: int, on_disk: bool): # noqa: FBT001
self.client.recreate_collection(
collection_name=collection_name,
vectors_config=rest.VectorParams(
size=embedding_dim,
on_disk=on_disk,
distance=distance,
),
shard_number=self.shard_number,
Expand Down
6 changes: 6 additions & 0 deletions integrations/qdrant/tests/test_dict_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def test_to_dict():
"path": None,
"index": "test",
"embedding_dim": 768,
"on_disk": False,
"content_field": "content",
"name_field": "name",
"embedding_field": "embedding",
Expand All @@ -42,6 +43,7 @@ def test_to_dict():
"metadata": {},
"write_batch_size": 100,
"scroll_size": 10000,
"payload_fields_to_index": None,
},
}

Expand All @@ -57,6 +59,7 @@ def test_from_dict():
"location": ":memory:",
"index": "test",
"embedding_dim": 768,
"on_disk": False,
"content_field": "content",
"name_field": "name",
"embedding_field": "embedding",
Expand All @@ -72,6 +75,7 @@ def test_from_dict():
"metadata": {},
"write_batch_size": 1000,
"scroll_size": 10000,
"payload_fields_to_index": None,
},
}
)
Expand All @@ -82,6 +86,7 @@ def test_from_dict():
document_store.content_field == "content",
document_store.name_field == "name",
document_store.embedding_field == "embedding",
document_store.on_disk is False,
document_store.similarity == "cosine",
document_store.return_embedding is False,
document_store.progress_bar,
Expand All @@ -101,5 +106,6 @@ def test_from_dict():
document_store.write_batch_size == 1000,
document_store.scroll_size == 10000,
document_store.api_key == Secret.from_env_var("ENV_VAR", strict=False),
document_store.payload_fields_to_index is None,
]
)
2 changes: 2 additions & 0 deletions integrations/qdrant/tests/test_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def test_to_dict(self):
"path": None,
"index": "test",
"embedding_dim": 768,
"on_disk": False,
"content_field": "content",
"name_field": "name",
"embedding_field": "embedding",
Expand All @@ -62,6 +63,7 @@ def test_to_dict(self):
"metadata": {},
"write_batch_size": 100,
"scroll_size": 10000,
"payload_fields_to_index": None,
},
},
"filters": None,
Expand Down

0 comments on commit 810ad84

Please sign in to comment.