Skip to content

Commit

Permalink
Merge pull request deepset-ai#34 from Anant/integrate_2.0_changes
Browse files Browse the repository at this point in the history
use new Document dataclass
  • Loading branch information
ElenaKusevska authored Dec 11, 2023
2 parents dcb1d9d + 42aa117 commit 3a64fe7
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 72 deletions.
15 changes: 15 additions & 0 deletions examples/example-2.0.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from pathlib import Path

from haystack.pipeline_utils import build_indexing_pipeline

from astra_store.document_store import AstraDocumentStore

# We support many different databases. Here we load a simple and lightweight in-memory document store.
document_store = AstraDocumentStore()

# Let's now build indexing pipeline that indexes PDFs and text files from a test folder.
indexing_pipeline = build_indexing_pipeline(
document_store=document_store, embedding_model="sentence-transformers/all-mpnet-base-v2"
)
result = indexing_pipeline.run(files=list(Path("../../test/test_files").iterdir()))
print(result)
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
pydantic==1.10.13
canals==0.9.0
haystack-ai
requests~=2.31.0
pytest~=7.4.3
pytest-cov
Expand Down
96 changes: 44 additions & 52 deletions src/astra_store/astra_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,33 +65,29 @@ def __init__(
def find_index(self):
find_query = {"findCollections": {"options": {"explain": True}}}
response = requests.request("POST", self.create_url, headers=self.request_header, data=json.dumps(find_query))
response.raise_for_status()
response_dict = json.loads(response.text)

if response.status_code == 200:
if "status" in response_dict:
collection_name_matches = list(
filter(lambda d: d['name'] == self.collection_name, response_dict["status"]["collections"])
)

if len(collection_name_matches) == 0:
logger.warning(
f"Astra collection {self.collection_name} not found under {self.keyspace_name}. Will be created."
)
return False
if "status" in response_dict:
collection_name_matches = list(
filter(lambda d: d['name'] == self.collection_name, response_dict["status"]["collections"])
)

collection_embedding_dim = collection_name_matches[0]["options"]["vector"]["dimension"]
if collection_embedding_dim != self.embedding_dim:
raise Exception(
f"Collection vector dimension is not valid, expected {self.embedding_dim}, "
f"found {collection_embedding_dim}"
)
if len(collection_name_matches) == 0:
logger.warning(
f"Astra collection {self.collection_name} not found under {self.keyspace_name}. Will be created."
)
return False

else:
raise Exception(f"status not in response: {response.text}")
collection_embedding_dim = collection_name_matches[0]["options"]["vector"]["dimension"]
if collection_embedding_dim != self.embedding_dim:
raise Exception(
f"Collection vector dimension is not valid, expected {self.embedding_dim}, "
f"found {collection_embedding_dim}"
)

else:
raise Exception(f"Astra DB not available. Status code: {response.status_code}, {response.text}")
# Retry or handle error better
raise Exception(f"status not in response: {response.text}")

return True

Expand Down Expand Up @@ -162,10 +158,10 @@ def _format_query_response(responses, include_metadata, include_values):
_id = response.pop("_id")
score = response.pop("$similarity") if "$similarity" in response else None
_values = response.pop("$vector") if "$vector" in response else None
text = response.pop("text") if "text" in response else None
text = response.pop("content") if "content" in response else None
values = _values if include_values else []
# TODO double check
metadata = response.pop("metadata") if "metadata" in response and include_metadata else dict()
metadata = response if include_metadata else dict()
rsp = Response(_id, text, values, metadata, score)
final_res.append(rsp)
return QueryResponse(final_res)
Expand All @@ -185,14 +181,12 @@ def find_documents(self, find_query):
headers=self.request_header,
data=query,
)
response.raise_for_status()
response_dict = json.loads(response.text)
if response.status_code == 200:
if "data" in response_dict and "documents" in response_dict["data"]:
return response_dict["data"]["documents"]
else:
logger.warning("No documents found", response_dict)
if "data" in response_dict and "documents" in response_dict["data"]:
return response_dict["data"]["documents"]
else:
raise Exception(f"Astra DB request error - status code: {response.status_code} response {response.text}")
logger.warning("No documents found", response_dict)

def get_documents(self, ids: List[str], batch_size: int = 20) -> QueryResponse:
document_batch = []
Expand All @@ -216,19 +210,17 @@ def insert(self, documents: List[Dict]):
headers=self.request_header,
data=query,
)
response.raise_for_status()
response_dict = json.loads(response.text)

