Skip to content

Commit

Permalink
Make Document Stores initially skip SparseEmbedding (#606)
Browse files Browse the repository at this point in the history
  • Loading branch information
anakin87 authored Mar 20, 2024
1 parent 2195623 commit b4ff369
Show file tree
Hide file tree
Showing 8 changed files with 94 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,15 @@ def _convert_input_document(document: Union[dict, Document]):
document_dict["dataframe"] = document_dict.pop("dataframe").to_json()
if embedding := document_dict.pop("embedding", []):
document_dict["$vector"] = embedding
if "sparse_embedding" in document_dict:
sparse_embedding = document_dict.pop("sparse_embedding", None)
if sparse_embedding:
logger.warning(
"Document %s has the `sparse_embedding` field set,"
"but storing sparse embeddings in Astra is not currently supported."
"The `sparse_embedding` field will be ignored.",
document_dict["_id"],
)

return document_dict

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,14 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D
if doc.embedding is not None:
data["embeddings"] = [doc.embedding]

if hasattr(doc, "sparse_embedding") and doc.sparse_embedding is not None:
logger.warning(
"Document %s has the `sparse_embedding` field set,"
"but storing sparse embeddings in Chroma is not currently supported."
"The `sparse_embedding` field will be ignored.",
doc.id,
)

self._collection.add(**data)

return len(documents)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,16 +216,30 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D
policy = DuplicatePolicy.FAIL

action = "index" if policy == DuplicatePolicy.OVERWRITE else "create"
documents_written, errors = helpers.bulk(
client=self._client,
actions=(

elasticsearch_actions = []
for doc in documents:
doc_dict = doc.to_dict()
if "sparse_embedding" in doc_dict:
sparse_embedding = doc_dict.pop("sparse_embedding", None)
if sparse_embedding:
logger.warning(
"Document %s has the `sparse_embedding` field set,"
"but storing sparse embeddings in Elasticsearch is not currently supported."
"The `sparse_embedding` field will be ignored.",
doc.id,
)
elasticsearch_actions.append(
{
"_op_type": action,
"_id": doc.id,
"_source": doc.to_dict(),
"_source": doc_dict,
}
for doc in documents
),
)

documents_written, errors = helpers.bulk(
client=self._client,
actions=elasticsearch_actions,
refresh="wait_for",
index=self._index,
raise_on_error=False,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,19 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D
if policy == DuplicatePolicy.NONE:
policy = DuplicatePolicy.FAIL

mongo_documents = [doc.to_dict(flatten=False) for doc in documents]
mongo_documents = []
for doc in documents:
doc_dict = doc.to_dict(flatten=False)
if "sparse_embedding" in doc_dict:
sparse_embedding = doc_dict.pop("sparse_embedding", None)
if sparse_embedding:
logger.warning(
"Document %s has the `sparse_embedding` field set,"
"but storing sparse embeddings in MongoDB Atlas is not currently supported."
"The `sparse_embedding` field will be ignored.",
doc.id,
)
mongo_documents.append(doc_dict)
operations: List[Union[UpdateOne, InsertOne, ReplaceOne]]
written_docs = len(documents)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,16 @@ def _from_haystack_to_pg_documents(documents: List[Document]) -> List[Dict[str,
db_document["dataframe"] = Jsonb(db_document["dataframe"]) if db_document["dataframe"] else None
db_document["meta"] = Jsonb(db_document["meta"])

if "sparse_embedding" in db_document:
sparse_embedding = db_document.pop("sparse_embedding", None)
if sparse_embedding:
logger.warning(
"Document %s has the `sparse_embedding` field set,"
"but storing sparse embeddings in Pgvector is not currently supported."
"The `sparse_embedding` field will be ignored.",
db_document["id"],
)

db_documents.append(db_document)

return db_documents
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,13 @@ def _convert_documents_to_pinecone_format(self, documents: List[Document]) -> Li
"objects in Pinecone is not supported. "
"The content of the `blob` field will be ignored."
)
if hasattr(document, "sparse_embedding") and document.sparse_embedding is not None:
logger.warning(
"Document %s has the `sparse_embedding` field set,"
"but storing sparse embeddings in Pinecone is not currently supported."
"The `sparse_embedding` field will be ignored.",
document.id,
)

documents_for_pinecone.append(doc_for_pinecone)
return documents_for_pinecone
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import logging
import uuid
from typing import List, Union

from haystack.dataclasses import Document
from qdrant_client.http import models as rest

logger = logging.getLogger(__name__)


class HaystackToQdrant:
"""A converter from Haystack to Qdrant types."""
Expand All @@ -22,6 +25,17 @@ def documents_to_batch(
vector = payload.pop(embedding_field) or {}
_id = self.convert_id(payload.get("id"))

# TODO: remove as soon as we introduce the support for sparse embeddings in Qdrant
if "sparse_embedding" in payload:
sparse_embedding = payload.pop("sparse_embedding", None)
if sparse_embedding:
logger.warning(
"Document %s has the `sparse_embedding` field set,"
"but storing sparse embeddings in Qdrant is not currently supported."
"The `sparse_embedding` field will be ignored.",
payload["id"],
)

point = rest.PointStruct(
payload=payload,
vector=vector,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import base64
import datetime
import json
import logging
from dataclasses import asdict
from typing import Any, Dict, List, Optional, Tuple, Union

Expand All @@ -21,6 +22,8 @@
from ._filters import convert_filters
from .auth import AuthCredentials

logger = logging.getLogger(__name__)

Number = Union[int, float]
TimeoutType = Union[Tuple[Number, Number], Number]

Expand Down Expand Up @@ -224,6 +227,16 @@ def _to_data_object(self, document: Document) -> Dict[str, Any]:
# The embedding vector is stored separately from the rest of the data
del data["embedding"]

if "sparse_embedding" in data:
sparse_embedding = data.pop("sparse_embedding", None)
if sparse_embedding:
logger.warning(
"Document %s has the `sparse_embedding` field set,"
"but storing sparse embeddings in Weaviate is not currently supported."
"The `sparse_embedding` field will be ignored.",
data["_original_id"],
)

return data

def _to_document(self, data: DataObject[Dict[str, Any], None]) -> Document:
Expand Down

0 comments on commit b4ff369

Please sign in to comment.