Skip to content

Commit

Permalink
refactor: Qdrant Query API (#1025)
Browse files Browse the repository at this point in the history
  • Loading branch information
Anush008 authored Aug 27, 2024
1 parent 8544b56 commit 8edb9e8
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from qdrant_client import grpc
from qdrant_client.http import models as rest
from qdrant_client.http.exceptions import UnexpectedResponse
from qdrant_client.hybrid.fusion import reciprocal_rank_fusion
from tqdm import tqdm

from .converters import (
Expand Down Expand Up @@ -537,20 +536,18 @@ def _query_by_sparse(
qdrant_filters = convert_filters_to_qdrant(filters)
query_indices = query_sparse_embedding.indices
query_values = query_sparse_embedding.values
points = self.client.search(
points = self.client.query_points(
collection_name=self.index,
query_vector=rest.NamedSparseVector(
name=SPARSE_VECTORS_NAME,
vector=rest.SparseVector(
indices=query_indices,
values=query_values,
),
query=rest.SparseVector(
indices=query_indices,
values=query_values,
),
using=SPARSE_VECTORS_NAME,
query_filter=qdrant_filters,
limit=top_k,
with_vectors=return_embedding,
score_threshold=score_threshold,
)
).points
results = [
convert_qdrant_point_to_haystack_document(point, use_sparse_embeddings=self.use_sparse_embeddings)
for point in points
Expand Down Expand Up @@ -588,17 +585,15 @@ def _query_by_embedding(
"""
qdrant_filters = convert_filters_to_qdrant(filters)

points = self.client.search(
points = self.client.query_points(
collection_name=self.index,
query_vector=rest.NamedVector(
name=DENSE_VECTORS_NAME if self.use_sparse_embeddings else "",
vector=query_embedding,
),
query=query_embedding,
using=DENSE_VECTORS_NAME if self.use_sparse_embeddings else None,
query_filter=qdrant_filters,
limit=top_k,
with_vectors=return_embedding,
score_threshold=score_threshold,
)
).points
results = [
convert_qdrant_point_to_haystack_document(point, use_sparse_embeddings=self.use_sparse_embeddings)
for point in points
Expand Down Expand Up @@ -655,46 +650,34 @@ def _query_hybrid(

qdrant_filters = convert_filters_to_qdrant(filters)

sparse_request = rest.SearchRequest(
vector=rest.NamedSparseVector(
name=SPARSE_VECTORS_NAME,
vector=rest.SparseVector(
indices=query_sparse_embedding.indices,
values=query_sparse_embedding.values,
),
),
filter=qdrant_filters,
limit=top_k,
with_payload=True,
with_vector=return_embedding,
score_threshold=score_threshold,
)

dense_request = rest.SearchRequest(
vector=rest.NamedVector(
name=DENSE_VECTORS_NAME,
vector=query_embedding,
),
filter=qdrant_filters,
limit=top_k,
with_payload=True,
with_vector=return_embedding,
)

try:
dense_request_response, sparse_request_response = self.client.search_batch(
collection_name=self.index, requests=[dense_request, sparse_request]
)
points = self.client.query_points(
collection_name=self.index,
prefetch=[
rest.Prefetch(
query=rest.SparseVector(
indices=query_sparse_embedding.indices,
values=query_sparse_embedding.values,
),
using=SPARSE_VECTORS_NAME,
filter=qdrant_filters,
),
rest.Prefetch(
query=query_embedding,
using=DENSE_VECTORS_NAME,
filter=qdrant_filters,
),
],
query=rest.FusionQuery(fusion=rest.Fusion.RRF),
limit=top_k,
score_threshold=score_threshold,
with_payload=True,
with_vectors=return_embedding,
).points
except Exception as e:
msg = "Error during hybrid search"
raise QdrantStoreError(msg) from e

try:
points = reciprocal_rank_fusion(responses=[dense_request_response, sparse_request_response], limit=top_k)
except Exception as e:
msg = "Error while applying Reciprocal Rank Fusion"
raise QdrantStoreError(msg) from e

results = [convert_qdrant_point_to_haystack_document(point, use_sparse_embeddings=True) for point in points]

return results
Expand Down
14 changes: 1 addition & 13 deletions integrations/qdrant/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,19 +114,7 @@ def test_query_hybrid_search_batch_failure(self):
sparse_embedding = SparseEmbedding(indices=[0, 1, 2, 3], values=[0.1, 0.8, 0.05, 0.33])
embedding = [0.1] * 768

with patch.object(document_store.client, "search_batch", side_effect=Exception("search_batch error")):
with patch.object(document_store.client, "query_points", side_effect=Exception("query_points")):

with pytest.raises(QdrantStoreError):
document_store._query_hybrid(query_sparse_embedding=sparse_embedding, query_embedding=embedding)

@patch("haystack_integrations.document_stores.qdrant.document_store.reciprocal_rank_fusion")
def test_query_hybrid_reciprocal_rank_fusion_failure(self, mocked_fusion):
document_store = QdrantDocumentStore(location=":memory:", use_sparse_embeddings=True)

sparse_embedding = SparseEmbedding(indices=[0, 1, 2, 3], values=[0.1, 0.8, 0.05, 0.33])
embedding = [0.1] * 768

mocked_fusion.side_effect = Exception("reciprocal_rank_fusion error")

with pytest.raises(QdrantStoreError):
document_store._query_hybrid(query_sparse_embedding=sparse_embedding, query_embedding=embedding)

0 comments on commit 8edb9e8

Please sign in to comment.