if response.status_code == 200:
inserted_ids = (
response_dict["status"]["insertedIds"]
if "status" in response_dict and "insertedIds" in response_dict["status"]
else []
)
if "errors" in response_dict:
logger.error(response_dict["errors"])
return inserted_ids
else:
raise Exception(f"Astra DB request error - status code: {response.status_code} response {response.text}")
inserted_ids = (
response_dict["status"]["insertedIds"]
if "status" in response_dict and "insertedIds" in response_dict["status"]
else []
)
if "errors" in response_dict:
logger.error(response_dict["errors"])
return inserted_ids

def update_document(self, document: Dict, id_key: str):
document_id = document.pop(id_key)
Expand All @@ -247,18 +239,16 @@ def update_document(self, document: Dict, id_key: str):
headers=self.request_header,
data=query,
)
response.raise_for_status()
response_dict = json.loads(response.text)
document[id_key] = document_id

if response.status_code == 200:
if "status" in response_dict and "errors" not in response_dict:
if "matchedCount" in response_dict["status"] and "modifiedCount" in response_dict["status"]:
if response_dict["status"]["matchedCount"] == 1 and response_dict["status"]["modifiedCount"] == 1:
return True
logger.warning(f"Documents {document_id} not updated in Astra {response.text}")
return False
else:
raise Exception(f"Astra DB request error - status code: {response.status_code} response {response.text}")
if "status" in response_dict and "errors" not in response_dict:
if "matchedCount" in response_dict["status"] and "modifiedCount" in response_dict["status"]:
if response_dict["status"]["matchedCount"] == 1 and response_dict["status"]["modifiedCount"] == 1:
return True
logger.warning(f"Documents {document_id} not updated in Astra {response.text}")
return False

def delete(
self,
Expand All @@ -278,16 +268,18 @@ def delete(
headers=self.request_header,
data=json.dumps(query),
)
response.raise_for_status()
return response

def count_documents(self) -> int:
"""
Returns how many documents are present in the document store.
"""
count = requests.request(
response = requests.request(
"POST",
self.request_url,
headers=self.request_header,
data=json.dumps({"countDocuments": {}}),
).json()["status"]["count"]
return count
)
response.raise_for_status()
return response.json()["status"]["count"]
26 changes: 12 additions & 14 deletions src/astra_store/document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
from typing import Any, Dict, List, Optional, Union

import pandas as pd
from haystack.dataclasses.document import Document
from haystack.document_stores.errors import (
from haystack.dataclasses import Document
from haystack.document_stores import (
DuplicateDocumentError,
DuplicatePolicy,
MissingDocumentError,
)
from haystack.document_stores.protocol import DuplicatePolicy
from pydantic import validate_arguments
from sentence_transformers import SentenceTransformer

Expand Down Expand Up @@ -139,7 +139,7 @@ def _convert_input_document(document: Union[dict, Document]):
data = document
else:
raise ValueError(f"Unsupported type for documents, documents is of type {type(document)}.")
meta = data.pop("metadata")
meta = data.pop("meta")
document_dict = {**data, **meta}
if "id" in document_dict:
if "_id" not in document_dict:
Expand All @@ -150,14 +150,14 @@ def _convert_input_document(document: Union[dict, Document]):
)
if "dataframe" in document_dict and document_dict["dataframe"] is not None:
document_dict["dataframe"] = document_dict.pop("dataframe").to_json()
if "text" in document_dict and document_dict["text"] is not None:
if "content" in document_dict and document_dict["content"] is not None:
if "embedding" in document_dict.keys():
if document_dict["embedding"] == None:
if document_dict["embedding"] is None:
document_dict.pop("embedding")
else:
document_dict["$vector"] = document_dict.pop("embedding")
if embed == True:
document_dict["$vector"] = self.embeddings.encode(document_dict["text"]).tolist()
if embed:
document_dict["$vector"] = self.embeddings.encode(document_dict["content"]).tolist()
else:
document_dict["$vector"] = None

Expand Down Expand Up @@ -272,10 +272,10 @@ def _get_result_to_documents(results) -> List[Document]:
documents = []
for match in results.matches:
document = Document(
text=match.text,
content=match.text,
id=match.id,
embedding=match.values,
metadata=match.metadata,
meta=match.metadata,
score=match.score,
)
documents.append(document)
Expand Down Expand Up @@ -376,13 +376,11 @@ def delete_documents(self, document_ids: List[str] = None, delete_all: Optional[
Deletes all documents with a matching document_ids from the document store.
Fails with `MissingDocumentError` if no document with this id is present in the store.
:param document_ids: the document_ids to delete
:param delete_all: delete all documents
:param document_ids: the document_ids to delete.
:param delete_all: delete all documents.
"""
response = self.index.delete(ids=document_ids, delete_all=delete_all)
response_dict = json.loads(response.text)

if response.status_code != 200:
raise Exception("Error querying Astra DB")
if response_dict["status"]["deletedCount"] == 0 and document_ids is not None:
raise MissingDocumentError(f"Document {document_ids} does not exist")
76 changes: 70 additions & 6 deletions tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from typing import List

import pytest
from haystack.dataclasses.document import Document
from haystack.document_stores.protocol import DuplicatePolicy
from haystack import Document
from haystack.document_stores import DuplicatePolicy, MissingDocumentError
from haystack.testing.document_store import DocumentStoreBaseTests

from astra_store.document_store import AstraDocumentStore


Expand All @@ -18,7 +19,7 @@ class TestDocumentStore(DocumentStoreBaseTests):
"""

@pytest.fixture
def docstore(self) -> AstraDocumentStore:
def document_store(self) -> AstraDocumentStore:
"""
This is the most basic requirement for the child class: provide
an instance of this document store so the base class can use it.
Expand Down Expand Up @@ -47,10 +48,73 @@ def docstore(self) -> AstraDocumentStore:
return astra_store

@pytest.fixture(autouse=True)
def run_before_and_after_tests(self, docstore: AstraDocumentStore):
def run_before_and_after_tests(self, document_store: AstraDocumentStore):
"""
Cleaning up document store
"""
docstore.delete_documents(delete_all=True)
assert docstore.count_documents() == 0
document_store.delete_documents(delete_all=True)
assert document_store.count_documents() == 0

def assert_documents_are_equal(self, received: List[Document], expected: List[Document]):
"""
Assert that two lists of Documents are equal.
This is used in every test, if a Document Store implementation has a different behaviour
it should override this method.
This can happen for example when the Document Store sets a score to returned Documents.
Since we can't know what the score will be, we can't compare the Documents reliably.
"""
import operator

received.sort(key=operator.attrgetter('content'))
assert received == expected

def test_delete_documents_non_existing_document(self, document_store: AstraDocumentStore):
"""
Test delete_documents() doesn't delete any Document when called with non existing id.
"""
doc = Document(content="test doc")
document_store.write_documents([doc])
assert document_store.count_documents() == 1

with pytest.raises(MissingDocumentError):
document_store.delete_documents(["non_existing_id"])

# No Document has been deleted
assert document_store.count_documents() == 1

# @pytest.mark.skip(reason="Unsupported filter operator not in.")
# def test_comparison_not_in(self, document_store, filterable_docs):
# pass
#
# @pytest.mark.skip(reason="Unsupported filter operator not in.")
# def test_comparison_not_in_with_with_non_list(self, document_store, filterable_docs):
# pass
#
# @pytest.mark.skip(reason="Unsupported filter operator not in.")
# def test_comparison_not_in_with_with_non_list_iterable(self, document_store, filterable_docs):
# pass

# @pytest.mark.skip(reason="Unsupported filter operator $gt.")
# def test_comparison_greater_than_with_iso_date(self, document_store, filterable_docs):
# pass

# @pytest.mark.skip(reason="Unsupported filter operator $gt.")
# def test_comparison_greater_than_with_string(self, document_store, filterable_docs):
# pass

# @pytest.mark.skip(reason="Unsupported filter operator $gt.")
# def test_comparison_greater_than_with_dataframe(self, document_store, filterable_docs):
# pass

# @pytest.mark.skip(reason="Unsupported filter operator $gt.")
# def test_comparison_greater_than_with_list(self, document_store, filterable_docs):
# pass

# @pytest.mark.skip(reason="Unsupported filter operator $gt.")
# def test_comparison_greater_than_with_none(self, document_store, filterable_docs):
# pass

# @pytest.mark.skip(reason="Unsupported filter operator $gte.")
# def test_comparison_greater_than_equal(self, document_store, filterable_docs):
# pass

0 comments on commit 3a64fe7

Please sign in to comment.