From 3eb73d7e76e9519e40e09a66f0ab0099d28717a0 Mon Sep 17 00:00:00 2001 From: Stefano Lottini Date: Fri, 27 Sep 2024 10:23:02 +0200 Subject: [PATCH] Thorough rewrite and optimization of integration tests (#82) * complete removal of SomeEmbeddings for tests * test_vectorstore_autodetect brought to rationality * test_graphvectorstore is nice now * halfway through test_vectorstore.py * tests of from_ methods of vectorstore are now good * most test_vectorstore done; missing only indexing and coreclients_init * all of graph/vectorstores brought to order * graph/vstore tests mostly hcd-compatible (wip) * wip on fixing the hcd/apikey header thing * completed rewrite of tests for graph/vectorstores * chat message histories tested nicely * deep restructuring of the caches testing * test_document_loaders under control * further improvement test document loader * with test_storage it seems everything is done now. * tiny docstr edit * make openai key into a fixture to heal compile test * clean info on IT prereqs --- libs/astradb/codespell_ignore_words.txt | 1 - .../langchain_astradb/graph_vectorstores.py | 2 +- .../langchain_astradb/utils/mmr_traversal.py | 2 +- .../astradb/langchain_astradb/vectorstores.py | 11 +- libs/astradb/testing.env.sample | 37 +- libs/astradb/tests/conftest.py | 57 +- .../tests/integration_tests/.env.example | 5 - .../tests/integration_tests/conftest.py | 501 ++++- .../tests/integration_tests/test_cache.py | 229 ++ .../tests/integration_tests/test_caches.py | 393 ---- .../test_chat_message_histories.py | 141 +- .../test_document_loaders.py | 202 +- .../test_graphvectorstore.py | 275 +-- .../integration_tests/test_semantic_cache.py | 265 +++ .../tests/integration_tests/test_storage.py | 643 +++--- .../integration_tests/test_vectorstore.py | 1302 +++++++++++ .../test_vectorstore_autodetect.py | 245 +-- .../test_vectorstore_ddl_tests.py | 538 +++++ .../integration_tests/test_vectorstores.py | 1951 ----------------- .../tests/unit_tests/test_vectorstores.py | 4 +- 20 files changed, 3554 insertions(+), 3250 deletions(-) delete mode 100644 libs/astradb/tests/integration_tests/.env.example create mode 100644 libs/astradb/tests/integration_tests/test_cache.py delete mode 100644 libs/astradb/tests/integration_tests/test_caches.py create mode 100644 libs/astradb/tests/integration_tests/test_semantic_cache.py create mode 100644 libs/astradb/tests/integration_tests/test_vectorstore.py create mode 100644 libs/astradb/tests/integration_tests/test_vectorstore_ddl_tests.py delete mode 100644 libs/astradb/tests/integration_tests/test_vectorstores.py diff --git a/libs/astradb/codespell_ignore_words.txt b/libs/astradb/codespell_ignore_words.txt index 0b3a7cd..e69de29 100644 --- a/libs/astradb/codespell_ignore_words.txt +++ b/libs/astradb/codespell_ignore_words.txt @@ -1 +0,0 @@ -Haa diff --git a/libs/astradb/langchain_astradb/graph_vectorstores.py b/libs/astradb/langchain_astradb/graph_vectorstores.py index d796915..8a9e9bc 100644 --- a/libs/astradb/langchain_astradb/graph_vectorstores.py +++ b/libs/astradb/langchain_astradb/graph_vectorstores.py @@ -323,7 +323,7 @@ def mmr_traversal_search( # noqa: C901 def get_adjacent(tags: set[str]) -> Iterable[_Edge]: targets: dict[str, _Edge] = {} - # TODO: Would be better parralelized + # TODO: Would be better parallelized for tag in tags: m_filter = (metadata_filter or {}).copy() m_filter[self.link_from_metadata_key] = tag diff --git a/libs/astradb/langchain_astradb/utils/mmr_traversal.py b/libs/astradb/langchain_astradb/utils/mmr_traversal.py index ad77ccf..04b3dea 100644 --- a/libs/astradb/langchain_astradb/utils/mmr_traversal.py +++ b/libs/astradb/langchain_astradb/utils/mmr_traversal.py @@ -117,7 +117,7 @@ def __init__( # List of the candidates. self.candidates = [] - # ND array of the candidate embeddings. + # numpy n-dimensional array of the candidate embeddings. self.candidate_embeddings = np.ndarray((0, self.dimensions), dtype=np.float32) self.best_score = NEG_INF diff --git a/libs/astradb/langchain_astradb/vectorstores.py b/libs/astradb/langchain_astradb/vectorstores.py index 2ef136d..c807742 100644 --- a/libs/astradb/langchain_astradb/vectorstores.py +++ b/libs/astradb/langchain_astradb/vectorstores.py @@ -460,8 +460,6 @@ def __init__( This is useful when the service is configured for the collection, but no corresponding secret is stored within Astra's key management system. - This parameter cannot be provided without - specifying ``collection_vector_service_options``. content_field: name of the field containing the textual content in the documents when saved on Astra DB. For vectorize collections, this cannot be specified; for non-vectorize collection, defaults @@ -473,7 +471,7 @@ def __init__( Please understand the limitations of this method and get some understanding of your data before passing ``"*"`` for this parameter. ignore_invalid_documents: if False (default), exceptions are raised - when a document is found on the Astra DB collectin that does + when a document is found on the Astra DB collection that does not have the expected shape. If set to True, such results from the database are ignored and a warning is issued. Note that in this case a similarity search may end up returning fewer @@ -824,11 +822,10 @@ async def adelete( raise ValueError(msg) _max_workers = concurrency or self.bulk_delete_concurrency - return all( - await gather_with_concurrency( - _max_workers, *[self.adelete_by_document_id(doc_id) for doc_id in ids] - ) + await gather_with_concurrency( + _max_workers, *[self.adelete_by_document_id(doc_id) for doc_id in ids] ) + return True def delete_collection(self) -> None: """Completely delete the collection from the database. diff --git a/libs/astradb/testing.env.sample b/libs/astradb/testing.env.sample index cb502da..7ae2960 100644 --- a/libs/astradb/testing.env.sample +++ b/libs/astradb/testing.env.sample @@ -1,13 +1,24 @@ -export ASTRA_DB_APPLICATION_TOKEN="AstraCS:aaabbbccc..." -export ASTRA_DB_API_ENDPOINT="https://0123...-region.apps.astra.datastax.com" -export ASTRA_DB_KEYSPACE="default_keyspace" -# Optional (mostly for HCD and such): -# export ASTRA_DB_ENVIRONMENT="..." - -# required to test vectorize with SHARED_SECRET. Comment on HCD and such. -export SHARED_SECRET_NAME_OPENAI="NAME_SUPPLIED_IN_ASTRA_KMS" -# required to test vectorize with HEADER -export OPENAI_API_KEY="sk-aaabbbccc..." - -# change to "1" if nvidia server-side embeddings are available for the DB -export NVIDIA_VECTORIZE_AVAILABLE="0" +# ASTRA DB SETUP + +ASTRA_DB_API_ENDPOINT=https://your_astra_db_id-your_region.apps.astra.datastax.com +ASTRA_DB_APPLICATION_TOKEN=AstraCS:your_astra_db_application_token +# ASTRA_DB_KEYSPACE=your_astra_db_keyspace +# ASTRA_DB_ENVIRONMENT="prod" + +SHARED_SECRET_NAME_OPENAI="key_name_on_astra_kms" +OPENAI_API_KEY="..." + + +### For testing on HCD it will not do SHARED_SECRET vectorize and look something like: +# +# +# +# ASTRA_DB_APPLICATION_TOKEN="Cassandra:Y2Fzc2FuZHJh:Y2Fzc2FuZHJh" +# ASTRA_DB_API_ENDPOINT="http://localhost:8181" +# ASTRA_DB_KEYSPACE="keyspace_created_by_the_ci_for_testing" +# ASTRA_DB_ENVIRONMENT="hcd" +# +# OPENAI_API_KEY="..." +# +# +# diff --git a/libs/astradb/tests/conftest.py b/libs/astradb/tests/conftest.py index aa01bdd..767023b 100644 --- a/libs/astradb/tests/conftest.py +++ b/libs/astradb/tests/conftest.py @@ -5,34 +5,14 @@ from __future__ import annotations import json +from typing import TYPE_CHECKING, Any from langchain_core.embeddings import Embeddings +from langchain_core.language_models import LLM +from typing_extensions import override - -class SomeEmbeddings(Embeddings): - """Turn a sentence into an embedding vector in some way. - Not important how. It is deterministic is all that counts. - """ - - def __init__(self, dimension: int) -> None: - self.dimension = dimension - - def embed_documents(self, texts: list[str]) -> list[list[float]]: - return [self.embed_query(txt) for txt in texts] - - async def aembed_documents(self, texts: list[str]) -> list[list[float]]: - return self.embed_documents(texts) - - def embed_query(self, text: str) -> list[float]: - unnormed0 = [ord(c) for c in text[: self.dimension]] - unnormed = (unnormed0 + [1] + [0] * (self.dimension - 1 - len(unnormed0)))[ - : self.dimension - ] - norm = sum(x * x for x in unnormed) ** 0.5 - return [x / norm for x in unnormed] - - async def aembed_query(self, text: str) -> list[float]: - return self.embed_query(text) +if TYPE_CHECKING: + from langchain_core.callbacks import CallbackManagerForLLMRun class ParserEmbeddings(Embeddings): @@ -61,3 +41,30 @@ def embed_query(self, text: str) -> list[float]: async def aembed_query(self, text: str) -> list[float]: return self.embed_query(text) + + +class IdentityLLM(LLM): + num_calls: int = 0 + + @property + @override + def _llm_type(self) -> str: + return "fake" + + @override + def _call( + self, + prompt: str, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> str: + self.num_calls += 1 + if stop is not None: + return f"STOP<{prompt.upper()}>" + return prompt + + @property + @override + def _identifying_params(self) -> dict[str, Any]: + return {} diff --git a/libs/astradb/tests/integration_tests/.env.example b/libs/astradb/tests/integration_tests/.env.example deleted file mode 100644 index 4259d87..0000000 --- a/libs/astradb/tests/integration_tests/.env.example +++ /dev/null @@ -1,5 +0,0 @@ -# astra db -ASTRA_DB_API_ENDPOINT=https://your_astra_db_id-your_region.apps.astra.datastax.com -ASTRA_DB_APPLICATION_TOKEN=AstraCS:your_astra_db_application_token -# ASTRA_DB_KEYSPACE=your_astra_db_namespace -# ASTRA_DB_SKIP_COLLECTION_DELETIONS=true diff --git a/libs/astradb/tests/integration_tests/conftest.py b/libs/astradb/tests/integration_tests/conftest.py index 9cea801..7abed8f 100644 --- a/libs/astradb/tests/integration_tests/conftest.py +++ b/libs/astradb/tests/integration_tests/conftest.py @@ -1,20 +1,88 @@ +""" +Integration tests on Astra DB. + +Required to run this test: + - a recent `astrapy` Python package available + - an Astra DB instance; + - the two environment variables set: + export ASTRA_DB_API_ENDPOINT="https://-us-east1.apps.astra.datastax.com" + export ASTRA_DB_APPLICATION_TOKEN="AstraCS:........." + - optionally this as well (otherwise defaults are used): + export ASTRA_DB_KEYSPACE="my_keyspace" + - optionally (if not on prod) + export ASTRA_DB_ENVIRONMENT="dev" # or similar + - an openai key name on KMS for SHARED_SECRET vectorize mode, associated to the DB: + export SHARED_SECRET_NAME_OPENAI="the_api_key_name_in_Astra_KMS" + - an OpenAI key for the vectorize test (in HEADER mode): + export OPENAI_API_KEY="..." + +Please refer to testing.env.sample. +""" + from __future__ import annotations import os from pathlib import Path -from typing import TypedDict +from typing import TYPE_CHECKING, Iterable, TypedDict import pytest -from astrapy import Database +from astrapy import DataAPIClient +from astrapy.authentication import StaticTokenProvider from astrapy.db import AstraDB from astrapy.info import CollectionVectorServiceOptions +from langchain_astradb.utils.astradb import SetupMode +from langchain_astradb.vectorstores import DEFAULT_INDEXING_OPTIONS, AstraDBVectorStore +from tests.conftest import IdentityLLM, ParserEmbeddings + +if TYPE_CHECKING: + from astrapy import Collection, Database + from langchain_core.embeddings import Embeddings + from langchain_core.language_models import LLM + + # Getting the absolute path of the current file's directory ABS_PATH = (Path(__file__)).parent # Getting the absolute path of the project's root directory PROJECT_DIR = Path(ABS_PATH).parent.parent +# Long-lasting collection names (default-indexed for vectorstores) +COLLECTION_NAME_D2 = "lc_test_d2_euclidean" +COLLECTION_NAME_VZ = "lc_test_vz_euclidean" +# (All-indexed) for general-purpose and autodetect +COLLECTION_NAME_IDXALL_D2 = "lc_test_d2_idxall_euclidean" +COLLECTION_NAME_IDXALL_VZ = "lc_test_vz_idxall_euclidean" +# non-vector all-indexed collection +COLLECTION_NAME_IDXALL = "lc_test_idxall" +# non-vector store-like-indexed collection +COLLECTION_NAME_IDXID = "lc_test_idxid" +# Function-lived collection names: +# (all-indexed) for autodetect: +EPHEMERAL_COLLECTION_NAME_IDXALL_D2 = "lc_test_d2_idxall_euclidean" +# of generic use for vectorstores +EPHEMERAL_COLLECTION_NAME_D2 = "lc_test_d2_cosine_short" +EPHEMERAL_COLLECTION_NAME_VZ = "lc_test_vz_cosine_short" +# for KMS (aka shared_secret) vectorize setup (vectorstores) +EPHEMERAL_COLLECTION_NAME_VZ_KMS = "lc_test_vz_kms_short" +# indexing-related collection names (function-lived) (vectorstores) +EPHEMERAL_CUSTOM_IDX_NAME_D2 = "lc_test_custom_idx_d2_short" +EPHEMERAL_DEFAULT_IDX_NAME_D2 = "lc_test_default_idx_d2_short" +EPHEMERAL_LEGACY_IDX_NAME_D2 = "lc_test_legacy_idx_d2_short" +# indexing-related collection names (function-lived) (storage) +EPHEMERAL_CUSTOM_IDX_NAME = "lc_test_custom_idx_short" +EPHEMERAL_LEGACY_IDX_NAME = "lc_test_legacy_idx_short" + +# autodetect assets +CUSTOM_CONTENT_KEY = "xcontent" +LONG_TEXT = "This is the textual content field in the doc." +# vectorstore-related utilities/constants +INCOMPATIBLE_INDEXING_MSG = "is detected as having the following indexing policy" +LEGACY_INDEXING_MSG = "is detected as having indexing turned on for all fields" +# similarity threshold definitions +EUCLIDEAN_MIN_SIM_UNIT_VECTORS = 0.2 +MATCH_EPSILON = 0.0001 + # Loading the .env file if it exists def _load_env() -> None: @@ -25,7 +93,7 @@ def _load_env() -> None: load_dotenv(dotenv_path) -def _has_env_vars() -> bool: +def astra_db_env_vars_available() -> bool: return all( [ "ASTRA_DB_APPLICATION_TOKEN" in os.environ, @@ -34,11 +102,48 @@ def _has_env_vars() -> bool: ) +_load_env() + + +OPENAI_VECTORIZE_OPTIONS_HEADER = CollectionVectorServiceOptions( + provider="openai", + model_name="text-embedding-3-small", +) + +OPENAI_SHARED_SECRET_KEY_NAME = os.environ.get("SHARED_SECRET_NAME_OPENAI") +OPENAI_VECTORIZE_OPTIONS_KMS: CollectionVectorServiceOptions | None +if OPENAI_SHARED_SECRET_KEY_NAME: + OPENAI_VECTORIZE_OPTIONS_KMS = CollectionVectorServiceOptions( + provider="openai", + model_name="text-embedding-3-small", + authentication={ + "providerKey": OPENAI_SHARED_SECRET_KEY_NAME, + }, + ) +else: + OPENAI_VECTORIZE_OPTIONS_KMS = None + + class AstraDBCredentials(TypedDict): token: str api_endpoint: str namespace: str | None - environment: str | None + environment: str + + +@pytest.fixture(scope="session") +def openai_api_key() -> str: + return os.environ["OPENAI_API_KEY"] + + +@pytest.fixture(scope="session") +def embedding_d2() -> Embeddings: + return ParserEmbeddings(dimension=2) + + +@pytest.fixture +def test_llm() -> LLM: + return IdentityLLM() @pytest.fixture(scope="session") @@ -47,22 +152,43 @@ def astra_db_credentials() -> AstraDBCredentials: "token": os.environ["ASTRA_DB_APPLICATION_TOKEN"], "api_endpoint": os.environ["ASTRA_DB_API_ENDPOINT"], "namespace": os.environ.get("ASTRA_DB_KEYSPACE"), - "environment": os.environ.get("ASTRA_DB_ENVIRONMENT"), + "environment": os.environ.get("ASTRA_DB_ENVIRONMENT", "prod"), } @pytest.fixture(scope="session") -def database(astra_db_credentials: AstraDBCredentials) -> Database: - return Database( - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], +def is_astra_db(astra_db_credentials: AstraDBCredentials) -> bool: + return astra_db_credentials["environment"].lower() in { + "prod", + "test", + "dev", + } + + +@pytest.fixture(scope="session") +def database( + *, + is_astra_db: bool, + astra_db_credentials: AstraDBCredentials, +) -> Database: + client = DataAPIClient(environment=astra_db_credentials["environment"]) + db = client.get_database( + astra_db_credentials["api_endpoint"], + token=StaticTokenProvider(astra_db_credentials["token"]), namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], ) + if not is_astra_db: + if astra_db_credentials["namespace"] is None: + msg = "Cannot test on non-Astra without a namespace set." + raise ValueError(msg) + db.get_database_admin().create_namespace(astra_db_credentials["namespace"]) + + return db @pytest.fixture(scope="session") def core_astra_db(astra_db_credentials: AstraDBCredentials) -> AstraDB: + """An instance of the 'core' (pre-1.0, legacy) astrapy database.""" return AstraDB( token=astra_db_credentials["token"], api_endpoint=astra_db_credentials["api_endpoint"], @@ -70,25 +196,348 @@ def core_astra_db(astra_db_credentials: AstraDBCredentials) -> AstraDB: ) -_load_env() +@pytest.fixture(scope="module") +def collection_d2( + database: Database, +) -> Iterable[Collection]: + """A general-purpose D=2(Euclidean) collection for per-test reuse.""" + collection = database.create_collection( + COLLECTION_NAME_D2, + dimension=2, + check_exists=False, + indexing=DEFAULT_INDEXING_OPTIONS, + metric="euclidean", + ) + yield collection + collection.drop() -OPENAI_VECTORIZE_OPTIONS = CollectionVectorServiceOptions( - provider="openai", - model_name="text-embedding-3-small", - authentication={ - "providerKey": f"{os.environ.get('SHARED_SECRET_NAME_OPENAI', '')}", - }, -) +@pytest.fixture +def empty_collection_d2( + collection_d2: Collection, +) -> Collection: + """A per-test-function empty d=2(Euclidean) collection.""" + collection_d2.delete_many({}) + return collection_d2 -OPENAI_VECTORIZE_OPTIONS_HEADER = CollectionVectorServiceOptions( - provider="openai", - model_name="text-embedding-3-small", -) + +@pytest.fixture +def vector_store_d2( + empty_collection_d2: Collection, # noqa: ARG001 + astra_db_credentials: AstraDBCredentials, + embedding_d2: Embeddings, +) -> AstraDBVectorStore: + """A fresh vector store on a d=2(Euclidean) collection.""" + return AstraDBVectorStore( + embedding=embedding_d2, + collection_name=COLLECTION_NAME_D2, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + setup_mode=SetupMode.OFF, + ) -NVIDIA_VECTORIZE_OPTIONS = CollectionVectorServiceOptions( - provider="nvidia", - model_name="NV-Embed-QA", -) +@pytest.fixture +def vector_store_d2_stringtoken( + empty_collection_d2: Collection, # noqa: ARG001 + astra_db_credentials: AstraDBCredentials, + embedding_d2: Embeddings, +) -> AstraDBVectorStore: + """ + A fresh vector store on a d=2(Euclidean) collection, + but initialized with a token string instead of a TokenProvider. + """ + return AstraDBVectorStore( + embedding=embedding_d2, + collection_name=COLLECTION_NAME_D2, + token=astra_db_credentials["token"], + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + setup_mode=SetupMode.OFF, + ) + + +@pytest.fixture +def ephemeral_collection_cleaner_d2( + database: Database, +) -> Iterable[str]: + """ + A nominal fixture to ensure the ephemeral collection is deleted + after the test function has finished. + """ + + yield EPHEMERAL_COLLECTION_NAME_D2 + + if EPHEMERAL_COLLECTION_NAME_D2 in database.list_collection_names(): + database.drop_collection(EPHEMERAL_COLLECTION_NAME_D2) + + +@pytest.fixture(scope="module") +def collection_idxall( + database: Database, +) -> Iterable[Collection]: + """ + A general-purpose collection for per-test reuse. + This one has default indexing (i.e. all fields are covered). + """ + collection = database.create_collection( + COLLECTION_NAME_IDXALL, + check_exists=False, + ) + yield collection + + collection.drop() + + +@pytest.fixture +def empty_collection_idxall( + collection_idxall: Collection, +) -> Collection: + """ + A per-test-function empty collection. + This one has default indexing (i.e. all fields are covered). + """ + collection_idxall.delete_many({}) + return collection_idxall + + +@pytest.fixture(scope="module") +def collection_idxid( + database: Database, +) -> Iterable[Collection]: + """ + A general-purpose collection for per-test reuse. + This one has id-only indexing (i.e. for Storage classes). + """ + collection = database.create_collection( + COLLECTION_NAME_IDXID, + indexing={"allow": ["_id"]}, + check_exists=False, + ) + yield collection + + collection.drop() + + +@pytest.fixture(scope="module") +def collection_idxall_d2( + database: Database, +) -> Iterable[Collection]: + """ + A general-purpose D=2(Euclidean) collection for per-test reuse. + This one has default indexing (i.e. all fields are covered). + """ + collection = database.create_collection( + COLLECTION_NAME_IDXALL_D2, + dimension=2, + check_exists=False, + metric="euclidean", + ) + yield collection + + collection.drop() + + +@pytest.fixture +def empty_collection_idxall_d2( + collection_idxall_d2: Collection, +) -> Collection: + """ + A per-test-function empty d=2(Euclidean) collection. + This one has default indexing (i.e. all fields are covered). + """ + collection_idxall_d2.delete_many({}) + return collection_idxall_d2 + + +@pytest.fixture +def vector_store_idxall_d2( + empty_collection_idxall_d2: Collection, # noqa: ARG001 + astra_db_credentials: AstraDBCredentials, + embedding_d2: Embeddings, +) -> AstraDBVectorStore: + """A fresh vector store on a d=2(Euclidean) collection.""" + return AstraDBVectorStore( + embedding=embedding_d2, + collection_name=COLLECTION_NAME_IDXALL_D2, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + collection_indexing_policy={"allow": ["*"]}, + setup_mode=SetupMode.OFF, + ) + + +@pytest.fixture +def ephemeral_collection_cleaner_idxall_d2( + database: Database, +) -> Iterable[str]: + """ + A nominal fixture to ensure the ephemeral collection is deleted + after the test function has finished. + """ + + yield EPHEMERAL_COLLECTION_NAME_IDXALL_D2 + + if EPHEMERAL_COLLECTION_NAME_IDXALL_D2 in database.list_collection_names(): + database.drop_collection(EPHEMERAL_COLLECTION_NAME_IDXALL_D2) + + +@pytest.fixture(scope="module") +def collection_vz( + openai_api_key: str, + database: Database, +) -> Iterable[Collection]: + """A general-purpose $vectorize collection for per-test reuse.""" + collection = database.create_collection( + COLLECTION_NAME_VZ, + dimension=16, + check_exists=False, + indexing=DEFAULT_INDEXING_OPTIONS, + metric="euclidean", + service=OPENAI_VECTORIZE_OPTIONS_HEADER, + embedding_api_key=openai_api_key, + ) + yield collection + + collection.drop() + + +@pytest.fixture +def empty_collection_vz( + collection_vz: Collection, +) -> Collection: + """A per-test-function empty $vecorize collection.""" + collection_vz.delete_many({}) + return collection_vz + + +@pytest.fixture +def vector_store_vz( + astra_db_credentials: AstraDBCredentials, + openai_api_key: str, + empty_collection_vz: Collection, # noqa: ARG001 +) -> AstraDBVectorStore: + """A fresh vector store on a $vectorize collection.""" + return AstraDBVectorStore( + collection_name=COLLECTION_NAME_VZ, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + setup_mode=SetupMode.OFF, + collection_vector_service_options=OPENAI_VECTORIZE_OPTIONS_HEADER, + collection_embedding_api_key=openai_api_key, + ) + + +@pytest.fixture +def ephemeral_collection_cleaner_vz( + database: Database, +) -> Iterable[str]: + """ + A nominal fixture to ensure the ephemeral vectorize collection is deleted + after the test function has finished. + """ + + yield EPHEMERAL_COLLECTION_NAME_VZ + + if EPHEMERAL_COLLECTION_NAME_VZ in database.list_collection_names(): + database.drop_collection(EPHEMERAL_COLLECTION_NAME_VZ) + + +@pytest.fixture(scope="module") +def collection_idxall_vz( + openai_api_key: str, + database: Database, +) -> Iterable[Collection]: + """ + A general-purpose $vectorize collection for per-test reuse. + This one has default indexing (i.e. all fields are covered). + """ + collection = database.create_collection( + COLLECTION_NAME_IDXALL_VZ, + dimension=16, + check_exists=False, + metric="euclidean", + service=OPENAI_VECTORIZE_OPTIONS_HEADER, + embedding_api_key=openai_api_key, + ) + yield collection + + collection.drop() + + +@pytest.fixture +def empty_collection_idxall_vz( + collection_idxall_vz: Collection, +) -> Collection: + """ + A per-test-function empty $vecorize collection. + This one has default indexing (i.e. all fields are covered). + """ + collection_idxall_vz.delete_many({}) + return collection_idxall_vz + + +@pytest.fixture +def vector_store_idxall_vz( + openai_api_key: str, + empty_collection_idxall_vz: Collection, # noqa: ARG001 + astra_db_credentials: AstraDBCredentials, +) -> AstraDBVectorStore: + """A fresh vector store on a d=2(Euclidean) collection.""" + return AstraDBVectorStore( + collection_name=COLLECTION_NAME_IDXALL_VZ, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + collection_indexing_policy={"allow": ["*"]}, + setup_mode=SetupMode.OFF, + collection_vector_service_options=OPENAI_VECTORIZE_OPTIONS_HEADER, + collection_embedding_api_key=openai_api_key, + ) + + +@pytest.fixture +def ephemeral_indexing_collections_cleaner( + database: Database, +) -> Iterable[list[str]]: + """ + A nominal fixture to ensure the ephemeral collections for indexing testing + are deleted after the test function has finished. + """ + + collection_names = [ + EPHEMERAL_CUSTOM_IDX_NAME_D2, + EPHEMERAL_DEFAULT_IDX_NAME_D2, + EPHEMERAL_LEGACY_IDX_NAME_D2, + EPHEMERAL_CUSTOM_IDX_NAME, + EPHEMERAL_LEGACY_IDX_NAME, + ] + yield collection_names + + for collection_name in collection_names: + if collection_name in database.list_collection_names(): + database.drop_collection(collection_name) + + +@pytest.fixture +def ephemeral_collection_cleaner_vz_kms( + database: Database, +) -> Iterable[str]: + """ + A nominal fixture to ensure the ephemeral vectorize collection with KMS + is deleted after the test function has finished. + """ + + yield EPHEMERAL_COLLECTION_NAME_VZ_KMS + + if EPHEMERAL_COLLECTION_NAME_VZ_KMS in database.list_collection_names(): + database.drop_collection(EPHEMERAL_COLLECTION_NAME_VZ_KMS) diff --git a/libs/astradb/tests/integration_tests/test_cache.py b/libs/astradb/tests/integration_tests/test_cache.py new file mode 100644 index 0000000..9e95ab6 --- /dev/null +++ b/libs/astradb/tests/integration_tests/test_cache.py @@ -0,0 +1,229 @@ +from __future__ import annotations + +import os +from typing import TYPE_CHECKING + +import pytest +from astrapy.authentication import StaticTokenProvider +from langchain_core.globals import get_llm_cache, set_llm_cache +from langchain_core.outputs import Generation, LLMResult + +from langchain_astradb import AstraDBCache +from langchain_astradb.utils.astradb import SetupMode + +from .conftest import ( + COLLECTION_NAME_IDXALL, + AstraDBCredentials, + astra_db_env_vars_available, +) + +if TYPE_CHECKING: + from astrapy import Collection + from astrapy.db import AstraDB + + from .conftest import IdentityLLM + + +@pytest.fixture +def astradb_cache( + astra_db_credentials: AstraDBCredentials, + empty_collection_idxall: Collection, # noqa: ARG001 +) -> AstraDBCache: + return AstraDBCache( + collection_name=COLLECTION_NAME_IDXALL, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + ) + + +@pytest.fixture +async def astradb_cache_async( + astra_db_credentials: AstraDBCredentials, + empty_collection_idxall: Collection, # noqa: ARG001 +) -> AstraDBCache: + return AstraDBCache( + collection_name=COLLECTION_NAME_IDXALL, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + setup_mode=SetupMode.ASYNC, + ) + + +@pytest.mark.skipif( + not astra_db_env_vars_available(), reason="Missing Astra DB env. vars" +) +class TestAstraDBCache: + def test_cache_crud_sync( + self, + astradb_cache: AstraDBCache, + ) -> None: + """Tests for basic cache CRUD, not involving an LLM.""" + gens0 = [Generation(text="gen0_text")] + gens1 = [Generation(text="gen1_text")] + assert astradb_cache.lookup("prompt0", "llm_string") is None + + astradb_cache.update("prompt0", "llm_string", gens0) + astradb_cache.update("prompt1", "llm_string", gens1) + assert astradb_cache.lookup("prompt0", "llm_string") == gens0 + assert astradb_cache.lookup("prompt1", "llm_string") == gens1 + + astradb_cache.delete("prompt0", "llm_string") + assert astradb_cache.lookup("prompt0", "llm_string") is None + assert astradb_cache.lookup("prompt1", "llm_string") == gens1 + + astradb_cache.clear() + assert astradb_cache.lookup("prompt0", "llm_string") is None + assert astradb_cache.lookup("prompt1", "llm_string") is None + + async def test_cache_crud_async( + self, + astradb_cache_async: AstraDBCache, + ) -> None: + """ + Tests for basic cache CRUD, not involving an LLM. + Async version. + """ + gens0 = [Generation(text="gen0_text")] + gens1 = [Generation(text="gen1_text")] + assert await astradb_cache_async.alookup("prompt0", "llm_string") is None + + await astradb_cache_async.aupdate("prompt0", "llm_string", gens0) + await astradb_cache_async.aupdate("prompt1", "llm_string", gens1) + assert await astradb_cache_async.alookup("prompt0", "llm_string") == gens0 + assert await astradb_cache_async.alookup("prompt1", "llm_string") == gens1 + + await astradb_cache_async.adelete("prompt0", "llm_string") + assert await astradb_cache_async.alookup("prompt0", "llm_string") is None + assert await astradb_cache_async.alookup("prompt1", "llm_string") == gens1 + + await astradb_cache_async.aclear() + assert await astradb_cache_async.alookup("prompt0", "llm_string") is None + assert await astradb_cache_async.alookup("prompt1", "llm_string") is None + + def test_cache_through_llm_sync( + self, + test_llm: IdentityLLM, + astradb_cache: AstraDBCache, + ) -> None: + """Tests for cache as used with a (mock) LLM.""" + gens0 = [Generation(text="gen0_text")] + set_llm_cache(astradb_cache) + + params = {"stop": None, **test_llm.dict()} + llm_string = str(sorted(params.items())) + + assert test_llm.num_calls == 0 + + # inject cache entry, check no LLM call is done + get_llm_cache().update("prompt0", llm_string, gens0) + output = test_llm.generate(["prompt0"]) + expected_output = LLMResult( + generations=[gens0], + llm_output={}, + ) + assert test_llm.num_calls == 0 + assert output == expected_output + + # check *one* new call for a new prompt, even if 'generate' repeated + test_llm.generate(["prompt1"]) + test_llm.generate(["prompt1"]) + test_llm.generate(["prompt1"]) + assert test_llm.num_calls == 1 + + # remove the cache and check a new LLM call is actually made + astradb_cache.delete_through_llm("prompt1", test_llm, stop=None) + test_llm.generate(["prompt1"]) + test_llm.generate(["prompt1"]) + assert test_llm.num_calls == 2 + + async def test_cache_through_llm_async( + self, + test_llm: IdentityLLM, + astradb_cache_async: AstraDBCache, + ) -> None: + """Tests for cache as used with a (mock) LLM, async version""" + gens0 = [Generation(text="gen0_text")] + set_llm_cache(astradb_cache_async) + + params = {"stop": None, **test_llm.dict()} + llm_string = str(sorted(params.items())) + + assert test_llm.num_calls == 0 + + # inject cache entry, check no LLM call is done + await get_llm_cache().aupdate("prompt0", llm_string, gens0) + output = await test_llm.agenerate(["prompt0"]) + expected_output = LLMResult( + generations=[gens0], + llm_output={}, + ) + assert test_llm.num_calls == 0 + assert output == expected_output + + # check *one* new call for a new prompt, even if 'generate' repeated + await test_llm.agenerate(["prompt1"]) + await test_llm.agenerate(["prompt1"]) + await test_llm.agenerate(["prompt1"]) + assert test_llm.num_calls == 1 + + # remove the cache and check a new LLM call is actually made + await astradb_cache_async.adelete_through_llm("prompt1", test_llm, stop=None) + await test_llm.agenerate(["prompt1"]) + await test_llm.agenerate(["prompt1"]) + assert test_llm.num_calls == 2 + + @pytest.mark.skipif( + os.environ.get("ASTRA_DB_ENVIRONMENT", "prod").upper() != "PROD", + reason="Can run on Astra DB production environment only", + ) + def test_cache_coreclients_init_sync( + self, + core_astra_db: AstraDB, + astradb_cache: AstraDBCache, + ) -> None: + """A deprecation warning from passing a (core) AstraDB, but it works.""" + gens0 = [Generation(text="gen0_text")] + astradb_cache.update("prompt0", "llm_string", gens0) + # create an equivalent cache with core AstraDB in init + with pytest.warns(DeprecationWarning) as rec_warnings: + cache_init_core = AstraDBCache( + collection_name=COLLECTION_NAME_IDXALL, + astra_db_client=core_astra_db, + ) + f_rec_warnings = [ + wrn for wrn in rec_warnings if issubclass(wrn.category, DeprecationWarning) + ] + assert len(f_rec_warnings) == 1 + assert cache_init_core.lookup("prompt0", "llm_string") == gens0 + + @pytest.mark.skipif( + os.environ.get("ASTRA_DB_ENVIRONMENT", "prod").upper() != "PROD", + reason="Can run on Astra DB production environment only", + ) + async def test_cache_coreclients_init_async( + self, + core_astra_db: AstraDB, + astradb_cache_async: AstraDBCache, + ) -> None: + """ + A deprecation warning from passing a (core) AstraDB, but it works. + Async version. + """ + gens0 = [Generation(text="gen0_text")] + await astradb_cache_async.aupdate("prompt0", "llm_string", gens0) + # create an equivalent cache with core AstraDB in init + with pytest.warns(DeprecationWarning) as rec_warnings: + cache_init_core = AstraDBCache( + collection_name=COLLECTION_NAME_IDXALL, + astra_db_client=core_astra_db, + setup_mode=SetupMode.ASYNC, + ) + f_rec_warnings = [ + wrn for wrn in rec_warnings if issubclass(wrn.category, DeprecationWarning) + ] + assert len(f_rec_warnings) == 1 + assert await cache_init_core.alookup("prompt0", "llm_string") == gens0 diff --git a/libs/astradb/tests/integration_tests/test_caches.py b/libs/astradb/tests/integration_tests/test_caches.py deleted file mode 100644 index 62ef6a3..0000000 --- a/libs/astradb/tests/integration_tests/test_caches.py +++ /dev/null @@ -1,393 +0,0 @@ -"""Test AstraDB caches. Requires an Astra DB vector instance. - -Required to run this test: - - a recent `astrapy` Python package available - - an Astra DB instance; - - the two environment variables set: - export ASTRA_DB_API_ENDPOINT="https://-us-east1.apps.astra.datastax.com" - export ASTRA_DB_APPLICATION_TOKEN="AstraCS:........." - - optionally this as well (otherwise defaults are used): - export ASTRA_DB_KEYSPACE="my_keyspace" -""" - -from __future__ import annotations - -import os -from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, Mapping, Optional, cast - -import pytest -from langchain_core.embeddings import Embeddings -from langchain_core.globals import get_llm_cache, set_llm_cache -from langchain_core.language_models import LLM -from langchain_core.outputs import Generation, LLMResult -from typing_extensions import override - -from langchain_astradb import AstraDBCache, AstraDBSemanticCache -from langchain_astradb.utils.astradb import SetupMode - -from .conftest import AstraDBCredentials, _has_env_vars - -if TYPE_CHECKING: - from astrapy.db import AstraDB - from langchain_core.caches import BaseCache - from langchain_core.callbacks import CallbackManagerForLLMRun - - -class FakeEmbeddings(Embeddings): - """Fake embeddings functionality for testing.""" - - @override - def embed_documents(self, texts: list[str]) -> list[list[float]]: - """Return simple embeddings. - Embeddings encode each text as its index. - """ - return [[1.0] * 9 + [float(i)] for i in range(len(texts))] - - @override - async def aembed_documents(self, texts: list[str]) -> list[list[float]]: - return self.embed_documents(texts) - - @override - def embed_query(self, text: str) -> list[float]: - """Return constant query embeddings. - Embeddings are identical to embed_documents(texts)[0]. - Distance to each text will be that text's index, - as it was passed to embed_documents. - """ - return [1.0] * 9 + [0.0] - - @override - async def aembed_query(self, text: str) -> list[float]: - return self.embed_query(text) - - -class FakeLLM(LLM): - """Fake LLM wrapper for testing purposes.""" - - queries: Optional[Mapping] = None # noqa: UP007 - sequential_responses: Optional[bool] = False # noqa: UP007 - response_index: int = 0 - - @override - def get_num_tokens(self, text: str) -> int: - """Return number of tokens.""" - return len(text.split()) - - @property - @override - def _llm_type(self) -> str: - """Return type of llm.""" - return "fake" - - @override - def _call( - self, - prompt: str, - stop: list[str] | None = None, - run_manager: CallbackManagerForLLMRun | None = None, - **kwargs: Any, - ) -> str: - if self.sequential_responses: - return self._get_next_response_in_sequence - if self.queries is not None: - return self.queries[prompt] - return "foo" if stop is None else "bar" - - @property - @override - def _identifying_params(self) -> dict[str, Any]: - return {} - - @property - def _get_next_response_in_sequence(self) -> str: - queries = cast(Mapping, self.queries) - response = queries[list(queries.keys())[self.response_index]] - self.response_index = self.response_index + 1 - return response - - -@pytest.fixture(scope="module") -def astradb_cache(astra_db_credentials: AstraDBCredentials) -> Iterator[AstraDBCache]: - cache = AstraDBCache( - collection_name="lc_integration_test_cache", - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - ) - yield cache - cache.collection.drop() - - -@pytest.fixture -async def async_astradb_cache( - astra_db_credentials: AstraDBCredentials, -) -> AsyncIterator[AstraDBCache]: - cache = AstraDBCache( - collection_name="lc_integration_test_cache_async", - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - setup_mode=SetupMode.ASYNC, - ) - yield cache - await cache.async_collection.drop() - - -@pytest.fixture(scope="module") -def astradb_semantic_cache( - astra_db_credentials: AstraDBCredentials, -) -> Iterator[AstraDBSemanticCache]: - fake_embe = FakeEmbeddings() - sem_cache = AstraDBSemanticCache( - collection_name="lc_integration_test_sem_cache", - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - embedding=fake_embe, - ) - yield sem_cache - sem_cache.collection.drop() - - -@pytest.fixture -async def async_astradb_semantic_cache( - astra_db_credentials: AstraDBCredentials, -) -> AsyncIterator[AstraDBSemanticCache]: - fake_embe = FakeEmbeddings() - sem_cache = AstraDBSemanticCache( - collection_name="lc_integration_test_sem_cache_async", - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - embedding=fake_embe, - setup_mode=SetupMode.ASYNC, - ) - yield sem_cache - sem_cache.collection.drop() - - -@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars") -class TestAstraDBCaches: - def test_astradb_cache_sync(self, astradb_cache: AstraDBCache) -> None: - self.do_cache_test(FakeLLM(), astradb_cache, "foo") - - async def test_astradb_cache_async(self, async_astradb_cache: AstraDBCache) -> None: - await self.ado_cache_test(FakeLLM(), async_astradb_cache, "foo") - - def test_astradb_semantic_cache_sync( - self, astradb_semantic_cache: AstraDBSemanticCache - ) -> None: - llm = FakeLLM() - self.do_cache_test(llm, astradb_semantic_cache, "bar") - output = llm.generate(["bar"]) # 'fizz' is erased away now - assert output != LLMResult( - generations=[[Generation(text="fizz")]], - llm_output={}, - ) - astradb_semantic_cache.clear() - - async def test_astradb_semantic_cache_async( - self, async_astradb_semantic_cache: AstraDBSemanticCache - ) -> None: - llm = FakeLLM() - await self.ado_cache_test(llm, async_astradb_semantic_cache, "bar") - output = await llm.agenerate(["bar"]) # 'fizz' is erased away now - assert output != LLMResult( - generations=[[Generation(text="fizz")]], - llm_output={}, - ) - await async_astradb_semantic_cache.aclear() - - @staticmethod - def do_cache_test(llm: LLM, cache: BaseCache, prompt: str) -> None: - set_llm_cache(cache) - params = llm.dict() - params["stop"] = None - llm_string = str(sorted(params.items())) - get_llm_cache().update("foo", llm_string, [Generation(text="fizz")]) - output = llm.generate([prompt]) - expected_output = LLMResult( - generations=[[Generation(text="fizz")]], - llm_output={}, - ) - assert output == expected_output - # clear the cache - cache.clear() - - @staticmethod - async def ado_cache_test(llm: LLM, cache: BaseCache, prompt: str) -> None: - set_llm_cache(cache) - params = llm.dict() - params["stop"] = None - llm_string = str(sorted(params.items())) - await get_llm_cache().aupdate("foo", llm_string, [Generation(text="fizz")]) - output = await llm.agenerate([prompt]) - expected_output = LLMResult( - generations=[[Generation(text="fizz")]], - llm_output={}, - ) - assert output == expected_output - # clear the cache - await cache.aclear() - - @pytest.mark.skipif( - os.environ.get("ASTRA_DB_ENVIRONMENT", "prod").upper() != "PROD", - reason="Can run on Astra DB prod only", - ) - def test_cache_coreclients_init_sync( - self, - astra_db_credentials: AstraDBCredentials, - core_astra_db: AstraDB, - ) -> None: - """A deprecation warning from passing a (core) AstraDB, but it works.""" - collection_name = "lc_test_cache_coreclsync" - test_gens = [Generation(text="ret_val0123")] - try: - cache_init_ok = AstraDBCache( - collection_name=collection_name, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - ) - cache_init_ok.update("pr", "llms", test_gens) - # create an equivalent cache with core AstraDB in init - with pytest.warns(DeprecationWarning) as rec_warnings: - cache_init_core = AstraDBCache( - collection_name=collection_name, - astra_db_client=core_astra_db, - ) - f_rec_warnings = [ - wrn - for wrn in rec_warnings - if issubclass(wrn.category, DeprecationWarning) - ] - assert len(f_rec_warnings) == 1 - assert cache_init_core.lookup("pr", "llms") == test_gens - finally: - cache_init_ok.astra_env.database.drop_collection(collection_name) - - @pytest.mark.skipif( - os.environ.get("ASTRA_DB_ENVIRONMENT", "prod").upper() != "PROD", - reason="Can run on Astra DB prod only", - ) - async def test_cache_coreclients_init_async( - self, - astra_db_credentials: AstraDBCredentials, - core_astra_db: AstraDB, - ) -> None: - """A deprecation warning from passing a (core) AstraDB, but it works.""" - collection_name = "lc_test_cache_coreclasync" - test_gens = [Generation(text="ret_val4567")] - try: - cache_init_ok = AstraDBCache( - collection_name=collection_name, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - setup_mode=SetupMode.ASYNC, - ) - await cache_init_ok.aupdate("pr", "llms", test_gens) - # create an equivalent cache with core AstraDB in init - with pytest.warns(DeprecationWarning) as rec_warnings: - cache_init_core = AstraDBCache( - collection_name=collection_name, - astra_db_client=core_astra_db, - setup_mode=SetupMode.ASYNC, - ) - f_rec_warnings = [ - wrn - for wrn in rec_warnings - if issubclass(wrn.category, DeprecationWarning) - ] - assert len(f_rec_warnings) == 1 - assert await cache_init_core.alookup("pr", "llms") == test_gens - finally: - await cache_init_ok.astra_env.async_database.drop_collection( - collection_name - ) - - @pytest.mark.skipif( - os.environ.get("ASTRA_DB_ENVIRONMENT", "prod").upper() != "PROD", - reason="Can run on Astra DB prod only", - ) - def test_semcache_coreclients_init_sync( - self, - astra_db_credentials: AstraDBCredentials, - core_astra_db: AstraDB, - ) -> None: - """A deprecation warning from passing a (core) AstraDB, but it works.""" - fake_embe = FakeEmbeddings() - collection_name = "lc_test_cache_coreclsync" - test_gens = [Generation(text="ret_val0123")] - try: - cache_init_ok = AstraDBSemanticCache( - collection_name=collection_name, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - embedding=fake_embe, - ) - cache_init_ok.update("pr", "llms", test_gens) - # create an equivalent cache with core AstraDB in init - with pytest.warns(DeprecationWarning) as rec_warnings: - cache_init_core = AstraDBSemanticCache( - collection_name=collection_name, - astra_db_client=core_astra_db, - embedding=fake_embe, - ) - f_rec_warnings = [ - wrn - for wrn in rec_warnings - if issubclass(wrn.category, DeprecationWarning) - ] - assert len(f_rec_warnings) == 1 - assert cache_init_core.lookup("pr", "llms") == test_gens - finally: - cache_init_ok.astra_env.database.drop_collection(collection_name) - - @pytest.mark.skipif( - os.environ.get("ASTRA_DB_ENVIRONMENT", "prod").upper() != "PROD", - reason="Can run on Astra DB prod only", - ) - async def test_semcache_coreclients_init_async( - self, - astra_db_credentials: AstraDBCredentials, - core_astra_db: AstraDB, - ) -> None: - """A deprecation warning from passing a (core) AstraDB, but it works.""" - fake_embe = FakeEmbeddings() - collection_name = "lc_test_cache_coreclasync" - test_gens = [Generation(text="ret_val4567")] - try: - cache_init_ok = AstraDBSemanticCache( - collection_name=collection_name, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - setup_mode=SetupMode.ASYNC, - embedding=fake_embe, - ) - await cache_init_ok.aupdate("pr", "llms", test_gens) - # create an equivalent cache with core AstraDB in init - with pytest.warns(DeprecationWarning) as rec_warnings: - cache_init_core = AstraDBSemanticCache( - collection_name=collection_name, - astra_db_client=core_astra_db, - setup_mode=SetupMode.ASYNC, - embedding=fake_embe, - ) - f_rec_warnings = [ - wrn - for wrn in rec_warnings - if issubclass(wrn.category, DeprecationWarning) - ] - assert len(f_rec_warnings) == 1 - assert await cache_init_core.alookup("pr", "llms") == test_gens - finally: - await cache_init_ok.astra_env.async_database.drop_collection( - collection_name - ) diff --git a/libs/astradb/tests/integration_tests/test_chat_message_histories.py b/libs/astradb/tests/integration_tests/test_chat_message_histories.py index a758a76..a8bc6ec 100644 --- a/libs/astradb/tests/integration_tests/test_chat_message_histories.py +++ b/libs/astradb/tests/integration_tests/test_chat_message_histories.py @@ -1,7 +1,7 @@ import os -from typing import AsyncIterable, Iterable import pytest +from astrapy import Collection from astrapy.db import AstraDB from langchain.memory import ConversationBufferMemory from langchain_core.messages import AIMessage, HumanMessage @@ -11,33 +11,36 @@ ) from langchain_astradb.utils.astradb import SetupMode -from .conftest import AstraDBCredentials, _has_env_vars +from .conftest import ( + COLLECTION_NAME_IDXALL, + AstraDBCredentials, + astra_db_env_vars_available, +) @pytest.fixture def history1( astra_db_credentials: AstraDBCredentials, -) -> Iterable[AstraDBChatMessageHistory]: - history1 = AstraDBChatMessageHistory( + empty_collection_idxall: Collection, # noqa: ARG001 +) -> AstraDBChatMessageHistory: + return AstraDBChatMessageHistory( session_id="session-test-1", - collection_name="langchain_cmh_test", + collection_name=COLLECTION_NAME_IDXALL, token=astra_db_credentials["token"], api_endpoint=astra_db_credentials["api_endpoint"], namespace=astra_db_credentials["namespace"], environment=astra_db_credentials["environment"], ) - yield history1 - history1.collection.drop() @pytest.fixture def history2( - history1: AstraDBChatMessageHistory, astra_db_credentials: AstraDBCredentials, + history1: AstraDBChatMessageHistory, # noqa: ARG001 ) -> AstraDBChatMessageHistory: return AstraDBChatMessageHistory( session_id="session-test-2", - collection_name=history1.collection_name, + collection_name=COLLECTION_NAME_IDXALL, token=astra_db_credentials["token"], api_endpoint=astra_db_credentials["api_endpoint"], namespace=astra_db_credentials["namespace"], @@ -46,34 +49,32 @@ def history2( # no two createCollection calls at once are issued: setup_mode=SetupMode.OFF, ) - # no deletion here, this is riding on history1 @pytest.fixture async def async_history1( astra_db_credentials: AstraDBCredentials, -) -> AsyncIterable[AstraDBChatMessageHistory]: - history1 = AstraDBChatMessageHistory( + history1: AstraDBChatMessageHistory, # noqa: ARG001 +) -> AstraDBChatMessageHistory: + return AstraDBChatMessageHistory( session_id="async-session-test-1", - collection_name="langchain_cmh_test", + collection_name=COLLECTION_NAME_IDXALL, token=astra_db_credentials["token"], api_endpoint=astra_db_credentials["api_endpoint"], namespace=astra_db_credentials["namespace"], environment=astra_db_credentials["environment"], - setup_mode=SetupMode.ASYNC, + setup_mode=SetupMode.OFF, ) - yield history1 - await history1.async_collection.drop() @pytest.fixture async def async_history2( - history1: AstraDBChatMessageHistory, astra_db_credentials: AstraDBCredentials, + history1: AstraDBChatMessageHistory, # noqa: ARG001 ) -> AstraDBChatMessageHistory: return AstraDBChatMessageHistory( session_id="async-session-test-2", - collection_name=history1.collection_name, + collection_name=COLLECTION_NAME_IDXALL, token=astra_db_credentials["token"], api_endpoint=astra_db_credentials["api_endpoint"], namespace=astra_db_credentials["namespace"], @@ -82,10 +83,11 @@ async def async_history2( # no two createCollection calls at once are issued: setup_mode=SetupMode.OFF, ) - # no deletion here, this is riding on history1 -@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars") +@pytest.mark.skipif( + not astra_db_env_vars_available(), reason="Missing Astra DB env. vars" +) class TestAstraDBChatMessageHistories: def test_memory_with_message_store( self, history1: AstraDBChatMessageHistory @@ -200,79 +202,70 @@ async def test_memory_separate_session_ids_async( @pytest.mark.skipif( os.environ.get("ASTRA_DB_ENVIRONMENT", "prod").upper() != "PROD", - reason="Can run on Astra DB prod only", + reason="Can run on Astra DB production environment only", ) - def test_chatms_coreclients_init_sync( + def test_chatmsh_coreclients_init_sync( self, astra_db_credentials: AstraDBCredentials, core_astra_db: AstraDB, + empty_collection_idxall: Collection, # noqa: ARG002 ) -> None: """A deprecation warning from passing a (core) AstraDB, but it works.""" - collection_name = "lc_test_cmh_coreclsync" test_messages = [AIMessage(content="Meow.")] - try: - chatmh_init_ok = AstraDBChatMessageHistory( + chatmh_init_ok = AstraDBChatMessageHistory( + session_id="gattini", + collection_name=COLLECTION_NAME_IDXALL, + token=astra_db_credentials["token"], + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + setup_mode=SetupMode.OFF, + ) + chatmh_init_ok.add_messages(test_messages) + # create an equivalent cache with core AstraDB in init + with pytest.warns(DeprecationWarning) as rec_warnings: + chatmh_init_core = AstraDBChatMessageHistory( + collection_name=COLLECTION_NAME_IDXALL, session_id="gattini", - collection_name=collection_name, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], + astra_db_client=core_astra_db, ) - chatmh_init_ok.add_messages(test_messages) - # create an equivalent cache with core AstraDB in init - with pytest.warns(DeprecationWarning) as rec_warnings: - chatmh_init_core = AstraDBChatMessageHistory( - collection_name=collection_name, - session_id="gattini", - astra_db_client=core_astra_db, - ) - f_rec_warnings = [ - wrn - for wrn in rec_warnings - if issubclass(wrn.category, DeprecationWarning) - ] - assert len(f_rec_warnings) == 1 - assert chatmh_init_core.messages == test_messages - finally: - chatmh_init_ok.astra_env.collection.drop() + f_rec_warnings = [ + wrn for wrn in rec_warnings if issubclass(wrn.category, DeprecationWarning) + ] + assert len(f_rec_warnings) == 1 + assert chatmh_init_core.messages == test_messages @pytest.mark.skipif( os.environ.get("ASTRA_DB_ENVIRONMENT", "prod").upper() != "PROD", - reason="Can run on Astra DB prod only", + reason="Can run on Astra DB production environment only", ) - async def test_chatms_coreclients_init_async( + async def test_chatmsh_coreclients_init_async( self, astra_db_credentials: AstraDBCredentials, core_astra_db: AstraDB, + empty_collection_idxall: Collection, # noqa: ARG002 ) -> None: """A deprecation warning from passing a (core) AstraDB, but it works.""" - collection_name = "lc_test_cmh_coreclasync" test_messages = [AIMessage(content="Ameow.")] - try: - chatmh_init_ok = AstraDBChatMessageHistory( + chatmh_init_ok = AstraDBChatMessageHistory( + session_id="gattini", + collection_name=COLLECTION_NAME_IDXALL, + token=astra_db_credentials["token"], + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + setup_mode=SetupMode.OFF, + ) + await chatmh_init_ok.aadd_messages(test_messages) + # create an equivalent cache with core AstraDB in init + with pytest.warns(DeprecationWarning) as rec_warnings: + chatmh_init_core = AstraDBChatMessageHistory( + collection_name=COLLECTION_NAME_IDXALL, session_id="gattini", - collection_name=collection_name, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], + astra_db_client=core_astra_db, setup_mode=SetupMode.ASYNC, ) - await chatmh_init_ok.aadd_messages(test_messages) - # create an equivalent cache with core AstraDB in init - with pytest.warns(DeprecationWarning) as rec_warnings: - chatmh_init_core = AstraDBChatMessageHistory( - collection_name=collection_name, - session_id="gattini", - astra_db_client=core_astra_db, - setup_mode=SetupMode.ASYNC, - ) - # cleaning out 'spurious' "unclosed socket/transport..." warnings - f_rec_warnings = [ - wrn - for wrn in rec_warnings - if issubclass(wrn.category, DeprecationWarning) - ] - assert len(f_rec_warnings) == 1 - assert await chatmh_init_core.aget_messages() == test_messages - finally: - await chatmh_init_ok.astra_env.async_collection.drop() + # cleaning out 'spurious' "unclosed socket/transport..." warnings + f_rec_warnings = [ + wrn for wrn in rec_warnings if issubclass(wrn.category, DeprecationWarning) + ] + assert len(f_rec_warnings) == 1 + assert await chatmh_init_core.aget_messages() == test_messages diff --git a/libs/astradb/tests/integration_tests/test_document_loaders.py b/libs/astradb/tests/integration_tests/test_document_loaders.py index efe62fb..193d46e 100644 --- a/libs/astradb/tests/integration_tests/test_document_loaders.py +++ b/libs/astradb/tests/integration_tests/test_document_loaders.py @@ -1,75 +1,59 @@ -"""Test of Astra DB document loader class `AstraDBLoader` - -Required to run this test: - - a recent `astrapy` Python package available - - an Astra DB instance; - - the two environment variables set: - export ASTRA_DB_API_ENDPOINT="https://-us-east1.apps.astra.datastax.com" - export ASTRA_DB_APPLICATION_TOKEN="AstraCS:........." - - optionally this as well (otherwise defaults are used): - export ASTRA_DB_KEYSPACE="my_keyspace" -""" +"""Test of Astra DB document loader class `AstraDBLoader`""" from __future__ import annotations import json import os -import uuid -from typing import TYPE_CHECKING, AsyncIterator, Iterator +from typing import TYPE_CHECKING import pytest +from astrapy.authentication import StaticTokenProvider from langchain_astradb import AstraDBLoader -from .conftest import AstraDBCredentials, _has_env_vars +from .conftest import ( + COLLECTION_NAME_IDXALL, + AstraDBCredentials, + astra_db_env_vars_available, +) if TYPE_CHECKING: from astrapy import AsyncCollection, Collection, Database from astrapy.db import AstraDB -@pytest.fixture -def collection(database: Database) -> Iterator[Collection]: - collection_name = f"lc_test_loader_{str(uuid.uuid4()).split('-')[0]}" - collection = database.create_collection(collection_name) - collection.insert_many([{"foo": "bar", "baz": "qux"}] * 20) - collection.insert_many( - [{"foo": "bar2", "baz": "qux"}] * 4 + [{"foo": "bar", "baz": "qux"}] * 4 +@pytest.fixture(scope="module") +def document_loader_collection( + collection_idxall: Collection, +) -> Collection: + collection_idxall.delete_many({}) + collection_idxall.insert_many( + [{"foo": "bar", "baz": "qux"}] * 24 + [{"foo": "bar2", "baz": "qux"}] * 4 ) - - yield collection - - collection.drop() + return collection_idxall @pytest.fixture -async def async_collection(database: Database) -> AsyncIterator[AsyncCollection]: - adatabase = database.to_async() - collection_name = f"lc_test_loader_{str(uuid.uuid4()).split('-')[0]}" - collection = await adatabase.create_collection(collection_name) - await collection.insert_many([{"foo": "bar", "baz": "qux"}] * 20) - await collection.insert_many( - [{"foo": "bar2", "baz": "qux"}] * 4 + [{"foo": "bar", "baz": "qux"}] * 4 - ) - - yield collection +async def async_document_loader_collection( + collection_idxall: Collection, +) -> AsyncCollection: + return collection_idxall.to_async() - await collection.drop() - -@pytest.mark.requires("astrapy") -@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars") +@pytest.mark.skipif( + not astra_db_env_vars_available(), reason="Missing Astra DB env. vars" +) class TestAstraDB: def test_astradb_loader_prefetched_sync( self, - collection: Collection, astra_db_credentials: AstraDBCredentials, + document_loader_collection: Collection, # noqa: ARG002 ) -> None: """Using 'prefetched' should give a warning but work nonetheless.""" with pytest.warns(UserWarning) as rec_warnings: loader = AstraDBLoader( - collection.name, - token=astra_db_credentials["token"], + COLLECTION_NAME_IDXALL, + token=StaticTokenProvider(astra_db_credentials["token"]), api_endpoint=astra_db_credentials["api_endpoint"], namespace=astra_db_credentials["namespace"], environment=astra_db_credentials["environment"], @@ -86,14 +70,15 @@ def test_astradb_loader_prefetched_sync( docs = loader.load() assert len(docs) == 22 - def test_astradb_loader_sync( + def test_astradb_loader_base_sync( self, - collection: Collection, astra_db_credentials: AstraDBCredentials, + database: Database, + document_loader_collection: Collection, # noqa: ARG002 ) -> None: loader = AstraDBLoader( - collection.name, - token=astra_db_credentials["token"], + COLLECTION_NAME_IDXALL, + token=StaticTokenProvider(astra_db_credentials["token"]), api_endpoint=astra_db_credentials["api_endpoint"], namespace=astra_db_credentials["namespace"], environment=astra_db_credentials["environment"], @@ -102,8 +87,8 @@ def test_astradb_loader_sync( filter_criteria={"foo": "bar"}, ) docs = loader.load() - assert len(docs) == 22 + ids = set() for doc in docs: content = json.loads(doc.page_content) @@ -112,19 +97,32 @@ def test_astradb_loader_sync( assert content["_id"] not in ids ids.add(content["_id"]) assert doc.metadata == { - "namespace": collection.namespace, - "api_endpoint": collection.database.api_endpoint, - "collection": collection.name, + "namespace": database.namespace, + "api_endpoint": astra_db_credentials["api_endpoint"], + "collection": COLLECTION_NAME_IDXALL, } + loader2 = AstraDBLoader( + COLLECTION_NAME_IDXALL, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + projection={"foo": 1}, + limit=22, + filter_criteria={"foo": "bar2"}, + ) + docs2 = loader2.load() + assert len(docs2) == 4 + def test_page_content_mapper_sync( self, - collection: Collection, astra_db_credentials: AstraDBCredentials, + document_loader_collection: Collection, # noqa: ARG002 ) -> None: loader = AstraDBLoader( - collection.name, - token=astra_db_credentials["token"], + COLLECTION_NAME_IDXALL, + token=StaticTokenProvider(astra_db_credentials["token"]), api_endpoint=astra_db_credentials["api_endpoint"], namespace=astra_db_credentials["namespace"], environment=astra_db_credentials["environment"], @@ -139,12 +137,12 @@ def test_page_content_mapper_sync( def test_metadata_mapper_sync( self, - collection: Collection, astra_db_credentials: AstraDBCredentials, + document_loader_collection: Collection, # noqa: ARG002 ) -> None: loader = AstraDBLoader( - collection.name, - token=astra_db_credentials["token"], + COLLECTION_NAME_IDXALL, + token=StaticTokenProvider(astra_db_credentials["token"]), api_endpoint=astra_db_credentials["api_endpoint"], namespace=astra_db_credentials["namespace"], environment=astra_db_credentials["environment"], @@ -159,14 +157,15 @@ def test_metadata_mapper_sync( async def test_astradb_loader_prefetched_async( self, - async_collection: AsyncCollection, astra_db_credentials: AstraDBCredentials, + database: Database, + async_document_loader_collection: AsyncCollection, # noqa: ARG002 ) -> None: """Using 'prefetched' should give a warning but work nonetheless.""" with pytest.warns(UserWarning) as rec_warnings: loader = AstraDBLoader( - async_collection.name, - token=astra_db_credentials["token"], + COLLECTION_NAME_IDXALL, + token=StaticTokenProvider(astra_db_credentials["token"]), api_endpoint=astra_db_credentials["api_endpoint"], namespace=astra_db_credentials["namespace"], environment=astra_db_credentials["environment"], @@ -180,18 +179,44 @@ async def test_astradb_loader_prefetched_async( wrn for wrn in rec_warnings if issubclass(wrn.category, UserWarning) ] assert len(f_rec_warnings) == 1 - docs = await loader.aload() assert len(docs) == 22 - async def test_astradb_loader_async( + ids = set() + for doc in docs: + content = json.loads(doc.page_content) + assert content["foo"] == "bar" + assert "baz" not in content + assert content["_id"] not in ids + ids.add(content["_id"]) + assert doc.metadata == { + "namespace": database.namespace, + "api_endpoint": astra_db_credentials["api_endpoint"], + "collection": COLLECTION_NAME_IDXALL, + } + + loader2 = AstraDBLoader( + COLLECTION_NAME_IDXALL, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + projection={"foo": 1}, + limit=22, + filter_criteria={"foo": "bar2"}, + ) + docs2 = await loader2.aload() + assert len(docs2) == 4 + + async def test_astradb_loader_base_async( self, - async_collection: AsyncCollection, astra_db_credentials: AstraDBCredentials, + database: Database, + async_document_loader_collection: AsyncCollection, # noqa: ARG002 ) -> None: loader = AstraDBLoader( - async_collection.name, - token=astra_db_credentials["token"], + COLLECTION_NAME_IDXALL, + token=StaticTokenProvider(astra_db_credentials["token"]), api_endpoint=astra_db_credentials["api_endpoint"], namespace=astra_db_credentials["namespace"], environment=astra_db_credentials["environment"], @@ -210,19 +235,19 @@ async def test_astradb_loader_async( assert content["_id"] not in ids ids.add(content["_id"]) assert doc.metadata == { - "namespace": async_collection.namespace, - "api_endpoint": async_collection.database.api_endpoint, - "collection": async_collection.name, + "namespace": database.namespace, + "api_endpoint": astra_db_credentials["api_endpoint"], + "collection": COLLECTION_NAME_IDXALL, } async def test_page_content_mapper_async( self, - async_collection: AsyncCollection, astra_db_credentials: AstraDBCredentials, + async_document_loader_collection: AsyncCollection, # noqa: ARG002 ) -> None: loader = AstraDBLoader( - async_collection.name, - token=astra_db_credentials["token"], + COLLECTION_NAME_IDXALL, + token=StaticTokenProvider(astra_db_credentials["token"]), api_endpoint=astra_db_credentials["api_endpoint"], namespace=astra_db_credentials["namespace"], environment=astra_db_credentials["environment"], @@ -235,12 +260,12 @@ async def test_page_content_mapper_async( async def test_metadata_mapper_async( self, - async_collection: AsyncCollection, astra_db_credentials: AstraDBCredentials, + async_document_loader_collection: AsyncCollection, # noqa: ARG002 ) -> None: loader = AstraDBLoader( - async_collection.name, - token=astra_db_credentials["token"], + COLLECTION_NAME_IDXALL, + token=StaticTokenProvider(astra_db_credentials["token"]), api_endpoint=astra_db_credentials["api_endpoint"], namespace=astra_db_credentials["namespace"], environment=astra_db_credentials["environment"], @@ -258,13 +283,16 @@ async def test_metadata_mapper_async( def test_astradb_loader_coreclients_init( self, astra_db_credentials: AstraDBCredentials, - collection: Collection, core_astra_db: AstraDB, + document_loader_collection: Collection, # noqa: ARG002 ) -> None: - """A deprecation warning from passing a (core) AstraDB, but it works.""" + """ + A deprecation warning from passing a (core) AstraDB, but it works. + Note there is no sync/async here: this class always has SetupMode.OFF. + """ loader_init_ok = AstraDBLoader( - collection_name=collection.name, - token=astra_db_credentials["token"], + collection_name=COLLECTION_NAME_IDXALL, + token=StaticTokenProvider(astra_db_credentials["token"]), api_endpoint=astra_db_credentials["api_endpoint"], namespace=astra_db_credentials["namespace"], limit=1, @@ -273,7 +301,7 @@ def test_astradb_loader_coreclients_init( # create an equivalent loader with core AstraDB in init with pytest.warns(DeprecationWarning) as rec_warnings: loader_init_core = AstraDBLoader( - collection_name=collection.name, + collection_name=COLLECTION_NAME_IDXALL, astra_db_client=core_astra_db, limit=1, ) @@ -286,12 +314,12 @@ def test_astradb_loader_coreclients_init( def test_astradb_loader_findoptions_deprecation( self, astra_db_credentials: AstraDBCredentials, - collection: Collection, + document_loader_collection: Collection, # noqa: ARG002 ) -> None: """Test deprecation of 'find_options' and related warnings/errors.""" loader0 = AstraDBLoader( - collection_name=collection.name, - token=astra_db_credentials["token"], + collection_name=COLLECTION_NAME_IDXALL, + token=StaticTokenProvider(astra_db_credentials["token"]), api_endpoint=astra_db_credentials["api_endpoint"], namespace=astra_db_credentials["namespace"], environment=astra_db_credentials["environment"], @@ -301,8 +329,8 @@ def test_astradb_loader_findoptions_deprecation( with pytest.warns(DeprecationWarning) as rec_warnings: loader_lo = AstraDBLoader( - collection_name=collection.name, - token=astra_db_credentials["token"], + collection_name=COLLECTION_NAME_IDXALL, + token=StaticTokenProvider(astra_db_credentials["token"]), api_endpoint=astra_db_credentials["api_endpoint"], namespace=astra_db_credentials["namespace"], environment=astra_db_credentials["environment"], @@ -316,8 +344,8 @@ def test_astradb_loader_findoptions_deprecation( with pytest.raises(ValueError, match="Duplicate 'limit' directive supplied."): AstraDBLoader( - collection_name=collection.name, - token=astra_db_credentials["token"], + collection_name=COLLECTION_NAME_IDXALL, + token=StaticTokenProvider(astra_db_credentials["token"]), api_endpoint=astra_db_credentials["api_endpoint"], namespace=astra_db_credentials["namespace"], environment=astra_db_credentials["environment"], @@ -327,8 +355,8 @@ def test_astradb_loader_findoptions_deprecation( with pytest.warns(DeprecationWarning) as rec_warnings: loader_uo = AstraDBLoader( - collection_name=collection.name, - token=astra_db_credentials["token"], + collection_name=COLLECTION_NAME_IDXALL, + token=StaticTokenProvider(astra_db_credentials["token"]), api_endpoint=astra_db_credentials["api_endpoint"], namespace=astra_db_credentials["namespace"], environment=astra_db_credentials["environment"], diff --git a/libs/astradb/tests/integration_tests/test_graphvectorstore.py b/libs/astradb/tests/integration_tests/test_graphvectorstore.py index 570ce9c..1ff5a77 100644 --- a/libs/astradb/tests/integration_tests/test_graphvectorstore.py +++ b/libs/astradb/tests/integration_tests/test_graphvectorstore.py @@ -7,11 +7,9 @@ from __future__ import annotations -import os -from typing import TYPE_CHECKING, Iterable +from typing import TYPE_CHECKING import pytest -from astrapy import DataAPIClient from astrapy.authentication import StaticTokenProvider from langchain_core.documents import Document from langchain_core.graph_vectorstores.base import Node @@ -19,72 +17,95 @@ from langchain_astradb.graph_vectorstores import AstraDBGraphVectorStore from langchain_astradb.utils.astradb import SetupMode -from tests.conftest import ParserEmbeddings -from .conftest import AstraDBCredentials, _has_env_vars +from .conftest import ( + COLLECTION_NAME_D2, + CUSTOM_CONTENT_KEY, + EPHEMERAL_COLLECTION_NAME_IDXALL_D2, + LONG_TEXT, + astra_db_env_vars_available, +) if TYPE_CHECKING: - from astrapy import Collection + from astrapy import Collection, Database from langchain_core.embeddings import Embeddings -# Faster testing (no actual collection deletions). Off by default (=full tests) -SKIP_COLLECTION_DELETE = ( - int(os.environ.get("ASTRA_DB_SKIP_COLLECTION_DELETIONS", "0")) != 0 -) - -GVS_NOVECTORIZE_COLLECTION = "lc_gvs_novectorize" -# for testing with autodetect -CUSTOM_CONTENT_KEY = "xcontent" -LONG_TEXT = "This is the textual content field in the doc." - - -@pytest.fixture(scope="session") -def provisioned_novectorize_collection( - astra_db_credentials: AstraDBCredentials, -) -> Iterable[Collection]: - """Provision a general-purpose collection for the no-vectorize tests.""" - client = DataAPIClient(environment=astra_db_credentials["environment"]) - database = client.get_database( - astra_db_credentials["api_endpoint"], - token=StaticTokenProvider(astra_db_credentials["token"]), - namespace=astra_db_credentials["namespace"], - ) - collection = database.create_collection( - GVS_NOVECTORIZE_COLLECTION, - dimension=2, - check_exists=False, - metric="euclidean", - ) - yield collection - - if not SKIP_COLLECTION_DELETE: - collection.drop() + from .conftest import AstraDBCredentials @pytest.fixture -def novectorize_empty_collection( - provisioned_novectorize_collection: Collection, -) -> Iterable[Collection]: - provisioned_novectorize_collection.delete_many({}) - yield provisioned_novectorize_collection +def graph_vector_store_docs() -> list[Document]: + """ + This is a set of Documents to pre-populate a graph vector store, + with entries placed in a certain way. - provisioned_novectorize_collection.delete_many({}) + Space of the entries (under Euclidean similarity): + A0 (*) + .... AL AR <.... + : | : + : | ^ : + v | . v + | : + TR | : BL + T0 --------------x-------------- B0 + TL | : BR + | : + | . + | . + | + FL FR + F0 -@pytest.fixture -def embedding() -> Embeddings: - return ParserEmbeddings(dimension=2) + the query point is meant to be at (*). + the A are bidirectionally with B + the A are outgoing to T + the A are incoming from F + The links are like: L with L, 0 with 0 and R with R. + """ + + docs_a = [ + Document(page_content="[-1, 9]", metadata={"label": "AL"}), + Document(page_content="[0, 10]", metadata={"label": "A0"}), + Document(page_content="[1, 9]", metadata={"label": "AR"}), + ] + docs_b = [ + Document(page_content="[9, 1]", metadata={"label": "BL"}), + Document(page_content="[10, 0]", metadata={"label": "B0"}), + Document(page_content="[9, -1]", metadata={"label": "BR"}), + ] + docs_f = [ + Document(page_content="[1, -9]", metadata={"label": "BL"}), + Document(page_content="[0, -10]", metadata={"label": "B0"}), + Document(page_content="[-1, -9]", metadata={"label": "BR"}), + ] + docs_t = [ + Document(page_content="[-9, -1]", metadata={"label": "TL"}), + Document(page_content="[-10, 0]", metadata={"label": "T0"}), + Document(page_content="[-9, 1]", metadata={"label": "TR"}), + ] + for doc_a, suffix in zip(docs_a, ["l", "0", "r"]): + add_links(doc_a, Link.bidir(kind="ab_example", tag=f"tag_{suffix}")) + add_links(doc_a, Link.outgoing(kind="at_example", tag=f"tag_{suffix}")) + add_links(doc_a, Link.incoming(kind="af_example", tag=f"tag_{suffix}")) + for doc_b, suffix in zip(docs_b, ["l", "0", "r"]): + add_links(doc_b, Link.bidir(kind="ab_example", tag=f"tag_{suffix}")) + for doc_t, suffix in zip(docs_t, ["l", "0", "r"]): + add_links(doc_t, Link.incoming(kind="at_example", tag=f"tag_{suffix}")) + for doc_f, suffix in zip(docs_f, ["l", "0", "r"]): + add_links(doc_f, Link.outgoing(kind="af_example", tag=f"tag_{suffix}")) + return docs_a + docs_b + docs_f + docs_t @pytest.fixture -def novectorize_empty_graph_store( - novectorize_empty_collection: Collection, # noqa: ARG001 +def graph_vector_store_d2( astra_db_credentials: AstraDBCredentials, - embedding: Embeddings, + empty_collection_d2: Collection, # noqa: ARG001 + embedding_d2: Embeddings, ) -> AstraDBGraphVectorStore: return AstraDBGraphVectorStore( - embedding=embedding, - collection_name=GVS_NOVECTORIZE_COLLECTION, + embedding=embedding_d2, + collection_name=COLLECTION_NAME_D2, token=StaticTokenProvider(astra_db_credentials["token"]), api_endpoint=astra_db_credentials["api_endpoint"], namespace=astra_db_credentials["namespace"], @@ -94,18 +115,34 @@ def novectorize_empty_graph_store( @pytest.fixture -def novectorize_autodetect_full_graph_store( +def populated_graph_vector_store_d2( + graph_vector_store_d2: AstraDBGraphVectorStore, + graph_vector_store_docs: list[Document], +) -> AstraDBGraphVectorStore: + graph_vector_store_d2.add_documents(graph_vector_store_docs) + return graph_vector_store_d2 + + +@pytest.fixture +def autodetect_populated_graph_vector_store_d2( astra_db_credentials: AstraDBCredentials, - novectorize_empty_collection: Collection, - embedding: Embeddings, - graph_docs: list[Document], + database: Database, + embedding_d2: Embeddings, + graph_vector_store_docs: list[Document], + ephemeral_collection_cleaner_idxall_d2: str, # noqa: ARG001 ) -> AstraDBGraphVectorStore: """ Pre-populate the collection and have (VectorStore)autodetect work on it, then create and return a GraphVectorStore, additionally filled with - the same (graph-)entries as for `novectorize_full_graph_store`. + the same (graph-)entries as for `populated_graph_vector_store_d2`. """ - novectorize_empty_collection.insert_many( + empty_collection_d2_idxall = database.create_collection( + EPHEMERAL_COLLECTION_NAME_IDXALL_D2, + dimension=2, + check_exists=False, + metric="euclidean", + ) + empty_collection_d2_idxall.insert_many( [ { CUSTOM_CONTENT_KEY: LONG_TEXT, @@ -128,8 +165,8 @@ def novectorize_autodetect_full_graph_store( ] ) gstore = AstraDBGraphVectorStore( - embedding=embedding, - collection_name=GVS_NOVECTORIZE_COLLECTION, + embedding=embedding_d2, + collection_name=EPHEMERAL_COLLECTION_NAME_IDXALL_D2, link_to_metadata_key="x_link_to_x", link_from_metadata_key="x_link_from_x", token=StaticTokenProvider(astra_db_credentials["token"]), @@ -139,14 +176,15 @@ def novectorize_autodetect_full_graph_store( content_field="*", autodetect_collection=True, ) - gstore.add_documents(graph_docs) + gstore.add_documents(graph_vector_store_docs) return gstore def assert_all_flat_docs(collection: Collection) -> None: """ - Check that after graph-insertions all docs in the store - still obey the underlying autodetected doc schema on DB. + Check that all docs in the store obey the underlying (flat) autodetected + doc schema on DB. + Useful for checking the store after graph-store-driven insertions. """ for doc in collection.find({}, projection={"*": True}): assert all(not isinstance(v, dict) for v in doc.values()) @@ -154,86 +192,15 @@ def assert_all_flat_docs(collection: Collection) -> None: assert isinstance(doc["$vector"], list) -@pytest.fixture -def graph_docs() -> list[Document]: - """ - This is a pre-populated graph vector store, - with entries placed in a certain way. - - Space of the entries (under Euclidean similarity): - - A0 (*) - .... AL AR <.... - : | : - : | ^ : - v | . v - | : - TR | : BL - T0 --------------x-------------- B0 - TL | : BR - | : - | . - | . - | - FL FR - F0 - - the query point is at (*). - the A are bidirectionally with B - the A are outgoing to T - the A are incoming from F - The links are like: L with L, 0 with 0 and R with R. - """ - - docs_a = [ - Document(page_content="[-1, 9]", metadata={"label": "AL"}), - Document(page_content="[0, 10]", metadata={"label": "A0"}), - Document(page_content="[1, 9]", metadata={"label": "AR"}), - ] - docs_b = [ - Document(page_content="[9, 1]", metadata={"label": "BL"}), - Document(page_content="[10, 0]", metadata={"label": "B0"}), - Document(page_content="[9, -1]", metadata={"label": "BR"}), - ] - docs_f = [ - Document(page_content="[1, -9]", metadata={"label": "BL"}), - Document(page_content="[0, -10]", metadata={"label": "B0"}), - Document(page_content="[-1, -9]", metadata={"label": "BR"}), - ] - docs_t = [ - Document(page_content="[-9, -1]", metadata={"label": "TL"}), - Document(page_content="[-10, 0]", metadata={"label": "T0"}), - Document(page_content="[-9, 1]", metadata={"label": "TR"}), - ] - for doc_a, suffix in zip(docs_a, ["l", "0", "r"]): - add_links(doc_a, Link.bidir(kind="ab_example", tag=f"tag_{suffix}")) - add_links(doc_a, Link.outgoing(kind="at_example", tag=f"tag_{suffix}")) - add_links(doc_a, Link.incoming(kind="af_example", tag=f"tag_{suffix}")) - for doc_b, suffix in zip(docs_b, ["l", "0", "r"]): - add_links(doc_b, Link.bidir(kind="ab_example", tag=f"tag_{suffix}")) - for doc_t, suffix in zip(docs_t, ["l", "0", "r"]): - add_links(doc_t, Link.incoming(kind="at_example", tag=f"tag_{suffix}")) - for doc_f, suffix in zip(docs_f, ["l", "0", "r"]): - add_links(doc_f, Link.outgoing(kind="af_example", tag=f"tag_{suffix}")) - return docs_a + docs_b + docs_f + docs_t - - -@pytest.fixture -def novectorize_full_graph_store( - novectorize_empty_graph_store: AstraDBGraphVectorStore, - graph_docs: list[Document], -) -> AstraDBGraphVectorStore: - novectorize_empty_graph_store.add_documents(graph_docs) - return novectorize_empty_graph_store - - -@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars") +@pytest.mark.skipif( + not astra_db_env_vars_available(), reason="Missing Astra DB env. vars" +) class TestAstraDBGraphVectorStore: @pytest.mark.parametrize( ("store_name", "is_autodetected"), [ - ("novectorize_full_graph_store", False), - ("novectorize_autodetect_full_graph_store", True), + ("populated_graph_vector_store_d2", False), + ("autodetect_populated_graph_vector_store_d2", True), ], ids=["native_store", "autodetected_store"], ) @@ -258,8 +225,8 @@ def test_gvs_similarity_search( @pytest.mark.parametrize( ("store_name", "is_autodetected"), [ - ("novectorize_full_graph_store", False), - ("novectorize_autodetect_full_graph_store", True), + ("populated_graph_vector_store_d2", False), + ("autodetect_populated_graph_vector_store_d2", True), ], ids=["native_store", "autodetected_store"], ) @@ -283,8 +250,8 @@ def test_gvs_traversal_search( @pytest.mark.parametrize( ("store_name", "is_autodetected"), [ - ("novectorize_full_graph_store", False), - ("novectorize_autodetect_full_graph_store", True), + ("populated_graph_vector_store_d2", False), + ("autodetect_populated_graph_vector_store_d2", True), ], ids=["native_store", "autodetected_store"], ) @@ -315,15 +282,15 @@ def test_gvs_from_texts( self, *, astra_db_credentials: AstraDBCredentials, - novectorize_empty_collection: Collection, # noqa: ARG002 - embedding: Embeddings, + empty_collection_d2: Collection, # noqa: ARG002 + embedding_d2: Embeddings, ) -> None: g_store = AstraDBGraphVectorStore.from_texts( texts=["[1, 2]"], - embedding=embedding, + embedding=embedding_d2, metadatas=[{"md": 1}], ids=["x_id"], - collection_name=GVS_NOVECTORIZE_COLLECTION, + collection_name=COLLECTION_NAME_D2, token=StaticTokenProvider(astra_db_credentials["token"]), api_endpoint=astra_db_credentials["api_endpoint"], namespace=astra_db_credentials["namespace"], @@ -342,8 +309,8 @@ def test_gvs_from_documents_containing_ids( self, *, astra_db_credentials: AstraDBCredentials, - novectorize_empty_collection: Collection, # noqa: ARG002 - embedding: Embeddings, + empty_collection_d2: Collection, # noqa: ARG002 + embedding_d2: Embeddings, ) -> None: the_document = Document( page_content="[1, 2]", @@ -352,8 +319,8 @@ def test_gvs_from_documents_containing_ids( ) g_store = AstraDBGraphVectorStore.from_documents( documents=[the_document], - embedding=embedding, - collection_name=GVS_NOVECTORIZE_COLLECTION, + embedding=embedding_d2, + collection_name=COLLECTION_NAME_D2, token=StaticTokenProvider(astra_db_credentials["token"]), api_endpoint=astra_db_credentials["api_endpoint"], namespace=astra_db_credentials["namespace"], @@ -371,7 +338,7 @@ def test_gvs_from_documents_containing_ids( def test_gvs_add_nodes( self, *, - novectorize_empty_graph_store: AstraDBGraphVectorStore, + graph_vector_store_d2: AstraDBGraphVectorStore, ) -> None: links0 = [ Link(kind="kA", direction="out", tag="tA"), @@ -384,8 +351,8 @@ def test_gvs_add_nodes( Node(id="id0", text="[0, 2]", metadata={"m": 0}, links=links0), Node(text="[0, 1]", metadata={"m": 1}, links=links1), ] - novectorize_empty_graph_store.add_nodes(nodes) - hits = novectorize_empty_graph_store.similarity_search_by_vector([0, 3]) + graph_vector_store_d2.add_nodes(nodes) + hits = graph_vector_store_d2.similarity_search_by_vector([0, 3]) assert len(hits) == 2 assert hits[0].id == "id0" assert hits[0].page_content == "[0, 2]" diff --git a/libs/astradb/tests/integration_tests/test_semantic_cache.py b/libs/astradb/tests/integration_tests/test_semantic_cache.py new file mode 100644 index 0000000..58dcb2d --- /dev/null +++ b/libs/astradb/tests/integration_tests/test_semantic_cache.py @@ -0,0 +1,265 @@ +from __future__ import annotations + +import os +from typing import TYPE_CHECKING + +import pytest +from astrapy.authentication import StaticTokenProvider +from langchain_core.globals import get_llm_cache, set_llm_cache +from langchain_core.outputs import Generation, LLMResult + +from langchain_astradb import AstraDBSemanticCache +from langchain_astradb.utils.astradb import SetupMode + +from .conftest import ( + COLLECTION_NAME_IDXALL_D2, + AstraDBCredentials, + astra_db_env_vars_available, +) + +if TYPE_CHECKING: + from astrapy import Collection + from astrapy.db import AstraDB + from langchain_core.embeddings import Embeddings + + from .conftest import IdentityLLM + + +@pytest.fixture +def astradb_semantic_cache( + astra_db_credentials: AstraDBCredentials, + empty_collection_idxall_d2: Collection, # noqa: ARG001 + embedding_d2: Embeddings, +) -> AstraDBSemanticCache: + return AstraDBSemanticCache( + collection_name=COLLECTION_NAME_IDXALL_D2, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + embedding=embedding_d2, + metric="euclidean", + ) + + +@pytest.fixture +async def astradb_semantic_cache_async( + astra_db_credentials: AstraDBCredentials, + empty_collection_idxall_d2: Collection, # noqa: ARG001 + embedding_d2: Embeddings, +) -> AstraDBSemanticCache: + return AstraDBSemanticCache( + collection_name=COLLECTION_NAME_IDXALL_D2, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + embedding=embedding_d2, + metric="euclidean", + setup_mode=SetupMode.ASYNC, + ) + + +@pytest.mark.skipif( + not astra_db_env_vars_available(), reason="Missing Astra DB env. vars" +) +class TestAstraDBSemanticCache: + def test_semantic_cache_crud_sync( + self, + astradb_semantic_cache: AstraDBSemanticCache, + ) -> None: + """Tests for basic cache CRUD, not involving an LLM.""" + gens0 = [Generation(text="gen0_text")] + gens1 = [Generation(text="gen1_text")] + assert astradb_semantic_cache.lookup_with_id("[1,2]", "lms") is None + + astradb_semantic_cache.update("[0.999,2.001]", "lms", gens0) + astradb_semantic_cache.update("[2.999,4.001]", "lms", gens1) + + hit12 = astradb_semantic_cache.lookup_with_id("[1,2]", "lms") + assert hit12 is not None + assert hit12[1] == gens0 + hit34 = astradb_semantic_cache.lookup_with_id("[3,4]", "lms") + assert hit34 is not None + assert hit34[1] == gens1 + + astradb_semantic_cache.delete_by_document_id(hit12[0]) + assert astradb_semantic_cache.lookup_with_id("[1,2]", "lms") is None + hit34_b = astradb_semantic_cache.lookup_with_id("[3,4]", "lms") + assert hit34_b is not None + assert hit34_b[1] == gens1 + + astradb_semantic_cache.clear() + assert astradb_semantic_cache.lookup_with_id("[1,2]", "lms") is None + assert astradb_semantic_cache.lookup_with_id("[3,4]", "lms") is None + + async def test_semantic_cache_crud_async( + self, + astradb_semantic_cache_async: AstraDBSemanticCache, + ) -> None: + """Tests for basic cache CRUD, not involving an LLM. Async version""" + gens0 = [Generation(text="gen0_text")] + gens1 = [Generation(text="gen1_text")] + assert ( + await astradb_semantic_cache_async.alookup_with_id("[1,2]", "lms") is None + ) + + await astradb_semantic_cache_async.aupdate("[0.999,2.001]", "lms", gens0) + await astradb_semantic_cache_async.aupdate("[2.999,4.001]", "lms", gens1) + + hit12 = await astradb_semantic_cache_async.alookup_with_id("[1,2]", "lms") + assert hit12 is not None + assert hit12[1] == gens0 + hit34 = await astradb_semantic_cache_async.alookup_with_id("[3,4]", "lms") + assert hit34 is not None + assert hit34[1] == gens1 + + await astradb_semantic_cache_async.adelete_by_document_id(hit12[0]) + assert ( + await astradb_semantic_cache_async.alookup_with_id("[1,2]", "lms") is None + ) + hit34_b = await astradb_semantic_cache_async.alookup_with_id("[3,4]", "lms") + assert hit34_b is not None + assert hit34_b[1] == gens1 + + await astradb_semantic_cache_async.aclear() + assert ( + await astradb_semantic_cache_async.alookup_with_id("[1,2]", "lms") is None + ) + assert ( + await astradb_semantic_cache_async.alookup_with_id("[3,4]", "lms") is None + ) + + def test_semantic_cache_through_llm_sync( + self, + test_llm: IdentityLLM, + astradb_semantic_cache: AstraDBSemanticCache, + ) -> None: + """Tests for semantic cache as used with a (mock) LLM.""" + gens0 = [Generation(text="gen0_text")] + set_llm_cache(astradb_semantic_cache) + + params = {"stop": None, **test_llm.dict()} + llm_string = str(sorted(params.items())) + + assert test_llm.num_calls == 0 + + # inject cache entry, check no LLM call is done + get_llm_cache().update("[1,2]", llm_string, gens0) + output = test_llm.generate(["[0.999,2.001]"]) + expected_output = LLMResult( + generations=[gens0], + llm_output={}, + ) + assert test_llm.num_calls == 0 + assert output == expected_output + + # check *one* new call for a new prompt, even if 'generate' repeated + test_llm.generate(["[3,4]"]) + test_llm.generate(["[3,4]"]) + test_llm.generate(["[3,4]"]) + assert test_llm.num_calls == 1 + + # clear the cache and check a new LLM call is actually made + astradb_semantic_cache.clear() + test_llm.generate(["[3,4]"]) + test_llm.generate(["[3,4]"]) + assert test_llm.num_calls == 2 + + async def test_semantic_cache_through_llm_async( + self, + test_llm: IdentityLLM, + astradb_semantic_cache: AstraDBSemanticCache, + ) -> None: + """Tests for semantic cache as used with a (mock) LLM, async version.""" + gens0 = [Generation(text="gen0_text")] + set_llm_cache(astradb_semantic_cache) + + params = {"stop": None, **test_llm.dict()} + llm_string = str(sorted(params.items())) + + assert test_llm.num_calls == 0 + + # inject cache entry, check no LLM call is done + await get_llm_cache().aupdate("[1,2]", llm_string, gens0) + output = await test_llm.agenerate(["[0.999,2.001]"]) + expected_output = LLMResult( + generations=[gens0], + llm_output={}, + ) + assert test_llm.num_calls == 0 + assert output == expected_output + + # check *one* new call for a new prompt, even if 'generate' repeated + await test_llm.agenerate(["[3,4]"]) + await test_llm.agenerate(["[3,4]"]) + await test_llm.agenerate(["[3,4]"]) + assert test_llm.num_calls == 1 + + # clear the cache and check a new LLM call is actually made + await astradb_semantic_cache.aclear() + await test_llm.agenerate(["[3,4]"]) + await test_llm.agenerate(["[3,4]"]) + assert test_llm.num_calls == 2 + + @pytest.mark.skipif( + os.environ.get("ASTRA_DB_ENVIRONMENT", "prod").upper() != "PROD", + reason="Can run on Astra DB production environment only", + ) + def test_semcache_coreclients_init_sync( + self, + core_astra_db: AstraDB, + embedding_d2: Embeddings, + astradb_semantic_cache: AstraDBSemanticCache, + ) -> None: + """A deprecation warning from passing a (core) AstraDB, but it works.""" + gens0 = [Generation(text="gen0_text")] + astradb_semantic_cache.update("[0.999,2.001]", "llm_string", gens0) + # create an equivalent cache with core AstraDB in init + with pytest.warns(DeprecationWarning) as rec_warnings: + semantic_cache_init_core = AstraDBSemanticCache( + collection_name=COLLECTION_NAME_IDXALL_D2, + astra_db_client=core_astra_db, + embedding=embedding_d2, + metric="euclidean", + ) + f_rec_warnings = [ + wrn for wrn in rec_warnings if issubclass(wrn.category, DeprecationWarning) + ] + assert len(f_rec_warnings) == 1 + hit12 = semantic_cache_init_core.lookup_with_id("[1,2]", "llm_string") + assert hit12 is not None + assert hit12[1] == gens0 + + @pytest.mark.skipif( + os.environ.get("ASTRA_DB_ENVIRONMENT", "prod").upper() != "PROD", + reason="Can run on Astra DB production environment only", + ) + async def test_semcache_coreclients_init_async( + self, + core_astra_db: AstraDB, + embedding_d2: Embeddings, + astradb_semantic_cache_async: AstraDBSemanticCache, + ) -> None: + """ + A deprecation warning from passing a (core) AstraDB, but it works. + Async version. + """ + gens0 = [Generation(text="gen0_text")] + await astradb_semantic_cache_async.aupdate("[0.999,2.001]", "llm_string", gens0) + # create an equivalent cache with core AstraDB in init + with pytest.warns(DeprecationWarning) as rec_warnings: + semantic_cache_init_core = AstraDBSemanticCache( + collection_name=COLLECTION_NAME_IDXALL_D2, + astra_db_client=core_astra_db, + embedding=embedding_d2, + metric="euclidean", + setup_mode=SetupMode.ASYNC, + ) + f_rec_warnings = [ + wrn for wrn in rec_warnings if issubclass(wrn.category, DeprecationWarning) + ] + assert len(f_rec_warnings) == 1 + hit12 = await semantic_cache_init_core.alookup_with_id("[1,2]", "llm_string") + assert hit12 is not None + assert hit12[1] == gens0 diff --git a/libs/astradb/tests/integration_tests/test_storage.py b/libs/astradb/tests/integration_tests/test_storage.py index f3f48b0..faa7b73 100644 --- a/libs/astradb/tests/integration_tests/test_storage.py +++ b/libs/astradb/tests/integration_tests/test_storage.py @@ -6,129 +6,110 @@ from typing import TYPE_CHECKING import pytest +from astrapy.authentication import StaticTokenProvider from langchain_astradb.storage import AstraDBByteStore, AstraDBStore from langchain_astradb.utils.astradb import SetupMode -from .conftest import _has_env_vars +from .conftest import ( + COLLECTION_NAME_IDXID, + EPHEMERAL_CUSTOM_IDX_NAME, + EPHEMERAL_LEGACY_IDX_NAME, + AstraDBCredentials, + astra_db_env_vars_available, +) if TYPE_CHECKING: - from astrapy import Database + from astrapy import Collection, Database from astrapy.db import AstraDB -def init_store( - astra_db_credentials: dict[str, str | None], - collection_name: str, +@pytest.fixture +def astra_db_empty_store( + astra_db_credentials: AstraDBCredentials, + collection_idxid: Collection, ) -> AstraDBStore: - store = AstraDBStore( - collection_name=collection_name, - token=astra_db_credentials["token"], + collection_idxid.delete_many({}) + return AstraDBStore( + collection_name=COLLECTION_NAME_IDXID, + token=StaticTokenProvider(astra_db_credentials["token"]), api_endpoint=astra_db_credentials["api_endpoint"], namespace=astra_db_credentials["namespace"], environment=astra_db_credentials["environment"], + setup_mode=SetupMode.OFF, ) - store.mset([("key1", [0.1, 0.2]), ("key2", "value2")]) - return store -def init_bytestore( - astra_db_credentials: dict[str, str | None], - collection_name: str, -) -> AstraDBByteStore: - store = AstraDBByteStore( - collection_name=collection_name, - token=astra_db_credentials["token"], +@pytest.fixture +async def astra_db_empty_store_async( + astra_db_credentials: AstraDBCredentials, + collection_idxid: Collection, +) -> AstraDBStore: + collection_idxid.delete_many({}) + return AstraDBStore( + collection_name=COLLECTION_NAME_IDXID, + token=StaticTokenProvider(astra_db_credentials["token"]), api_endpoint=astra_db_credentials["api_endpoint"], namespace=astra_db_credentials["namespace"], environment=astra_db_credentials["environment"], + setup_mode=SetupMode.ASYNC, ) - store.mset([("key1", b"value1"), ("key2", b"value2")]) - return store -async def init_async_store( - astra_db_credentials: dict[str, str | None], collection_name: str -) -> AstraDBStore: - store = AstraDBStore( - collection_name=collection_name, - token=astra_db_credentials["token"], +@pytest.fixture +def astra_db_empty_byte_store( + astra_db_credentials: AstraDBCredentials, + collection_idxid: Collection, +) -> AstraDBByteStore: + collection_idxid.delete_many({}) + return AstraDBByteStore( + collection_name=COLLECTION_NAME_IDXID, + token=StaticTokenProvider(astra_db_credentials["token"]), api_endpoint=astra_db_credentials["api_endpoint"], namespace=astra_db_credentials["namespace"], environment=astra_db_credentials["environment"], - setup_mode=SetupMode.ASYNC, + setup_mode=SetupMode.OFF, ) - await store.amset([("key1", [0.1, 0.2]), ("key2", "value2")]) - return store -@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars") +@pytest.mark.skipif( + not astra_db_env_vars_available(), reason="Missing Astra DB env. vars" +) class TestAstraDBStore: - def test_mget( - self, - astra_db_credentials: dict[str, str | None], - ) -> None: - """Test AstraDBStore mget method.""" - collection_name = "lc_test_store_mget" - try: - store = init_store(astra_db_credentials, collection_name) - assert store.mget(["key1", "key2"]) == [[0.1, 0.2], "value2"] - finally: - store.astra_env.database.drop_collection(collection_name) - - async def test_amget( + def test_store_crud_sync( self, - astra_db_credentials: dict[str, str | None], + astra_db_empty_store: AstraDBStore, ) -> None: - """Test AstraDBStore amget method.""" - collection_name = "lc_test_store_mget" - try: - store = await init_async_store(astra_db_credentials, collection_name) - assert await store.amget(["key1", "key2"]) == [[0.1, 0.2], "value2"] - finally: - await store.astra_env.async_database.drop_collection(collection_name) - - def test_mset( - self, - astra_db_credentials: dict[str, str | None], - ) -> None: - """Test that multiple keys can be set with AstraDBStore.""" - collection_name = "lc_test_store_mset" - try: - store = init_store(astra_db_credentials, collection_name) - result = store.collection.find_one({"_id": "key1"}) - assert (result or {})["value"] == [0.1, 0.2] - result = store.collection.find_one({"_id": "key2"}) - assert (result or {})["value"] == "value2" - finally: - store.astra_env.database.drop_collection(collection_name) - - async def test_amset( + """Test AstraDBStore mget/mset/mdelete method.""" + astra_db_empty_store.mset([("key1", [0.1, 0.2]), ("key2", "value2")]) + assert astra_db_empty_store.mget(["key1", "key2"]) == [[0.1, 0.2], "value2"] + astra_db_empty_store.mdelete(["key1", "key2"]) + assert astra_db_empty_store.mget(["key1", "key2"]) == [None, None] + + async def test_store_crud_async( self, - astra_db_credentials: dict[str, str | None], + astra_db_empty_store_async: AstraDBStore, ) -> None: - """Test that multiple keys can be set with AstraDBStore.""" - collection_name = "lc_test_store_mset" - try: - store = await init_async_store(astra_db_credentials, collection_name) - result = await store.async_collection.find_one({"_id": "key1"}) - assert (result or {})["value"] == [0.1, 0.2] - result = await store.async_collection.find_one({"_id": "key2"}) - assert (result or {})["value"] == "value2" - finally: - await store.astra_env.async_database.drop_collection(collection_name) - - def test_store_massive_mset_with_replace( + """Test AstraDBStore amget/amset/amdelete method. Async version.""" + await astra_db_empty_store_async.amset( + [("key1", [0.1, 0.2]), ("key2", "value2")] + ) + assert await astra_db_empty_store_async.amget(["key1", "key2"]) == [ + [0.1, 0.2], + "value2", + ] + await astra_db_empty_store_async.amdelete(["key1", "key2"]) + assert await astra_db_empty_store_async.amget(["key1", "key2"]) == [None, None] + + def test_store_massive_write_with_replace_sync( self, - astra_db_credentials: dict[str, str | None], + astra_db_empty_store: AstraDBStore, ) -> None: """Testing the insert-many-and-replace-some patterns thoroughly.""" full_size = 300 first_group_size = 150 second_group_slicer = [30, 100, 2] max_values_in_in = 100 - collection_name = "lc_test_store_massive_mset" - ids_and_texts = [ ( f"doc_{idx}", @@ -136,63 +117,49 @@ def test_store_massive_mset_with_replace( ) for idx in range(full_size) ] - try: - store = AstraDBStore( - collection_name=collection_name, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - ) - # massive insertion on empty (zip and rezip for uniformity with later) - group0_ids, group0_texts = list(zip(*ids_and_texts[0:first_group_size])) - store.mset(list(zip(group0_ids, group0_texts))) - - # massive insertion with many overwrites scattered through - # (we change the text to later check on DB for successful update) - _s, _e, _st = second_group_slicer - group1_ids, group1_texts_pre = list( - zip( - *( - ids_and_texts[_s:_e:_st] - + ids_and_texts[first_group_size:full_size] - ) - ) + # massive insertion on empty (zip and rezip for uniformity with later) + group0_ids, group0_texts = list(zip(*ids_and_texts[0:first_group_size])) + astra_db_empty_store.mset(list(zip(group0_ids, group0_texts))) + + # massive insertion with many overwrites scattered through + # (we change the text to later check on DB for successful update) + _s, _e, _st = second_group_slicer + group1_ids, group1_texts_pre = list( + zip(*(ids_and_texts[_s:_e:_st] + ids_and_texts[first_group_size:full_size])) + ) + group1_texts = [txt.upper() for txt in group1_texts_pre] + astra_db_empty_store.mset(list(zip(group1_ids, group1_texts))) + + # final read (we want the IDs to do a full check) + expected_text_by_id = { + **dict(zip(group0_ids, group0_texts)), + **dict(zip(group1_ids, group1_texts)), + } + all_ids = [doc_id for doc_id, _ in ids_and_texts] + # The Data API can handle at most max_values_in_in entries, let's chunk + all_vals = [ + val + for chunk_start in range(0, full_size, max_values_in_in) + for val in astra_db_empty_store.mget( + all_ids[chunk_start : chunk_start + max_values_in_in] ) - group1_texts = [txt.upper() for txt in group1_texts_pre] - store.mset(list(zip(group1_ids, group1_texts))) - - # final read (we want the IDs to do a full check) - expected_text_by_id = { - **dict(zip(group0_ids, group0_texts)), - **dict(zip(group1_ids, group1_texts)), - } - all_ids = [doc_id for doc_id, _ in ids_and_texts] - # The Data API can handle at most max_values_in_in entries, let's chunk - all_vals = [ - val - for chunk_start in range(0, full_size, max_values_in_in) - for val in store.mget( - all_ids[chunk_start : chunk_start + max_values_in_in] - ) - ] - for val, doc_id in zip(all_vals, all_ids): - assert val == expected_text_by_id[doc_id] - finally: - store.astra_env.database.drop_collection(collection_name) + ] + for val, doc_id in zip(all_vals, all_ids): + assert val == expected_text_by_id[doc_id] - async def test_store_massive_amset_with_replace( + async def test_store_massive_write_with_replace_async( self, - astra_db_credentials: dict[str, str | None], + astra_db_empty_store_async: AstraDBStore, ) -> None: - """Testing the insert-many-and-replace-some patterns thoroughly.""" + """ + Testing the insert-many-and-replace-some patterns thoroughly. + Async version. + """ full_size = 300 first_group_size = 150 second_group_slicer = [30, 100, 2] max_values_in_in = 100 - collection_name = "lc_test_store_massive_amset" - ids_and_texts = [ ( f"doc_{idx}", @@ -201,167 +168,181 @@ async def test_store_massive_amset_with_replace( for idx in range(full_size) ] - try: - store = AstraDBStore( - collection_name=collection_name, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - ) + # massive insertion on empty (zip and rezip for uniformity with later) + group0_ids, group0_texts = list(zip(*ids_and_texts[0:first_group_size])) + await astra_db_empty_store_async.amset(list(zip(group0_ids, group0_texts))) - # massive insertion on empty (zip and rezip for uniformity with later) - group0_ids, group0_texts = list(zip(*ids_and_texts[0:first_group_size])) - await store.amset(list(zip(group0_ids, group0_texts))) - - # massive insertion with many overwrites scattered through - # (we change the text to later check on DB for successful update) - _s, _e, _st = second_group_slicer - group1_ids, group1_texts_pre = list( - zip( - *( - ids_and_texts[_s:_e:_st] - + ids_and_texts[first_group_size:full_size] - ) - ) + # massive insertion with many overwrites scattered through + # (we change the text to later check on DB for successful update) + _s, _e, _st = second_group_slicer + group1_ids, group1_texts_pre = list( + zip(*(ids_and_texts[_s:_e:_st] + ids_and_texts[first_group_size:full_size])) + ) + group1_texts = [txt.upper() for txt in group1_texts_pre] + await astra_db_empty_store_async.amset(list(zip(group1_ids, group1_texts))) + + # final read (we want the IDs to do a full check) + expected_text_by_id = { + **dict(zip(group0_ids, group0_texts)), + **dict(zip(group1_ids, group1_texts)), + } + all_ids = [doc_id for doc_id, _ in ids_and_texts] + # The Data API can handle at most max_values_in_in entries, let's chunk + all_vals = [ + val + for chunk_start in range(0, full_size, max_values_in_in) + for val in await astra_db_empty_store_async.amget( + all_ids[chunk_start : chunk_start + max_values_in_in] ) - group1_texts = [txt.upper() for txt in group1_texts_pre] - await store.amset(list(zip(group1_ids, group1_texts))) - - # final read (we want the IDs to do a full check) - expected_text_by_id = { - **dict(zip(group0_ids, group0_texts)), - **dict(zip(group1_ids, group1_texts)), - } - all_ids = [doc_id for doc_id, _ in ids_and_texts] - # The Data API can handle at most max_values_in_in entries, let's chunk - all_vals = [ - val - for chunk_start in range(0, full_size, max_values_in_in) - for val in await store.amget( - all_ids[chunk_start : chunk_start + max_values_in_in] - ) - ] - for val, doc_id in zip(all_vals, all_ids): - assert val == expected_text_by_id[doc_id] - finally: - store.astra_env.database.drop_collection(collection_name) + ] + for val, doc_id in zip(all_vals, all_ids): + assert val == expected_text_by_id[doc_id] - def test_mdelete( - self, - astra_db_credentials: dict[str, str | None], - ) -> None: - """Test that deletion works as expected.""" - collection_name = "lc_test_store_mdelete" - try: - store = init_store(astra_db_credentials, collection_name) - store.mdelete(["key1", "key2"]) - result = store.mget(["key1", "key2"]) - assert result == [None, None] - finally: - store.astra_env.database.drop_collection(collection_name) - - async def test_amdelete( + def test_store_yield_keys_sync( self, - astra_db_credentials: dict[str, str | None], + astra_db_empty_store: AstraDBStore, ) -> None: - """Test that deletion works as expected.""" - collection_name = "lc_test_store_mdelete" - try: - store = await init_async_store(astra_db_credentials, collection_name) - await store.amdelete(["key1", "key2"]) - result = await store.amget(["key1", "key2"]) - assert result == [None, None] - finally: - await store.astra_env.async_database.drop_collection(collection_name) - - def test_yield_keys( + """Test of yield_keys.""" + astra_db_empty_store.mset([("key1", [0.1, 0.2]), ("key2", "value2")]) + assert set(astra_db_empty_store.yield_keys()) == {"key1", "key2"} + assert set(astra_db_empty_store.yield_keys(prefix="key")) == {"key1", "key2"} + assert set(astra_db_empty_store.yield_keys(prefix="lang")) == set() + + async def test_store_yield_keys_async( self, - astra_db_credentials: dict[str, str | None], + astra_db_empty_store_async: AstraDBStore, ) -> None: - collection_name = "lc_test_store_yield_keys" - try: - store = init_store(astra_db_credentials, collection_name) - assert set(store.yield_keys()) == {"key1", "key2"} - assert set(store.yield_keys(prefix="key")) == {"key1", "key2"} - assert set(store.yield_keys(prefix="lang")) == set() - finally: - store.astra_env.database.drop_collection(collection_name) - - async def test_ayield_keys( + """Test of yield_keys, async version""" + await astra_db_empty_store_async.amset( + [("key1", [0.1, 0.2]), ("key2", "value2")] + ) + assert {k async for k in astra_db_empty_store_async.ayield_keys()} == { + "key1", + "key2", + } + assert { + k async for k in astra_db_empty_store_async.ayield_keys(prefix="key") + } == {"key1", "key2"} + assert { + k async for k in astra_db_empty_store_async.ayield_keys(prefix="lang") + } == set() + + def test_bytestore_crud_sync( self, - astra_db_credentials: dict[str, str | None], + astra_db_empty_byte_store: AstraDBByteStore, ) -> None: - collection_name = "lc_test_store_yield_keys" - try: - store = await init_async_store(astra_db_credentials, collection_name) - assert {key async for key in store.ayield_keys()} == {"key1", "key2"} - assert {key async for key in store.ayield_keys(prefix="key")} == { - "key1", - "key2", - } - assert {key async for key in store.ayield_keys(prefix="lang")} == set() - finally: - await store.astra_env.async_database.drop_collection(collection_name) - - def test_bytestore_mget( + """ + Test AstraDBByteStore mget/mset/mdelete method. + + Since this class shares most of its logic with AstraDBStore, + there's no need to test async nor the other methods/pathways. + """ + astra_db_empty_byte_store.mset([("key1", b"value1"), ("key2", b"value2")]) + assert astra_db_empty_byte_store.mget(["key1", "key2"]) == [ + b"value1", + b"value2", + ] + astra_db_empty_byte_store.mdelete(["key1", "key2"]) + assert astra_db_empty_byte_store.mget(["key1", "key2"]) == [None, None] + + @pytest.mark.skipif( + os.environ.get("ASTRA_DB_ENVIRONMENT", "prod").upper() != "PROD", + reason="Can run on Astra DB production environment only", + ) + def test_store_coreclients_init_sync( self, - astra_db_credentials: dict[str, str | None], + core_astra_db: AstraDB, + astra_db_empty_store: AstraDBStore, ) -> None: - """Test AstraDBByteStore mget method.""" - collection_name = "lc_test_bytestore_mget" - try: - store = init_bytestore(astra_db_credentials, collection_name) - assert store.mget(["key1", "key2"]) == [b"value1", b"value2"] - finally: - store.astra_env.database.drop_collection(collection_name) - - def test_bytestore_mset( + """A deprecation warning from passing a (core) AstraDB, but it works.""" + astra_db_empty_store.mset([("key", "val123")]) + + # create an equivalent store with core AstraDB in init + with pytest.warns(DeprecationWarning) as rec_warnings: + store_init_core = AstraDBStore( + collection_name=COLLECTION_NAME_IDXID, + astra_db_client=core_astra_db, + ) + f_rec_warnings = [ + wrn for wrn in rec_warnings if issubclass(wrn.category, DeprecationWarning) + ] + assert len(f_rec_warnings) == 1 + assert store_init_core.mget(["key"]) == ["val123"] + + @pytest.mark.skipif( + os.environ.get("ASTRA_DB_ENVIRONMENT", "prod").upper() != "PROD", + reason="Can run on Astra DB production environment only", + ) + async def test_store_coreclients_init_async( self, - astra_db_credentials: dict[str, str | None], + core_astra_db: AstraDB, + astra_db_empty_store_async: AstraDBStore, ) -> None: - """Test that multiple keys can be set with AstraDBByteStore.""" - collection_name = "lc_test_bytestore_mset" - try: - store = init_bytestore(astra_db_credentials, collection_name) - result = store.collection.find_one({"_id": "key1"}) - assert (result or {})["value"] == "dmFsdWUx" - result = store.collection.find_one({"_id": "key2"}) - assert (result or {})["value"] == "dmFsdWUy" - finally: - store.astra_env.database.drop_collection(collection_name) - - def test_indexing_detection( + """ + A deprecation warning from passing a (core) AstraDB, but it works. + Async version. + """ + await astra_db_empty_store_async.amset([("key", "val123")]) + # create an equivalent store with core AstraDB in init + with pytest.warns(DeprecationWarning) as rec_warnings: + store_init_core = AstraDBStore( + collection_name=COLLECTION_NAME_IDXID, + astra_db_client=core_astra_db, + setup_mode=SetupMode.ASYNC, + ) + f_rec_warnings = [ + wrn for wrn in rec_warnings if issubclass(wrn.category, DeprecationWarning) + ] + assert len(f_rec_warnings) == 1 + assert await store_init_core.amget(["key"]) == ["val123"] + + def test_store_indexing_default_sync( self, - astra_db_credentials: dict[str, str | None], - database: Database, + astra_db_credentials: AstraDBCredentials, + astra_db_empty_store: AstraDBStore, # noqa: ARG002 + ephemeral_indexing_collections_cleaner: list[str], # noqa: ARG002 ) -> None: - """Test the behaviour against preexisting legacy collections.""" - database.create_collection("lc_test_legacy_store") - database.create_collection( - "lc_test_custom_store", indexing={"allow": ["my_field"]} - ) + """Test of default-indexing re-instantiation.""" AstraDBStore( - collection_name="lc_test_regular_store", - token=astra_db_credentials["token"], + collection_name=COLLECTION_NAME_IDXID, + token=StaticTokenProvider(astra_db_credentials["token"]), api_endpoint=astra_db_credentials["api_endpoint"], namespace=astra_db_credentials["namespace"], environment=astra_db_credentials["environment"], ) - # repeated instantiation must work - AstraDBStore( - collection_name="lc_test_regular_store", - token=astra_db_credentials["token"], + async def test_store_indexing_default_async( + self, + astra_db_credentials: AstraDBCredentials, + astra_db_empty_store_async: AstraDBStore, # noqa: ARG002 + ephemeral_indexing_collections_cleaner: list[str], # noqa: ARG002 + ) -> None: + """Test of default-indexing re-instantiation, async version""" + await AstraDBStore( + collection_name=COLLECTION_NAME_IDXID, + token=StaticTokenProvider(astra_db_credentials["token"]), api_endpoint=astra_db_credentials["api_endpoint"], namespace=astra_db_credentials["namespace"], environment=astra_db_credentials["environment"], + setup_mode=SetupMode.ASYNC, + ).amget(["some_key"]) + + def test_store_indexing_on_legacy_sync( + self, + astra_db_credentials: AstraDBCredentials, + database: Database, + ephemeral_indexing_collections_cleaner: list[str], # noqa: ARG002 + ) -> None: + """Test of instantiation against a legacy collection.""" + database.create_collection( + EPHEMERAL_LEGACY_IDX_NAME, + indexing=None, + check_exists=False, ) - # on a legacy collection must just give a warning with pytest.warns(UserWarning) as rec_warnings: AstraDBStore( - collection_name="lc_test_legacy_store", - token=astra_db_credentials["token"], + collection_name=EPHEMERAL_LEGACY_IDX_NAME, + token=StaticTokenProvider(astra_db_credentials["token"]), api_endpoint=astra_db_credentials["api_endpoint"], namespace=astra_db_credentials["namespace"], environment=astra_db_credentials["environment"], @@ -370,92 +351,76 @@ def test_indexing_detection( wrn for wrn in rec_warnings if issubclass(wrn.category, UserWarning) ] assert len(f_rec_warnings) == 1 - # on a custom collection must error - with pytest.raises( - ValueError, match="is detected as having the following indexing policy" - ): - AstraDBStore( - collection_name="lc_test_custom_store", - token=astra_db_credentials["token"], + + async def test_store_indexing_on_legacy_async( + self, + astra_db_credentials: AstraDBCredentials, + database: Database, + ephemeral_indexing_collections_cleaner: list[str], # noqa: ARG002 + ) -> None: + """Test of instantiation against a legacy collection, async version.""" + database.create_collection( + EPHEMERAL_LEGACY_IDX_NAME, + indexing=None, + check_exists=False, + ) + with pytest.warns(UserWarning) as rec_warnings: + await AstraDBStore( + collection_name=EPHEMERAL_LEGACY_IDX_NAME, + token=StaticTokenProvider(astra_db_credentials["token"]), api_endpoint=astra_db_credentials["api_endpoint"], namespace=astra_db_credentials["namespace"], environment=astra_db_credentials["environment"], - ) - - database.drop_collection("lc_test_legacy_store") - database.drop_collection("lc_test_custom_store") - database.drop_collection("lc_test_regular_store") + setup_mode=SetupMode.ASYNC, + ).amget(["some_key"]) + f_rec_warnings = [ + wrn for wrn in rec_warnings if issubclass(wrn.category, UserWarning) + ] + assert len(f_rec_warnings) == 1 - @pytest.mark.skipif( - os.environ.get("ASTRA_DB_ENVIRONMENT", "prod").upper() != "PROD", - reason="Can run on Astra DB prod only", - ) - def test_store_coreclients_init_sync( + def test_store_indexing_on_custom_sync( self, - astra_db_credentials: dict[str, str | None], - core_astra_db: AstraDB, + astra_db_credentials: AstraDBCredentials, + database: Database, + ephemeral_indexing_collections_cleaner: list[str], # noqa: ARG002 ) -> None: - """A deprecation warning from passing a (core) AstraDB, but it works.""" - collection_name = "lc_test_bytestore_coreclsync" - try: - store_init_ok = AstraDBStore( - collection_name=collection_name, - token=astra_db_credentials["token"], + """Test of instantiation against a legacy collection.""" + database.create_collection( + EPHEMERAL_CUSTOM_IDX_NAME, + indexing={"deny": ["useless", "forgettable"]}, + check_exists=False, + ) + with pytest.raises( + ValueError, match="is detected as having the following indexing policy" + ): + AstraDBStore( + collection_name=EPHEMERAL_CUSTOM_IDX_NAME, + token=StaticTokenProvider(astra_db_credentials["token"]), api_endpoint=astra_db_credentials["api_endpoint"], namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], ) - store_init_ok.mset([("key", "val123")]) - # create an equivalent store with core AstraDB in init - with pytest.warns(DeprecationWarning) as rec_warnings: - store_init_core = AstraDBStore( - collection_name=collection_name, - astra_db_client=core_astra_db, - ) - f_rec_warnings = [ - wrn - for wrn in rec_warnings - if issubclass(wrn.category, DeprecationWarning) - ] - assert len(f_rec_warnings) == 1 - assert store_init_core.mget(["key"]) == ["val123"] - finally: - store_init_ok.astra_env.database.drop_collection(collection_name) - @pytest.mark.skipif( - os.environ.get("ASTRA_DB_ENVIRONMENT", "prod").upper() != "PROD", - reason="Can run on Astra DB prod only", - ) - async def test_store_coreclients_init_async( + async def test_store_indexing_on_custom_async( self, - astra_db_credentials: dict[str, str | None], - core_astra_db: AstraDB, + astra_db_credentials: AstraDBCredentials, + database: Database, + ephemeral_indexing_collections_cleaner: list[str], # noqa: ARG002 ) -> None: - """A deprecation warning from passing a (core) AstraDB, but it works.""" - collection_name = "lc_test_bytestore_coreclasync" - try: - store_init_ok = AstraDBStore( - collection_name=collection_name, - token=astra_db_credentials["token"], + """Test of instantiation against a legacy collection, async version.""" + database.create_collection( + EPHEMERAL_CUSTOM_IDX_NAME, + indexing={"deny": ["useless", "forgettable"]}, + check_exists=False, + ) + with pytest.raises( + ValueError, match="is detected as having the following indexing policy" + ): + await AstraDBStore( + collection_name=EPHEMERAL_CUSTOM_IDX_NAME, + token=StaticTokenProvider(astra_db_credentials["token"]), api_endpoint=astra_db_credentials["api_endpoint"], namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], setup_mode=SetupMode.ASYNC, - ) - await store_init_ok.amset([("key", "val123")]) - # create an equivalent store with core AstraDB in init - with pytest.warns(DeprecationWarning) as rec_warnings: - store_init_core = AstraDBStore( - collection_name=collection_name, - astra_db_client=core_astra_db, - setup_mode=SetupMode.ASYNC, - ) - f_rec_warnings = [ - wrn - for wrn in rec_warnings - if issubclass(wrn.category, DeprecationWarning) - ] - assert len(f_rec_warnings) == 1 - assert await store_init_core.amget(["key"]) == ["val123"] - finally: - await store_init_ok.astra_env.async_database.drop_collection( - collection_name - ) + ).amget(["some_key"]) diff --git a/libs/astradb/tests/integration_tests/test_vectorstore.py b/libs/astradb/tests/integration_tests/test_vectorstore.py new file mode 100644 index 0000000..dfc9c99 --- /dev/null +++ b/libs/astradb/tests/integration_tests/test_vectorstore.py @@ -0,0 +1,1302 @@ +"""Test of Astra DB vector store class `AstraDBVectorStore`.""" + +from __future__ import annotations + +import json +import math +import os +from typing import TYPE_CHECKING, Any + +import pytest +from astrapy.authentication import StaticTokenProvider +from langchain_core.documents import Document + +from langchain_astradb.utils.astradb import SetupMode +from langchain_astradb.vectorstores import AstraDBVectorStore + +from .conftest import ( + COLLECTION_NAME_D2, + EPHEMERAL_COLLECTION_NAME_D2, + EUCLIDEAN_MIN_SIM_UNIT_VECTORS, + MATCH_EPSILON, + OPENAI_VECTORIZE_OPTIONS_HEADER, + astra_db_env_vars_available, +) + +if TYPE_CHECKING: + from astrapy import Collection + from astrapy.db import AstraDB + from langchain_core.embeddings import Embeddings + + from .conftest import AstraDBCredentials + + +@pytest.mark.skipif( + not astra_db_env_vars_available(), reason="Missing Astra DB env. vars" +) +class TestAstraDBVectorStore: + @pytest.mark.parametrize( + ("is_vectorize", "page_contents", "collection_fixture_name"), + [ + ( + False, + [ + "[1,2]", + "[3,4]", + "[5,6]", + "[7,8]", + "[9,10]", + "[11,12]", + ], + "empty_collection_d2", + ), + ( + True, + [ + "Dogs 1", + "Cats 3", + "Giraffes 5", + "Spiders 7", + "Pycnogonids 9", + "Rabbits 11", + ], + "empty_collection_vz", + ), + ], + ids=["nonvectorize_store", "vectorize_store"], + ) + def test_astradb_vectorstore_from_texts_sync( + self, + *, + astra_db_credentials: AstraDBCredentials, + openai_api_key: str, + embedding_d2: Embeddings, + is_vectorize: bool, + page_contents: list[str], + collection_fixture_name: str, + request: pytest.FixtureRequest, + ) -> None: + """from_texts methods and the associated warnings.""" + collection: Collection = request.getfixturevalue(collection_fixture_name) + init_kwargs: dict[str, Any] + if is_vectorize: + init_kwargs = { + "collection_vector_service_options": OPENAI_VECTORIZE_OPTIONS_HEADER, + "collection_embedding_api_key": openai_api_key, + } + else: + init_kwargs = {"embedding": embedding_d2} + + v_store = AstraDBVectorStore.from_texts( + texts=page_contents[0:2], + metadatas=[{"m": 1}, {"m": 3}], + ids=["ft1", "ft3"], + collection_name=collection.name, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + setup_mode=SetupMode.OFF, + **init_kwargs, + ) + search_results_triples_0 = v_store.similarity_search_with_score_id( + page_contents[1], + k=1, + ) + assert len(search_results_triples_0) == 1 + res_doc_0, _, res_id_0 = search_results_triples_0[0] + assert res_doc_0.page_content == page_contents[1] + assert res_doc_0.metadata == {"m": 3} + assert res_id_0 == "ft3" + + # testing additional kwargs & from_text-specific kwargs + with pytest.warns(UserWarning): + # unknown kwargs going to the constructor through _from_kwargs + AstraDBVectorStore.from_texts( + texts=page_contents[2:4], + metadatas=[{"m": 5}, {"m": 7}], + ids=["ft5", "ft7"], + collection_name=collection.name, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + setup_mode=SetupMode.OFF, + number_of_wizards=123, + name_of_river="Thames", + **init_kwargs, + ) + search_results_triples_1 = v_store.similarity_search_with_score_id( + page_contents[3], + k=1, + ) + assert len(search_results_triples_1) == 1 + res_doc_1, _, res_id_1 = search_results_triples_1[0] + assert res_doc_1.page_content == page_contents[3] + assert res_doc_1.metadata == {"m": 7} + assert res_id_1 == "ft7" + # routing of 'add_texts' keyword arguments + v_store_2 = AstraDBVectorStore.from_texts( + texts=page_contents[4:6], + metadatas=[{"m": 9}, {"m": 11}], + ids=["ft9", "ft11"], + collection_name=collection.name, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + setup_mode=SetupMode.OFF, + batch_size=19, + batch_concurrency=23, + overwrite_concurrency=29, + **init_kwargs, + ) + assert v_store_2.batch_size != 19 + assert v_store_2.bulk_insert_batch_concurrency != 23 + assert v_store_2.bulk_insert_overwrite_concurrency != 29 + search_results_triples_2 = v_store_2.similarity_search_with_score_id( + page_contents[5], + k=1, + ) + assert len(search_results_triples_2) == 1 + res_doc_2, _, res_id_2 = search_results_triples_2[0] + assert res_doc_2.page_content == page_contents[5] + assert res_doc_2.metadata == {"m": 11} + assert res_id_2 == "ft11" + + @pytest.mark.parametrize( + ("is_vectorize", "page_contents", "collection_fixture_name"), + [ + (False, ["[1,2]", "[3,4]"], "empty_collection_d2"), + (True, ["Whales 1", "Tomatoes 3"], "empty_collection_vz"), + ], + ids=["nonvectorize_store", "vectorize_store"], + ) + def test_astradb_vectorstore_from_documents_sync( + self, + *, + astra_db_credentials: AstraDBCredentials, + openai_api_key: str, + embedding_d2: Embeddings, + is_vectorize: bool, + page_contents: list[str], + collection_fixture_name: str, + request: pytest.FixtureRequest, + ) -> None: + """from_documents, esp. the various handling of ID-in-doc vs external.""" + collection: Collection = request.getfixturevalue(collection_fixture_name) + pc1, pc2 = page_contents + init_kwargs: dict[str, Any] + if is_vectorize: + init_kwargs = { + "collection_vector_service_options": OPENAI_VECTORIZE_OPTIONS_HEADER, + "collection_embedding_api_key": openai_api_key, + } + else: + init_kwargs = {"embedding": embedding_d2} + # no IDs. + v_store = AstraDBVectorStore.from_documents( + [ + Document(page_content=pc1, metadata={"m": 1}), + Document(page_content=pc2, metadata={"m": 3}), + ], + collection_name=collection.name, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + setup_mode=SetupMode.OFF, + **init_kwargs, + ) + hits = v_store.similarity_search(pc2, k=1) + assert len(hits) == 1 + assert hits[0].page_content == pc2 + assert hits[0].metadata == {"m": 3} + v_store.clear() + + # IDs passed separately. + with pytest.warns(DeprecationWarning) as rec_warnings: + v_store_2 = AstraDBVectorStore.from_documents( + [ + Document(page_content=pc1, metadata={"m": 1}), + Document(page_content=pc2, metadata={"m": 3}), + ], + ids=["idx1", "idx3"], + collection_name=collection.name, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + setup_mode=SetupMode.OFF, + **init_kwargs, + ) + f_rec_warnings = [ + wrn for wrn in rec_warnings if issubclass(wrn.category, DeprecationWarning) + ] + assert len(f_rec_warnings) == 1 + hits = v_store_2.similarity_search(pc2, k=1) + assert len(hits) == 1 + assert hits[0].page_content == pc2 + assert hits[0].metadata == {"m": 3} + assert hits[0].id == "idx3" + v_store_2.clear() + + # IDs in documents. + v_store_3 = AstraDBVectorStore.from_documents( + [ + Document(page_content=pc1, metadata={"m": 1}, id="idx1"), + Document(page_content=pc2, metadata={"m": 3}, id="idx3"), + ], + collection_name=collection.name, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + setup_mode=SetupMode.OFF, + **init_kwargs, + ) + hits = v_store_3.similarity_search(pc2, k=1) + assert len(hits) == 1 + assert hits[0].page_content == pc2 + assert hits[0].metadata == {"m": 3} + assert hits[0].id == "idx3" + v_store_3.clear() + + # IDs both in documents and aside. + with pytest.warns(DeprecationWarning) as rec_warnings: + v_store_4 = AstraDBVectorStore.from_documents( + [ + Document(page_content=pc1, metadata={"m": 1}), + Document(page_content=pc2, metadata={"m": 3}, id="idy3"), + ], + ids=["idx1", "idx3"], + collection_name=collection.name, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + setup_mode=SetupMode.OFF, + **init_kwargs, + ) + f_rec_warnings = [ + wrn for wrn in rec_warnings if issubclass(wrn.category, DeprecationWarning) + ] + assert len(f_rec_warnings) == 1 + hits = v_store_4.similarity_search(pc2, k=1) + assert len(hits) == 1 + assert hits[0].page_content == pc2 + assert hits[0].metadata == {"m": 3} + assert hits[0].id == "idx3" + + @pytest.mark.parametrize( + ("is_vectorize", "page_contents", "collection_fixture_name"), + [ + ( + False, + [ + "[1,2]", + "[3,4]", + "[5,6]", + "[7,8]", + "[9,10]", + "[11,12]", + ], + "empty_collection_d2", + ), + ( + True, + [ + "Dogs 1", + "Cats 3", + "Giraffes 5", + "Spiders 7", + "Pycnogonids 9", + "Rabbits 11", + ], + "empty_collection_vz", + ), + ], + ids=["nonvectorize_store", "vectorize_store"], + ) + async def test_astradb_vectorstore_from_texts_async( + self, + *, + astra_db_credentials: AstraDBCredentials, + openai_api_key: str, + embedding_d2: Embeddings, + is_vectorize: bool, + page_contents: list[str], + collection_fixture_name: str, + request: pytest.FixtureRequest, + ) -> None: + """from_texts methods and the associated warnings, async version.""" + collection: Collection = request.getfixturevalue(collection_fixture_name) + init_kwargs: dict[str, Any] + if is_vectorize: + init_kwargs = { + "collection_vector_service_options": OPENAI_VECTORIZE_OPTIONS_HEADER, + "collection_embedding_api_key": openai_api_key, + } + else: + init_kwargs = {"embedding": embedding_d2} + + v_store = await AstraDBVectorStore.afrom_texts( + texts=page_contents[0:2], + metadatas=[{"m": 1}, {"m": 3}], + ids=["ft1", "ft3"], + collection_name=collection.name, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + setup_mode=SetupMode.OFF, + **init_kwargs, + ) + search_results_triples_0 = await v_store.asimilarity_search_with_score_id( + page_contents[1], + k=1, + ) + assert len(search_results_triples_0) == 1 + res_doc_0, _, res_id_0 = search_results_triples_0[0] + assert res_doc_0.page_content == page_contents[1] + assert res_doc_0.metadata == {"m": 3} + assert res_id_0 == "ft3" + + # testing additional kwargs & from_text-specific kwargs + with pytest.warns(UserWarning): + # unknown kwargs going to the constructor through _from_kwargs + await AstraDBVectorStore.afrom_texts( + texts=page_contents[2:4], + metadatas=[{"m": 5}, {"m": 7}], + ids=["ft5", "ft7"], + collection_name=collection.name, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + setup_mode=SetupMode.OFF, + number_of_wizards=123, + name_of_river="Thames", + **init_kwargs, + ) + search_results_triples_1 = await v_store.asimilarity_search_with_score_id( + page_contents[3], + k=1, + ) + assert len(search_results_triples_1) == 1 + res_doc_1, _, res_id_1 = search_results_triples_1[0] + assert res_doc_1.page_content == page_contents[3] + assert res_doc_1.metadata == {"m": 7} + assert res_id_1 == "ft7" + # routing of 'add_texts' keyword arguments + v_store_2 = await AstraDBVectorStore.afrom_texts( + texts=page_contents[4:6], + metadatas=[{"m": 9}, {"m": 11}], + ids=["ft9", "ft11"], + collection_name=collection.name, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + setup_mode=SetupMode.OFF, + batch_size=19, + batch_concurrency=23, + overwrite_concurrency=29, + **init_kwargs, + ) + assert v_store_2.batch_size != 19 + assert v_store_2.bulk_insert_batch_concurrency != 23 + assert v_store_2.bulk_insert_overwrite_concurrency != 29 + search_results_triples_2 = await v_store_2.asimilarity_search_with_score_id( + page_contents[5], + k=1, + ) + assert len(search_results_triples_2) == 1 + res_doc_2, _, res_id_2 = search_results_triples_2[0] + assert res_doc_2.page_content == page_contents[5] + assert res_doc_2.metadata == {"m": 11} + assert res_id_2 == "ft11" + + @pytest.mark.parametrize( + ("is_vectorize", "page_contents", "collection_fixture_name"), + [ + (False, ["[1,2]", "[3,4]"], "empty_collection_d2"), + (True, ["Whales 1", "Tomatoes 3"], "empty_collection_vz"), + ], + ids=["nonvectorize_store", "vectorize_store"], + ) + async def test_astradb_vectorstore_from_documents_async( + self, + *, + astra_db_credentials: AstraDBCredentials, + openai_api_key: str, + embedding_d2: Embeddings, + is_vectorize: bool, + page_contents: list[str], + collection_fixture_name: str, + request: pytest.FixtureRequest, + ) -> None: + """ + from_documents, esp. the various handling of ID-in-doc vs external. + Async version. + """ + collection: Collection = request.getfixturevalue(collection_fixture_name) + pc1, pc2 = page_contents + init_kwargs: dict[str, Any] + if is_vectorize: + init_kwargs = { + "collection_vector_service_options": OPENAI_VECTORIZE_OPTIONS_HEADER, + "collection_embedding_api_key": openai_api_key, + } + else: + init_kwargs = {"embedding": embedding_d2} + # no IDs. + v_store = await AstraDBVectorStore.afrom_documents( + [ + Document(page_content=pc1, metadata={"m": 1}), + Document(page_content=pc2, metadata={"m": 3}), + ], + collection_name=collection.name, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + setup_mode=SetupMode.OFF, + **init_kwargs, + ) + hits = await v_store.asimilarity_search(pc2, k=1) + assert len(hits) == 1 + assert hits[0].page_content == pc2 + assert hits[0].metadata == {"m": 3} + v_store.clear() + + # IDs passed separately. + with pytest.warns(DeprecationWarning) as rec_warnings: + v_store_2 = await AstraDBVectorStore.afrom_documents( + [ + Document(page_content=pc1, metadata={"m": 1}), + Document(page_content=pc2, metadata={"m": 3}), + ], + ids=["idx1", "idx3"], + collection_name=collection.name, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + setup_mode=SetupMode.OFF, + **init_kwargs, + ) + f_rec_warnings = [ + wrn for wrn in rec_warnings if issubclass(wrn.category, DeprecationWarning) + ] + assert len(f_rec_warnings) == 1 + hits = await v_store_2.asimilarity_search(pc2, k=1) + assert len(hits) == 1 + assert hits[0].page_content == pc2 + assert hits[0].metadata == {"m": 3} + assert hits[0].id == "idx3" + v_store_2.clear() + + # IDs in documents. + v_store_3 = await AstraDBVectorStore.afrom_documents( + [ + Document(page_content=pc1, metadata={"m": 1}, id="idx1"), + Document(page_content=pc2, metadata={"m": 3}, id="idx3"), + ], + collection_name=collection.name, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + setup_mode=SetupMode.OFF, + **init_kwargs, + ) + hits = await v_store_3.asimilarity_search(pc2, k=1) + assert len(hits) == 1 + assert hits[0].page_content == pc2 + assert hits[0].metadata == {"m": 3} + assert hits[0].id == "idx3" + v_store_3.clear() + + # IDs both in documents and aside. + with pytest.warns(DeprecationWarning) as rec_warnings: + v_store_4 = await AstraDBVectorStore.afrom_documents( + [ + Document(page_content=pc1, metadata={"m": 1}), + Document(page_content=pc2, metadata={"m": 3}, id="idy3"), + ], + ids=["idx1", "idx3"], + collection_name=collection.name, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + setup_mode=SetupMode.OFF, + **init_kwargs, + ) + f_rec_warnings = [ + wrn for wrn in rec_warnings if issubclass(wrn.category, DeprecationWarning) + ] + assert len(f_rec_warnings) == 1 + hits = await v_store_4.asimilarity_search(pc2, k=1) + assert len(hits) == 1 + assert hits[0].page_content == pc2 + assert hits[0].metadata == {"m": 3} + assert hits[0].id == "idx3" + + @pytest.mark.parametrize( + "vector_store", + [ + "vector_store_d2", + "vector_store_d2_stringtoken", + "vector_store_vz", + ], + ) + def test_astradb_vectorstore_crud_sync( + self, + vector_store: str, + request: pytest.FixtureRequest, + ) -> None: + """Add/delete/update behaviour.""" + vstore: AstraDBVectorStore = request.getfixturevalue(vector_store) + + res0 = vstore.similarity_search("[-1,-1]", k=2) + assert res0 == [] + # write and check again + added_ids = vstore.add_texts( + texts=["[1,2]", "[3,4]", "[5,6]"], + metadatas=[ + {"k": "a", "ord": 0}, + {"k": "b", "ord": 1}, + {"k": "c", "ord": 2}, + ], + ids=["a", "b", "c"], + ) + # not requiring ordered match (elsewhere it may be overwriting some) + assert set(added_ids) == {"a", "b", "c"} + res1 = vstore.similarity_search("[-1,-1]", k=5) + assert {doc.page_content for doc in res1} == {"[1,2]", "[3,4]", "[5,6]"} + res2 = vstore.similarity_search("[3,4]", k=1) + assert len(res2) == 1 + assert res2[0].page_content == "[3,4]" + assert res2[0].metadata == {"k": "b", "ord": 1} + assert res2[0].id == "b" + # partial overwrite and count total entries + added_ids_1 = vstore.add_texts( + texts=["[5,6]", "[7,8]"], + metadatas=[ + {"k": "c_new", "ord": 102}, + {"k": "d_new", "ord": 103}, + ], + ids=["c", "d"], + ) + # not requiring ordered match (elsewhere it may be overwriting some) + assert set(added_ids_1) == {"c", "d"} + res2 = vstore.similarity_search("[-1,-1]", k=10) + assert len(res2) == 4 + # pick one that was just updated and check its metadata + res3 = vstore.similarity_search_with_score_id( + query="[5,6]", k=1, filter={"k": "c_new"} + ) + doc3, _, id3 = res3[0] + assert doc3.page_content == "[5,6]" + assert doc3.metadata == {"k": "c_new", "ord": 102} + assert id3 == "c" + # delete and count again + del1_res = vstore.delete(["b"]) + assert del1_res is True + del2_res = vstore.delete(["a", "c", "Z!"]) + assert del2_res is True # a non-existing ID was supplied + assert len(vstore.similarity_search("[-1,-1]", k=10)) == 1 + # clear store + vstore.clear() + assert vstore.similarity_search("[-1,-1]", k=2) == [] + # add_documents with "ids" arg passthrough + vstore.add_documents( + [ + Document(page_content="[9,10]", metadata={"k": "v", "ord": 204}), + Document(page_content="[11,12]", metadata={"k": "w", "ord": 205}), + ], + ids=["v", "w"], + ) + assert len(vstore.similarity_search("[-1,-1]", k=10)) == 2 + res4 = vstore.similarity_search("[11,12]", k=1, filter={"k": "w"}) + assert res4[0].metadata["ord"] == 205 + assert res4[0].id == "w" + # add_texts with "ids" arg passthrough + vstore.add_texts( + texts=["[13,14]", "[15,16]"], + metadatas=[{"k": "r", "ord": 306}, {"k": "s", "ord": 307}], + ids=["r", "s"], + ) + assert len(vstore.similarity_search("[-1,-1]", k=10)) == 4 + res4 = vstore.similarity_search("[-1,-1]", k=1, filter={"k": "s"}) + assert res4[0].metadata["ord"] == 307 + assert res4[0].id == "s" + # delete_by_document_id + vstore.delete_by_document_id("s") + assert len(vstore.similarity_search("[-1,-1]", k=10)) == 3 + + @pytest.mark.parametrize( + "vector_store", + [ + "vector_store_d2", + "vector_store_d2_stringtoken", + "vector_store_vz", + ], + ) + async def test_astradb_vectorstore_crud_async( + self, + vector_store: str, + request: pytest.FixtureRequest, + ) -> None: + """Add/delete/update behaviour, async version.""" + vstore: AstraDBVectorStore = request.getfixturevalue(vector_store) + + res0 = await vstore.asimilarity_search("[-1,-1]", k=2) + assert res0 == [] + # write and check again + added_ids = await vstore.aadd_texts( + texts=["[1,2]", "[3,4]", "[5,6]"], + metadatas=[ + {"k": "a", "ord": 0}, + {"k": "b", "ord": 1}, + {"k": "c", "ord": 2}, + ], + ids=["a", "b", "c"], + ) + # not requiring ordered match (elsewhere it may be overwriting some) + assert set(added_ids) == {"a", "b", "c"} + res1 = await vstore.asimilarity_search("[-1,-1]", k=5) + assert {doc.page_content for doc in res1} == {"[1,2]", "[3,4]", "[5,6]"} + res2 = await vstore.asimilarity_search("[3,4]", k=1) + assert len(res2) == 1 + assert res2[0].page_content == "[3,4]" + assert res2[0].metadata == {"k": "b", "ord": 1} + assert res2[0].id == "b" + # partial overwrite and count total entries + added_ids_1 = await vstore.aadd_texts( + texts=["[5,6]", "[7,8]"], + metadatas=[ + {"k": "c_new", "ord": 102}, + {"k": "d_new", "ord": 103}, + ], + ids=["c", "d"], + ) + # not requiring ordered match (elsewhere it may be overwriting some) + assert set(added_ids_1) == {"c", "d"} + res2 = await vstore.asimilarity_search("[-1,-1]", k=10) + assert len(res2) == 4 + # pick one that was just updated and check its metadata + res3 = await vstore.asimilarity_search_with_score_id( + query="[5,6]", k=1, filter={"k": "c_new"} + ) + doc3, _, id3 = res3[0] + assert doc3.page_content == "[5,6]" + assert doc3.metadata == {"k": "c_new", "ord": 102} + assert id3 == "c" + # delete and count again + del1_res = await vstore.adelete(["b"]) + assert del1_res is True + del2_res = await vstore.adelete(["a", "c", "Z!"]) + assert del2_res is True # a non-existing ID was supplied + assert len(await vstore.asimilarity_search("[-1,-1]", k=10)) == 1 + # clear store + await vstore.aclear() + assert await vstore.asimilarity_search("[-1,-1]", k=2) == [] + # add_documents with "ids" arg passthrough + await vstore.aadd_documents( + [ + Document(page_content="[9,10]", metadata={"k": "v", "ord": 204}), + Document(page_content="[11,12]", metadata={"k": "w", "ord": 205}), + ], + ids=["v", "w"], + ) + assert len(await vstore.asimilarity_search("[-1,-1]", k=10)) == 2 + res4 = await vstore.asimilarity_search("[11,12]", k=1, filter={"k": "w"}) + assert res4[0].metadata["ord"] == 205 + assert res4[0].id == "w" + # add_texts with "ids" arg passthrough + await vstore.aadd_texts( + texts=["[13,14]", "[15,16]"], + metadatas=[{"k": "r", "ord": 306}, {"k": "s", "ord": 307}], + ids=["r", "s"], + ) + assert len(await vstore.asimilarity_search("[-1,-1]", k=10)) == 4 + res4 = await vstore.asimilarity_search("[-1,-1]", k=1, filter={"k": "s"}) + assert res4[0].metadata["ord"] == 307 + assert res4[0].id == "s" + # delete_by_document_id + await vstore.adelete_by_document_id("s") + assert len(await vstore.asimilarity_search("[-1,-1]", k=10)) == 3 + + def test_astradb_vectorstore_massive_insert_replace_sync( + self, + vector_store_d2: AstraDBVectorStore, + ) -> None: + """Testing the insert-many-and-replace-some patterns thoroughly.""" + full_size = 300 + first_group_size = 150 + second_group_slicer = [30, 100, 2] + + all_ids = [f"doc_{idx}" for idx in range(full_size)] + all_texts = [f"[0,{idx+1}]" for idx in range(full_size)] + + # massive insertion on empty + group0_ids = all_ids[0:first_group_size] + group0_texts = all_texts[0:first_group_size] + inserted_ids0 = vector_store_d2.add_texts( + texts=group0_texts, + ids=group0_ids, + ) + assert set(inserted_ids0) == set(group0_ids) + # massive insertion with many overwrites scattered through + # (we change the text to later check on DB for successful update) + _s, _e, _st = second_group_slicer + group1_ids = all_ids[_s:_e:_st] + all_ids[first_group_size:full_size] + group1_texts = [ + txt.upper() + for txt in (all_texts[_s:_e:_st] + all_texts[first_group_size:full_size]) + ] + inserted_ids1 = vector_store_d2.add_texts( + texts=group1_texts, + ids=group1_ids, + ) + assert set(inserted_ids1) == set(group1_ids) + # final read (we want the IDs to do a full check) + expected_text_by_id = { + **dict(zip(group0_ids, group0_texts)), + **dict(zip(group1_ids, group1_texts)), + } + full_results = vector_store_d2.similarity_search_with_score_id_by_vector( + embedding=[1.0, 1.0], + k=full_size, + ) + for doc, _, doc_id in full_results: + assert doc.page_content == expected_text_by_id[doc_id] + + async def test_astradb_vectorstore_massive_insert_replace_async( + self, + vector_store_d2: AstraDBVectorStore, + ) -> None: + """ + Testing the insert-many-and-replace-some patterns thoroughly. + Async version. + """ + full_size = 300 + first_group_size = 150 + second_group_slicer = [30, 100, 2] + + all_ids = [f"doc_{idx}" for idx in range(full_size)] + all_texts = [f"[0,{idx+1}]" for idx in range(full_size)] + + # massive insertion on empty + group0_ids = all_ids[0:first_group_size] + group0_texts = all_texts[0:first_group_size] + + inserted_ids0 = await vector_store_d2.aadd_texts( + texts=group0_texts, + ids=group0_ids, + ) + assert set(inserted_ids0) == set(group0_ids) + # massive insertion with many overwrites scattered through + # (we change the text to later check on DB for successful update) + _s, _e, _st = second_group_slicer + group1_ids = all_ids[_s:_e:_st] + all_ids[first_group_size:full_size] + group1_texts = [ + txt.upper() + for txt in (all_texts[_s:_e:_st] + all_texts[first_group_size:full_size]) + ] + inserted_ids1 = await vector_store_d2.aadd_texts( + texts=group1_texts, + ids=group1_ids, + ) + assert set(inserted_ids1) == set(group1_ids) + # final read (we want the IDs to do a full check) + expected_text_by_id = dict(zip(all_ids, all_texts)) + full_results = await vector_store_d2.asimilarity_search_with_score_id_by_vector( + embedding=[1.0, 1.0], + k=full_size, + ) + for doc, _, doc_id in full_results: + assert doc.page_content == expected_text_by_id[doc_id] + + def test_astradb_vectorstore_mmr_sync( + self, + vector_store_d2: AstraDBVectorStore, + ) -> None: + """MMR testing. We work on the unit circle with angle multiples + of 2*pi/20 and prepare a store with known vectors for a controlled + MMR outcome. + """ + + def _v_from_i(i: int, n: int) -> str: + angle = 2 * math.pi * i / n + vector = [math.cos(angle), math.sin(angle)] + return json.dumps(vector) + + i_vals = [0, 4, 5, 13] + n_val = 20 + vector_store_d2.add_texts( + [_v_from_i(i, n_val) for i in i_vals], metadatas=[{"i": i} for i in i_vals] + ) + res1 = vector_store_d2.max_marginal_relevance_search( + _v_from_i(3, n_val), + k=2, + fetch_k=3, + ) + res_i_vals = {doc.metadata["i"] for doc in res1} + assert res_i_vals == {0, 4} + + async def test_astradb_vectorstore_mmr_async( + self, + vector_store_d2: AstraDBVectorStore, + ) -> None: + """MMR testing. We work on the unit circle with angle multiples + of 2*pi/20 and prepare a store with known vectors for a controlled + MMR outcome. + Async version. + """ + + def _v_from_i(i: int, n: int) -> str: + angle = 2 * math.pi * i / n + vector = [math.cos(angle), math.sin(angle)] + return json.dumps(vector) + + i_vals = [0, 4, 5, 13] + n_val = 20 + await vector_store_d2.aadd_texts( + [_v_from_i(i, n_val) for i in i_vals], + metadatas=[{"i": i} for i in i_vals], + ) + res1 = await vector_store_d2.amax_marginal_relevance_search( + _v_from_i(3, n_val), + k=2, + fetch_k=3, + ) + res_i_vals = {doc.metadata["i"] for doc in res1} + assert res_i_vals == {0, 4} + + def test_astradb_vectorstore_mmr_vectorize_sync( + self, + vector_store_vz: AstraDBVectorStore, + ) -> None: + """MMR testing with vectorize, sync.""" + vector_store_vz.add_texts( + [ + "Some dogs bark", + "My dog growls", + "Your cat meows", + "Please do the dishes after you're done", + ], + ids=["db", "dg", "c", "z"], + ) + + hits = vector_store_vz.max_marginal_relevance_search( + "The dogs say woof", + k=2, + fetch_k=3, + ) + assert {doc.id for doc in hits} == {"db", "c"} + + async def test_astradb_vectorstore_mmr_vectorize_async( + self, + vector_store_vz: AstraDBVectorStore, + ) -> None: + """MMR async testing with vectorize, async.""" + await vector_store_vz.aadd_texts( + [ + "Some dogs bark", + "My dog growls", + "Your cat meows", + "Please do the dishes after you're done", + ], + ids=["db", "dg", "c", "z"], + ) + + hits = await vector_store_vz.amax_marginal_relevance_search( + "The dogs say woof", + k=2, + fetch_k=3, + ) + assert {doc.id for doc in hits} == {"db", "c"} + + @pytest.mark.parametrize( + "vector_store", + [ + "vector_store_d2", + "vector_store_vz", + ], + ) + def test_astradb_vectorstore_metadata( + self, + vector_store: str, + request: pytest.FixtureRequest, + ) -> None: + """Metadata filtering.""" + vstore: AstraDBVectorStore = request.getfixturevalue(vector_store) + vstore.add_documents( + [ + Document( + page_content="[1,2]", + metadata={"ord": ord("q"), "group": "consonant", "letter": "q"}, + ), + Document( + page_content="[3,4]", + metadata={"ord": ord("w"), "group": "consonant", "letter": "w"}, + ), + Document( + page_content="[5,6]", + metadata={"ord": ord("r"), "group": "consonant", "letter": "r"}, + ), + Document( + page_content="[-1,2]", + metadata={"ord": ord("e"), "group": "vowel", "letter": "e"}, + ), + Document( + page_content="[-3,4]", + metadata={"ord": ord("i"), "group": "vowel", "letter": "i"}, + ), + Document( + page_content="[-5,6]", + metadata={"ord": ord("o"), "group": "vowel", "letter": "o"}, + ), + ] + ) + # no filters + res0 = vstore.similarity_search("[-1,-1]", k=10) + assert {doc.metadata["letter"] for doc in res0} == set("qwreio") + # single filter + res1 = vstore.similarity_search( + "[-1,-1]", + k=10, + filter={"group": "vowel"}, + ) + assert {doc.metadata["letter"] for doc in res1} == set("eio") + # multiple filters + res2 = vstore.similarity_search( + "[-1,-1]", + k=10, + filter={"group": "consonant", "ord": ord("q")}, + ) + assert {doc.metadata["letter"] for doc in res2} == set("q") + # excessive filters + res3 = vstore.similarity_search( + "[-1,-1]", + k=10, + filter={"group": "consonant", "ord": ord("q"), "case": "upper"}, + ) + assert res3 == [] + # filter with logical operator + res4 = vstore.similarity_search( + "[-1,-1]", + k=10, + filter={"$or": [{"ord": ord("q")}, {"ord": ord("r")}]}, + ) + assert {doc.metadata["letter"] for doc in res4} == {"q", "r"} + + @pytest.mark.parametrize( + ("is_vectorize", "vector_store", "texts", "query"), + [ + ( + False, + "vector_store_d2", + ["[1,1]", "[-1,-1]"], + "[0.99999,1.00001]", + ), + ( + True, + "vector_store_vz", + ["the boat is in the sea", "perhaps triangles are blue"], + "there's a ship in the ocean", + ), + ], + ids=["nonvectorize_store", "vectorize_store"], + ) + def test_astradb_vectorstore_similarity_scale_sync( + self, + *, + is_vectorize: bool, + vector_store: str, + texts: list[str], + query: str, + request: pytest.FixtureRequest, + ) -> None: + """Scale of the similarity scores.""" + vstore: AstraDBVectorStore = request.getfixturevalue(vector_store) + vstore.add_texts( + texts=texts, + ids=["near", "far"], + ) + res1 = vstore.similarity_search_with_score( + query, + k=2, + ) + scores = [sco for _, sco in res1] + sco_near, sco_far = scores + assert sco_far >= 0 + if not is_vectorize: + assert abs(1 - sco_near) < MATCH_EPSILON + assert sco_far < EUCLIDEAN_MIN_SIM_UNIT_VECTORS + MATCH_EPSILON + + @pytest.mark.parametrize( + ("is_vectorize", "vector_store", "texts", "query"), + [ + ( + False, + "vector_store_d2", + ["[1,1]", "[-1,-1]"], + "[0.99999,1.00001]", + ), + ( + True, + "vector_store_vz", + ["the boat is in the sea", "perhaps triangles are blue"], + "there's a ship in the ocean", + ), + ], + ids=["nonvectorize_store", "vectorize_store"], + ) + async def test_astradb_vectorstore_similarity_scale_async( + self, + *, + is_vectorize: bool, + vector_store: str, + texts: list[str], + query: str, + request: pytest.FixtureRequest, + ) -> None: + """Scale of the similarity scores, async version.""" + vstore: AstraDBVectorStore = request.getfixturevalue(vector_store) + await vstore.aadd_texts( + texts=texts, + ids=["near", "far"], + ) + res1 = await vstore.asimilarity_search_with_score( + query, + k=2, + ) + scores = [sco for _, sco in res1] + sco_near, sco_far = scores + assert sco_far >= 0 + if not is_vectorize: + assert abs(1 - sco_near) < MATCH_EPSILON + assert sco_far < EUCLIDEAN_MIN_SIM_UNIT_VECTORS + MATCH_EPSILON + + @pytest.mark.parametrize( + "vector_store", + [ + "vector_store_d2", + "vector_store_vz", + ], + ) + def test_astradb_vectorstore_massive_delete( + self, + vector_store: str, + request: pytest.FixtureRequest, + ) -> None: + """Larger-scale bulk deletes.""" + vstore: AstraDBVectorStore = request.getfixturevalue(vector_store) + m = 150 + texts = [f"[0,{i + 1 / 7.0}]" for i in range(2 * m)] + ids0 = ["doc_%i" % i for i in range(m)] + ids1 = ["doc_%i" % (i + m) for i in range(m)] + ids = ids0 + ids1 + vstore.add_texts(texts=texts, ids=ids) + # deleting a bunch of these + del_res0 = vstore.delete(ids0) + assert del_res0 is True + # deleting the rest plus a fake one + del_res1 = vstore.delete([*ids1, "ghost!"]) + assert del_res1 is True # ensure no error + # nothing left + assert vstore.similarity_search("[-1,-1]", k=2 * m) == [] + + def test_astradb_vectorstore_custom_params_sync( + self, + astra_db_credentials: AstraDBCredentials, + empty_collection_d2: Collection, + embedding_d2: Embeddings, + ) -> None: + """Custom batch size and concurrency params.""" + v_store = AstraDBVectorStore( + embedding=embedding_d2, + collection_name=empty_collection_d2.name, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + setup_mode=SetupMode.OFF, + batch_size=17, + bulk_insert_batch_concurrency=13, + bulk_insert_overwrite_concurrency=7, + bulk_delete_concurrency=19, + ) + # add_texts and delete some + n = 120 + texts = [f"[0,{i + 1 / 7.0}]" for i in range(n)] + ids = ["doc_%i" % i for i in range(n)] + v_store.add_texts(texts=texts, ids=ids) + v_store.add_texts( + texts=texts, + ids=ids, + batch_size=19, + batch_concurrency=7, + overwrite_concurrency=13, + ) + v_store.delete(ids[: n // 2]) + v_store.delete(ids[n // 2 :], concurrency=23) + + async def test_astradb_vectorstore_custom_params_async( + self, + astra_db_credentials: AstraDBCredentials, + empty_collection_d2: Collection, + embedding_d2: Embeddings, + ) -> None: + """Custom batch size and concurrency params, async version""" + v_store = AstraDBVectorStore( + embedding=embedding_d2, + collection_name=empty_collection_d2.name, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + setup_mode=SetupMode.OFF, + batch_size=17, + bulk_insert_batch_concurrency=13, + bulk_insert_overwrite_concurrency=7, + bulk_delete_concurrency=19, + ) + # add_texts and delete some + n = 120 + texts = [f"[0,{i + 1 / 7.0}]" for i in range(n)] + ids = ["doc_%i" % i for i in range(n)] + await v_store.aadd_texts(texts=texts, ids=ids) + await v_store.aadd_texts( + texts=texts, + ids=ids, + batch_size=19, + batch_concurrency=7, + overwrite_concurrency=13, + ) + await v_store.adelete(ids[: n // 2]) + await v_store.adelete(ids[n // 2 :], concurrency=23) + + def test_astradb_vectorstore_metrics( + self, + astra_db_credentials: AstraDBCredentials, + embedding_d2: Embeddings, + vector_store_d2: AstraDBVectorStore, + ephemeral_collection_cleaner_d2: str, # noqa: ARG002 + ) -> None: + """Different choices of similarity metric. + Both stores (with "cosine" and "euclidea" metrics) contain these two: + - a vector slightly rotated w.r.t query vector + - a vector which is a long multiple of query vector + so, which one is "the closest one" depends on the metric. + """ + euclidean_store = vector_store_d2 + + isq2 = 0.5**0.5 + isa = 0.7 + isb = (1.0 - isa * isa) ** 0.5 + texts = [ + json.dumps([isa, isb]), + json.dumps([10 * isq2, 10 * isq2]), + ] + ids = ["rotated", "scaled"] + query_text = json.dumps([isq2, isq2]) + + # prepare empty collections + cosine_store = AstraDBVectorStore( + embedding=embedding_d2, + collection_name=EPHEMERAL_COLLECTION_NAME_D2, + metric="cosine", + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + ) + + cosine_store.add_texts(texts=texts, ids=ids) + euclidean_store.add_texts(texts=texts, ids=ids) + + cosine_triples = cosine_store.similarity_search_with_score_id( + query_text, + k=1, + ) + euclidean_triples = euclidean_store.similarity_search_with_score_id( + query_text, + k=1, + ) + assert len(cosine_triples) == 1 + assert len(euclidean_triples) == 1 + assert cosine_triples[0][2] == "scaled" + assert euclidean_triples[0][2] == "rotated" + + @pytest.mark.skipif( + os.environ.get("ASTRA_DB_ENVIRONMENT", "prod").upper() != "PROD", + reason="Can run on Astra DB production environment only", + ) + def test_astradb_vectorstore_coreclients_init_sync( + self, + core_astra_db: AstraDB, + embedding_d2: Embeddings, + vector_store_d2: AstraDBVectorStore, + ) -> None: + """ + Expect a deprecation warning from passing a (core) AstraDB class, + but it must work. + """ + vector_store_d2.add_texts(["[1,2]"]) + + with pytest.warns(DeprecationWarning) as rec_warnings: + v_store_init_core = AstraDBVectorStore( + embedding=embedding_d2, + collection_name=COLLECTION_NAME_D2, + astra_db_client=core_astra_db, + metric="euclidean", + ) + + results = v_store_init_core.similarity_search("[-1,-1]", k=1) + # cleaning out 'spurious' "unclosed socket/transport..." warnings + f_rec_warnings = [ + wrn for wrn in rec_warnings if issubclass(wrn.category, DeprecationWarning) + ] + assert len(f_rec_warnings) == 1 + assert len(results) == 1 + assert results[0].page_content == "[1,2]" + + @pytest.mark.skipif( + os.environ.get("ASTRA_DB_ENVIRONMENT", "prod").upper() != "PROD", + reason="Can run on Astra DB production environment only", + ) + async def test_astradb_vectorstore_coreclients_init_async( + self, + core_astra_db: AstraDB, + embedding_d2: Embeddings, + vector_store_d2: AstraDBVectorStore, + ) -> None: + """ + Expect a deprecation warning from passing a (core) AstraDB class, + but it must work. Async version. + """ + vector_store_d2.add_texts(["[1,2]"]) + + with pytest.warns(DeprecationWarning) as rec_warnings: + v_store_init_core = AstraDBVectorStore( + embedding=embedding_d2, + collection_name=COLLECTION_NAME_D2, + astra_db_client=core_astra_db, + metric="euclidean", + setup_mode=SetupMode.ASYNC, + ) + + results = await v_store_init_core.asimilarity_search("[-1,-1]", k=1) + # cleaning out 'spurious' "unclosed socket/transport..." warnings + f_rec_warnings = [ + wrn for wrn in rec_warnings if issubclass(wrn.category, DeprecationWarning) + ] + assert len(f_rec_warnings) == 1 + assert len(results) == 1 + assert results[0].page_content == "[1,2]" diff --git a/libs/astradb/tests/integration_tests/test_vectorstore_autodetect.py b/libs/astradb/tests/integration_tests/test_vectorstore_autodetect.py index 1c1fab4..a7cf999 100644 --- a/libs/astradb/tests/integration_tests/test_vectorstore_autodetect.py +++ b/libs/astradb/tests/integration_tests/test_vectorstore_autodetect.py @@ -5,175 +5,67 @@ from __future__ import annotations -import os -from typing import TYPE_CHECKING, Iterable +from typing import TYPE_CHECKING import pytest -from astrapy import DataAPIClient from astrapy.authentication import StaticTokenProvider from langchain_core.documents import Document -from langchain_astradb.utils.astradb import SetupMode from langchain_astradb.vectorstores import AstraDBVectorStore -from tests.conftest import SomeEmbeddings -from .conftest import OPENAI_VECTORIZE_OPTIONS, AstraDBCredentials, _has_env_vars +from .conftest import ( + COLLECTION_NAME_IDXALL_D2, + COLLECTION_NAME_IDXALL_VZ, + CUSTOM_CONTENT_KEY, + astra_db_env_vars_available, +) if TYPE_CHECKING: from astrapy import Collection from langchain_core.embeddings import Embeddings -# Faster testing (no actual collection deletions). Off by default (=full tests) -SKIP_COLLECTION_DELETE = ( - int(os.environ.get("ASTRA_DB_SKIP_COLLECTION_DELETIONS", "0")) != 0 -) - -AD_NOVECTORIZE_COLLECTION = "lc_ad_novectorize" -AD_VECTORIZE_COLLECTION = "lc_ad_vectorize" - - -@pytest.fixture(scope="session") -def provisioned_novectorize_collection( - astra_db_credentials: AstraDBCredentials, -) -> Iterable[Collection]: - """Provision a general-purpose collection for the no-vectorize tests.""" - client = DataAPIClient(environment=astra_db_credentials["environment"]) - database = client.get_database( - astra_db_credentials["api_endpoint"], - token=StaticTokenProvider(astra_db_credentials["token"]), - namespace=astra_db_credentials["namespace"], - ) - collection = database.create_collection( - AD_NOVECTORIZE_COLLECTION, - dimension=2, - check_exists=False, - metric="cosine", - ) - yield collection - - if not SKIP_COLLECTION_DELETE: - collection.drop() - - -@pytest.fixture(scope="session") -def provisioned_vectorize_collection( - astra_db_credentials: AstraDBCredentials, -) -> Iterable[Collection]: - """Provision a general-purpose collection for the vectorize tests.""" - client = DataAPIClient(environment=astra_db_credentials["environment"]) - database = client.get_database( - astra_db_credentials["api_endpoint"], - token=StaticTokenProvider(astra_db_credentials["token"]), - namespace=astra_db_credentials["namespace"], - ) - collection = database.create_collection( - AD_VECTORIZE_COLLECTION, - dimension=2, - check_exists=False, - metric="cosine", - service=OPENAI_VECTORIZE_OPTIONS, - ) - yield collection - - if not SKIP_COLLECTION_DELETE: - collection.drop() - - -@pytest.fixture -def novectorize_collection( - provisioned_novectorize_collection: Collection, -) -> Iterable[Collection]: - provisioned_novectorize_collection.delete_many({}) - yield provisioned_novectorize_collection - - provisioned_novectorize_collection.delete_many({}) - - -@pytest.fixture -def vectorize_collection( - provisioned_vectorize_collection: Collection, -) -> Iterable[Collection]: - provisioned_vectorize_collection.delete_many({}) - yield provisioned_vectorize_collection - - provisioned_vectorize_collection.delete_many({}) - - -@pytest.fixture(scope="session") -def embedding() -> Embeddings: - return SomeEmbeddings(dimension=2) + from .conftest import AstraDBCredentials -@pytest.fixture -def novectorize_store( - novectorize_collection: Collection, # noqa: ARG001 - astra_db_credentials: AstraDBCredentials, - embedding: Embeddings, -) -> AstraDBVectorStore: - return AstraDBVectorStore( - embedding=embedding, - collection_name=AD_NOVECTORIZE_COLLECTION, - token=StaticTokenProvider(astra_db_credentials["token"]), - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - setup_mode=SetupMode.OFF, - ) - - -@pytest.fixture -def vectorize_store( - vectorize_collection: Collection, # noqa: ARG001 - astra_db_credentials: AstraDBCredentials, -) -> AstraDBVectorStore: - return AstraDBVectorStore( - collection_name=AD_VECTORIZE_COLLECTION, - token=StaticTokenProvider(astra_db_credentials["token"]), - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - setup_mode=SetupMode.OFF, - collection_vector_service_options=OPENAI_VECTORIZE_OPTIONS, - ) - - -@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars") -class TestVectorStoreAutodetect: +@pytest.mark.skipif( + not astra_db_env_vars_available(), reason="Missing Astra DB env. vars" +) +class TestAstraDBVectorStoreAutodetect: def test_autodetect_flat_novectorize_crud( self, - novectorize_collection: Collection, astra_db_credentials: AstraDBCredentials, - embedding: Embeddings, + empty_collection_idxall_d2: Collection, + embedding_d2: Embeddings, ) -> None: """Test autodetect on a populated flat collection, checking all codecs.""" - novectorize_collection.insert_many( + empty_collection_idxall_d2.insert_many( [ { "_id": "1", - "$vector": [0.1, 0.2], - "cont": "Cont1", + "$vector": [1, 2], + CUSTOM_CONTENT_KEY: "[1,2]", "m1": "a", "m2": "x", }, { "_id": "2", - "$vector": [0.3, 0.4], - "cont": "Cont2", + "$vector": [3, 4], + CUSTOM_CONTENT_KEY: "[3,4]", "m1": "b", "m2": "y", }, { "_id": "3", - "$vector": [0.5, 0.6], - "cont": "Cont3", + "$vector": [5, 6], + CUSTOM_CONTENT_KEY: "[5,6]", "m1": "c", "m2": "z", }, ] ) ad_store = AstraDBVectorStore( - embedding=embedding, - collection_name=AD_NOVECTORIZE_COLLECTION, + embedding=embedding_d2, + collection_name=COLLECTION_NAME_IDXALL_D2, token=StaticTokenProvider(astra_db_credentials["token"]), api_endpoint=astra_db_credentials["api_endpoint"], namespace=astra_db_credentials["namespace"], @@ -182,14 +74,14 @@ def test_autodetect_flat_novectorize_crud( ) # ANN and the metadata - results = ad_store.similarity_search("query", k=3) - assert {res.page_content for res in results} == {"Cont1", "Cont2", "Cont3"} + results = ad_store.similarity_search("[-1,-1]", k=3) + assert {res.page_content for res in results} == {"[1,2]", "[3,4]", "[5,6]"} assert "m1" in results[0].metadata assert "m2" in results[0].metadata # inserting id4 = "4" - pc4 = "Cont4" + pc4 = "[7,8]" md4 = {"q1": "Q1", "q2": "Q2"} inserted_ids = ad_store.add_texts( texts=[pc4], @@ -199,22 +91,22 @@ def test_autodetect_flat_novectorize_crud( assert inserted_ids == [id4] # reading with filtering - results2 = ad_store.similarity_search("query", k=3, filter={"q2": "Q2"}) + results2 = ad_store.similarity_search("[-1,-1]", k=3, filter={"q2": "Q2"}) assert results2 == [Document(id=id4, page_content=pc4, metadata=md4)] def test_autodetect_default_novectorize_crud( self, - novectorize_collection: Collection, # noqa: ARG002 astra_db_credentials: AstraDBCredentials, - embedding: Embeddings, - novectorize_store: AstraDBVectorStore, + empty_collection_idxall_d2: Collection, # noqa: ARG002 + embedding_d2: Embeddings, + vector_store_idxall_d2: AstraDBVectorStore, ) -> None: """Test autodetect on a VS-made collection, checking all codecs.""" - novectorize_store.add_texts( + vector_store_idxall_d2.add_texts( texts=[ - "Cont1", - "Cont2", - "Cont3", + "[1,2]", + "[3,4]", + "[5,6]", ], metadatas=[ {"m1": "a", "m2": "x"}, @@ -229,8 +121,8 @@ def test_autodetect_default_novectorize_crud( ) # now with the autodetect ad_store = AstraDBVectorStore( - embedding=embedding, - collection_name=AD_NOVECTORIZE_COLLECTION, + embedding=embedding_d2, + collection_name=COLLECTION_NAME_IDXALL_D2, token=StaticTokenProvider(astra_db_credentials["token"]), api_endpoint=astra_db_credentials["api_endpoint"], namespace=astra_db_credentials["namespace"], @@ -239,14 +131,14 @@ def test_autodetect_default_novectorize_crud( ) # ANN and the metadata - results = ad_store.similarity_search("query", k=3) - assert {res.page_content for res in results} == {"Cont1", "Cont2", "Cont3"} + results = ad_store.similarity_search("[-1,-1]", k=3) + assert {res.page_content for res in results} == {"[1,2]", "[3,4]", "[5,6]"} assert "m1" in results[0].metadata assert "m2" in results[0].metadata # inserting id4 = "4" - pc4 = "Cont4" + pc4 = "[7,8]" md4 = {"q1": "Q1", "q2": "Q2"} inserted_ids = ad_store.add_texts( texts=[pc4], @@ -256,16 +148,17 @@ def test_autodetect_default_novectorize_crud( assert inserted_ids == [id4] # reading with filtering - results2 = ad_store.similarity_search("query", k=3, filter={"q2": "Q2"}) + results2 = ad_store.similarity_search("[9,10]", k=3, filter={"q2": "Q2"}) assert results2 == [Document(id=id4, page_content=pc4, metadata=md4)] def test_autodetect_flat_vectorize_crud( self, - vectorize_collection: Collection, astra_db_credentials: AstraDBCredentials, + openai_api_key: str, + empty_collection_idxall_vz: Collection, ) -> None: """Test autodetect on a populated flat collection, checking all codecs.""" - vectorize_collection.insert_many( + empty_collection_idxall_vz.insert_many( [ { "_id": "1", @@ -288,12 +181,13 @@ def test_autodetect_flat_vectorize_crud( ] ) ad_store = AstraDBVectorStore( - collection_name=AD_VECTORIZE_COLLECTION, + collection_name=COLLECTION_NAME_IDXALL_VZ, token=StaticTokenProvider(astra_db_credentials["token"]), api_endpoint=astra_db_credentials["api_endpoint"], namespace=astra_db_credentials["namespace"], environment=astra_db_credentials["environment"], autodetect_collection=True, + collection_embedding_api_key=openai_api_key, ) # ANN and the metadata @@ -319,12 +213,14 @@ def test_autodetect_flat_vectorize_crud( def test_autodetect_default_vectorize_crud( self, - vectorize_collection: Collection, # noqa: ARG002 + *, astra_db_credentials: AstraDBCredentials, - vectorize_store: AstraDBVectorStore, + openai_api_key: str, + empty_collection_idxall_vz: Collection, # noqa: ARG002 + vector_store_idxall_vz: AstraDBVectorStore, ) -> None: """Test autodetect on a VS-made collection, checking all codecs.""" - vectorize_store.add_texts( + vector_store_idxall_vz.add_texts( texts=[ "Cont1", "Cont2", @@ -343,12 +239,13 @@ def test_autodetect_default_vectorize_crud( ) # now with the autodetect ad_store = AstraDBVectorStore( - collection_name=AD_VECTORIZE_COLLECTION, + collection_name=COLLECTION_NAME_IDXALL_VZ, token=StaticTokenProvider(astra_db_credentials["token"]), api_endpoint=astra_db_credentials["api_endpoint"], namespace=astra_db_credentials["namespace"], environment=astra_db_credentials["environment"], autodetect_collection=True, + collection_embedding_api_key=openai_api_key, ) # ANN and the metadata @@ -374,25 +271,25 @@ def test_autodetect_default_vectorize_crud( def test_failed_docs_autodetect_flat_novectorize_crud( self, - novectorize_collection: Collection, astra_db_credentials: AstraDBCredentials, - embedding: Embeddings, + empty_collection_idxall_d2: Collection, + embedding_d2: Embeddings, ) -> None: """Test autodetect + skipping failing documents.""" - novectorize_collection.insert_many( + empty_collection_idxall_d2.insert_many( [ { "_id": "1", - "$vector": [0.1, 0.2], - "cont": "Cont1", + "$vector": [1, 2], + CUSTOM_CONTENT_KEY: "[1,2]", "m1": "a", "m2": "x", }, ] ) ad_store_e = AstraDBVectorStore( - collection_name=AD_NOVECTORIZE_COLLECTION, - embedding=embedding, + collection_name=COLLECTION_NAME_IDXALL_D2, + embedding=embedding_d2, token=StaticTokenProvider(astra_db_credentials["token"]), api_endpoint=astra_db_credentials["api_endpoint"], namespace=astra_db_credentials["namespace"], @@ -401,8 +298,8 @@ def test_failed_docs_autodetect_flat_novectorize_crud( ignore_invalid_documents=False, ) ad_store_w = AstraDBVectorStore( - collection_name=AD_NOVECTORIZE_COLLECTION, - embedding=embedding, + collection_name=COLLECTION_NAME_IDXALL_D2, + embedding=embedding_d2, token=StaticTokenProvider(astra_db_credentials["token"]), api_endpoint=astra_db_credentials["api_endpoint"], namespace=astra_db_credentials["namespace"], @@ -411,23 +308,29 @@ def test_failed_docs_autodetect_flat_novectorize_crud( ignore_invalid_documents=True, ) - results_e = ad_store_e.similarity_search("query", k=3) + results_e = ad_store_e.similarity_search("[-1,-1]", k=3) assert len(results_e) == 1 - results_w = ad_store_w.similarity_search("query", k=3) + results_w = ad_store_w.similarity_search("[-1,-1]", k=3) assert len(results_w) == 1 - novectorize_collection.insert_one( + empty_collection_idxall_d2.insert_one( { "_id": "2", - "$vector": [0.1, 0.2], + "$vector": [3, 4], "m1": "invalid:", - "m2": "no $vector!", + "m2": "no 'cont'", } ) with pytest.raises(KeyError): - ad_store_e.similarity_search("query", k=3) + ad_store_e.similarity_search("[7,8]", k=3) - results_w_post = ad_store_w.similarity_search("query", k=3) + # one case should result in just a warning: + with pytest.warns(UserWarning) as rec_warnings: + results_w_post = ad_store_w.similarity_search("[7,8]", k=3) + f_rec_warnings = [ + wrn for wrn in rec_warnings if issubclass(wrn.category, UserWarning) + ] + assert len(f_rec_warnings) == 1 assert len(results_w_post) == 1 diff --git a/libs/astradb/tests/integration_tests/test_vectorstore_ddl_tests.py b/libs/astradb/tests/integration_tests/test_vectorstore_ddl_tests.py new file mode 100644 index 0000000..a8a8b79 --- /dev/null +++ b/libs/astradb/tests/integration_tests/test_vectorstore_ddl_tests.py @@ -0,0 +1,538 @@ +"""DDL-heavy parts of the tests for the Astra DB vector store class `AstraDBVectorStore` + +Refer to `test_vectorstores.py` for the requirements to run. +""" + +from __future__ import annotations + +import warnings +from typing import TYPE_CHECKING + +import pytest +from astrapy.authentication import EmbeddingAPIKeyHeaderProvider, StaticTokenProvider +from astrapy.exceptions import InsertManyException + +from langchain_astradb.utils.astradb import SetupMode +from langchain_astradb.vectorstores import AstraDBVectorStore + +from .conftest import ( + EPHEMERAL_COLLECTION_NAME_D2, + EPHEMERAL_COLLECTION_NAME_VZ, + EPHEMERAL_COLLECTION_NAME_VZ_KMS, + EPHEMERAL_CUSTOM_IDX_NAME_D2, + EPHEMERAL_DEFAULT_IDX_NAME_D2, + EPHEMERAL_LEGACY_IDX_NAME_D2, + INCOMPATIBLE_INDEXING_MSG, + LEGACY_INDEXING_MSG, + OPENAI_VECTORIZE_OPTIONS_HEADER, + OPENAI_VECTORIZE_OPTIONS_KMS, + astra_db_env_vars_available, +) + +if TYPE_CHECKING: + from astrapy import Collection, Database + from langchain_core.embeddings import Embeddings + + from .conftest import AstraDBCredentials + + +@pytest.mark.skipif( + not astra_db_env_vars_available(), reason="Missing Astra DB env. vars" +) +class TestAstraDBVectorStoreDDLs: + def test_astradb_vectorstore_create_delete_sync( + self, + astra_db_credentials: AstraDBCredentials, + database: Database, + embedding_d2: Embeddings, + ephemeral_collection_cleaner_d2: str, # noqa: ARG002 + ) -> None: + """Create and delete.""" + v_store = AstraDBVectorStore( + embedding=embedding_d2, + collection_name=EPHEMERAL_COLLECTION_NAME_D2, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + metric="cosine", + ) + v_store.add_texts(["[1,2]"]) + v_store.delete_collection() + assert EPHEMERAL_COLLECTION_NAME_D2 not in database.list_collection_names() + + def test_astradb_vectorstore_create_delete_vectorize_sync( + self, + astra_db_credentials: AstraDBCredentials, + openai_api_key: str, + database: Database, + ephemeral_collection_cleaner_vz: str, # noqa: ARG002 + ) -> None: + """Create and delete with vectorize option.""" + v_store = AstraDBVectorStore( + collection_name=EPHEMERAL_COLLECTION_NAME_VZ, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + metric="cosine", + collection_vector_service_options=OPENAI_VECTORIZE_OPTIONS_HEADER, + collection_embedding_api_key=openai_api_key, + ) + v_store.add_texts(["This is text"]) + v_store.delete_collection() + assert EPHEMERAL_COLLECTION_NAME_VZ not in database.list_collection_names() + + async def test_astradb_vectorstore_create_delete_async( + self, + astra_db_credentials: AstraDBCredentials, + database: Database, + embedding_d2: Embeddings, + ephemeral_collection_cleaner_d2: str, # noqa: ARG002 + ) -> None: + """Create and delete, async.""" + v_store = AstraDBVectorStore( + embedding=embedding_d2, + collection_name=EPHEMERAL_COLLECTION_NAME_D2, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + metric="cosine", + ) + await v_store.aadd_texts(["[1,2]"]) + await v_store.adelete_collection() + assert EPHEMERAL_COLLECTION_NAME_D2 not in database.list_collection_names() + + async def test_astradb_vectorstore_create_delete_vectorize_async( + self, + astra_db_credentials: AstraDBCredentials, + openai_api_key: str, + database: Database, + ephemeral_collection_cleaner_vz: str, # noqa: ARG002 + ) -> None: + """Create and delete with vectorize option, async.""" + v_store = AstraDBVectorStore( + collection_name=EPHEMERAL_COLLECTION_NAME_VZ, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + metric="cosine", + collection_vector_service_options=OPENAI_VECTORIZE_OPTIONS_HEADER, + collection_embedding_api_key=openai_api_key, + ) + await v_store.aadd_texts(["[1,2]"]) + await v_store.adelete_collection() + assert EPHEMERAL_COLLECTION_NAME_VZ not in database.list_collection_names() + + def test_astradb_vectorstore_pre_delete_collection_sync( + self, + astra_db_credentials: AstraDBCredentials, + embedding_d2: Embeddings, + ephemeral_collection_cleaner_d2: str, # noqa: ARG002 + ) -> None: + """Use of the pre_delete_collection flag.""" + v_store = AstraDBVectorStore( + embedding=embedding_d2, + collection_name=EPHEMERAL_COLLECTION_NAME_D2, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + metric="cosine", + ) + v_store.add_texts(texts=["[1,2]"]) + res1 = v_store.similarity_search("[-1,-1]", k=5) + assert len(res1) == 1 + v_store = AstraDBVectorStore( + embedding=embedding_d2, + collection_name=EPHEMERAL_COLLECTION_NAME_D2, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + metric="cosine", + pre_delete_collection=True, + ) + res1 = v_store.similarity_search("[-1,-1]", k=5) + assert len(res1) == 0 + + async def test_astradb_vectorstore_pre_delete_collection_async( + self, + astra_db_credentials: AstraDBCredentials, + embedding_d2: Embeddings, + ephemeral_collection_cleaner_d2: str, # noqa: ARG002 + ) -> None: + v_store = AstraDBVectorStore( + embedding=embedding_d2, + collection_name=EPHEMERAL_COLLECTION_NAME_D2, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + setup_mode=SetupMode.ASYNC, + metric="cosine", + ) + await v_store.aadd_texts(texts=["[1,2]"]) + res1 = await v_store.asimilarity_search("[-1,-1]", k=5) + assert len(res1) == 1 + v_store = AstraDBVectorStore( + embedding=embedding_d2, + collection_name=EPHEMERAL_COLLECTION_NAME_D2, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + setup_mode=SetupMode.ASYNC, + metric="cosine", + pre_delete_collection=True, + ) + res1 = await v_store.asimilarity_search("[-1,-1]", k=5) + assert len(res1) == 0 + + def test_astradb_vectorstore_indexing_legacy_sync( + self, + astra_db_credentials: AstraDBCredentials, + database: Database, + embedding_d2: Embeddings, + ephemeral_indexing_collections_cleaner: list[str], # noqa: ARG002 + ) -> None: + """ + Test of the vector store behaviour for various indexing settings, + with an existing 'legacy' collection (i.e. unspecified indexing policy). + """ + database.create_collection( + EPHEMERAL_LEGACY_IDX_NAME_D2, + dimension=2, + check_exists=False, + ) + + with pytest.raises( + ValueError, + match=LEGACY_INDEXING_MSG, + ): + AstraDBVectorStore( + collection_name=EPHEMERAL_LEGACY_IDX_NAME_D2, + embedding=embedding_d2, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + metadata_indexing_exclude={"long_summary", "the_divine_comedy"}, + ) + + # one case should result in just a warning: + with pytest.warns(UserWarning) as rec_warnings: + AstraDBVectorStore( + collection_name=EPHEMERAL_LEGACY_IDX_NAME_D2, + embedding=embedding_d2, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + ) + f_rec_warnings = [ + wrn for wrn in rec_warnings if issubclass(wrn.category, UserWarning) + ] + assert len(f_rec_warnings) == 1 + + def test_astradb_vectorstore_indexing_default_sync( + self, + astra_db_credentials: AstraDBCredentials, + embedding_d2: Embeddings, + ephemeral_indexing_collections_cleaner: str, # noqa: ARG002 + ) -> None: + """ + Test of the vector store behaviour for various indexing settings, + with an existing 'default' collection. + """ + AstraDBVectorStore( + collection_name=EPHEMERAL_DEFAULT_IDX_NAME_D2, + embedding=embedding_d2, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + ) + + with warnings.catch_warnings(): + warnings.simplefilter("error") + AstraDBVectorStore( + collection_name=EPHEMERAL_DEFAULT_IDX_NAME_D2, + embedding=embedding_d2, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + ) + + # unacceptable for a pre-existing (default-indexing) collection: + with pytest.raises(ValueError, match=INCOMPATIBLE_INDEXING_MSG): + AstraDBVectorStore( + collection_name=EPHEMERAL_DEFAULT_IDX_NAME_D2, + embedding=embedding_d2, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + metadata_indexing_exclude={"long_summary", "the_divine_comedy"}, + ) + + def test_astradb_vectorstore_indexing_custom_sync( + self, + astra_db_credentials: AstraDBCredentials, + embedding_d2: Embeddings, + ephemeral_indexing_collections_cleaner: str, # noqa: ARG002 + ) -> None: + """ + Test of the vector store behaviour for various indexing settings, + with an existing custom-indexing collection. + """ + AstraDBVectorStore( + collection_name=EPHEMERAL_CUSTOM_IDX_NAME_D2, + embedding=embedding_d2, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + metadata_indexing_exclude={"long_summary", "the_divine_comedy"}, + ) + + with warnings.catch_warnings(): + warnings.simplefilter("error") + AstraDBVectorStore( + collection_name=EPHEMERAL_CUSTOM_IDX_NAME_D2, + embedding=embedding_d2, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + metadata_indexing_exclude={"long_summary", "the_divine_comedy"}, + ) + + with pytest.raises(ValueError, match=INCOMPATIBLE_INDEXING_MSG): + AstraDBVectorStore( + collection_name=EPHEMERAL_CUSTOM_IDX_NAME_D2, + embedding=embedding_d2, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + metadata_indexing_exclude={"changed_fields"}, + ) + + with pytest.raises(ValueError, match=INCOMPATIBLE_INDEXING_MSG): + AstraDBVectorStore( + collection_name=EPHEMERAL_CUSTOM_IDX_NAME_D2, + embedding=embedding_d2, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + ) + + async def test_astradb_vectorstore_indexing_legacy_async( + self, + astra_db_credentials: AstraDBCredentials, + database: Database, + embedding_d2: Embeddings, + ephemeral_indexing_collections_cleaner: list[str], # noqa: ARG002 + ) -> None: + """ + Test of the vector store behaviour for various indexing settings, + with an existing 'legacy' collection (i.e. unspecified indexing policy). + """ + database.create_collection( + EPHEMERAL_LEGACY_IDX_NAME_D2, + dimension=2, + check_exists=False, + ) + + with pytest.raises( + ValueError, + match=LEGACY_INDEXING_MSG, + ): + await AstraDBVectorStore( + collection_name=EPHEMERAL_LEGACY_IDX_NAME_D2, + embedding=embedding_d2, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + setup_mode=SetupMode.ASYNC, + metadata_indexing_exclude={"long_summary", "the_divine_comedy"}, + ).aadd_texts(["[4,13]"]) + + # one case should result in just a warning: + with pytest.warns(UserWarning) as rec_warnings: + await AstraDBVectorStore( + collection_name=EPHEMERAL_LEGACY_IDX_NAME_D2, + embedding=embedding_d2, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + setup_mode=SetupMode.ASYNC, + ).aadd_texts(["[4,13]"]) + f_rec_warnings = [ + wrn for wrn in rec_warnings if issubclass(wrn.category, UserWarning) + ] + assert len(f_rec_warnings) == 1 + + async def test_astradb_vectorstore_indexing_default_async( + self, + astra_db_credentials: AstraDBCredentials, + embedding_d2: Embeddings, + ephemeral_indexing_collections_cleaner: str, # noqa: ARG002 + ) -> None: + """ + Test of the vector store behaviour for various indexing settings, + with an existing 'default' collection. + """ + await AstraDBVectorStore( + collection_name=EPHEMERAL_DEFAULT_IDX_NAME_D2, + embedding=embedding_d2, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + setup_mode=SetupMode.ASYNC, + ).aadd_texts(["[4,13]"]) + + with warnings.catch_warnings(): + warnings.simplefilter("error") + await AstraDBVectorStore( + collection_name=EPHEMERAL_DEFAULT_IDX_NAME_D2, + embedding=embedding_d2, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + setup_mode=SetupMode.ASYNC, + ).aadd_texts(["[4,13]"]) + + # unacceptable for a pre-existing (default-indexing) collection: + with pytest.raises(ValueError, match=INCOMPATIBLE_INDEXING_MSG): + await AstraDBVectorStore( + collection_name=EPHEMERAL_DEFAULT_IDX_NAME_D2, + embedding=embedding_d2, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + setup_mode=SetupMode.ASYNC, + metadata_indexing_exclude={"long_summary", "the_divine_comedy"}, + ).aadd_texts(["[4,13]"]) + + async def test_astradb_vectorstore_indexing_custom_async( + self, + astra_db_credentials: AstraDBCredentials, + embedding_d2: Embeddings, + ephemeral_indexing_collections_cleaner: str, # noqa: ARG002 + ) -> None: + """ + Test of the vector store behaviour for various indexing settings, + with an existing custom-indexing collection. + """ + await AstraDBVectorStore( + collection_name=EPHEMERAL_CUSTOM_IDX_NAME_D2, + embedding=embedding_d2, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + setup_mode=SetupMode.ASYNC, + metadata_indexing_exclude={"long_summary", "the_divine_comedy"}, + ).aadd_texts(["[4,13]"]) + + with warnings.catch_warnings(): + warnings.simplefilter("error") + await AstraDBVectorStore( + collection_name=EPHEMERAL_CUSTOM_IDX_NAME_D2, + embedding=embedding_d2, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + setup_mode=SetupMode.ASYNC, + metadata_indexing_exclude={"long_summary", "the_divine_comedy"}, + ).aadd_texts(["[4,13]"]) + + with pytest.raises(ValueError, match=INCOMPATIBLE_INDEXING_MSG): + await AstraDBVectorStore( + collection_name=EPHEMERAL_CUSTOM_IDX_NAME_D2, + embedding=embedding_d2, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + setup_mode=SetupMode.ASYNC, + metadata_indexing_exclude={"changed_fields"}, + ).aadd_texts(["[4,13]"]) + + with pytest.raises(ValueError, match=INCOMPATIBLE_INDEXING_MSG): + await AstraDBVectorStore( + collection_name=EPHEMERAL_CUSTOM_IDX_NAME_D2, + embedding=embedding_d2, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + setup_mode=SetupMode.ASYNC, + ).aadd_texts(["[4,13]"]) + + @pytest.mark.skipif( + OPENAI_VECTORIZE_OPTIONS_KMS is None, + reason="A KMS ('shared secret') API Key name is required", + ) + def test_astradb_vectorstore_vectorize_headers_precedence_stringheader( + self, + astra_db_credentials: AstraDBCredentials, + collection_vz: Collection, # noqa: ARG002 + ephemeral_collection_cleaner_vz_kms: str, # noqa: ARG002 + ) -> None: + """ + Test that header, if passed, takes precedence over vectorize setting. + To do so, a faulty header is passed, expecting the call to fail. + """ + v_store = AstraDBVectorStore( + collection_name=EPHEMERAL_COLLECTION_NAME_VZ_KMS, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + setup_mode=SetupMode.OFF, + collection_vector_service_options=OPENAI_VECTORIZE_OPTIONS_KMS, + collection_embedding_api_key="verywrong", + ) + with pytest.raises(InsertManyException): + v_store.add_texts(["Failing"]) + + @pytest.mark.skipif( + OPENAI_VECTORIZE_OPTIONS_KMS is None, + reason="A KMS ('shared secret') API Key name is required", + ) + def test_astradb_vectorstore_vectorize_headers_precedence_headerprovider( + self, + astra_db_credentials: AstraDBCredentials, + collection_vz: Collection, # noqa: ARG002 + ephemeral_collection_cleaner_vz_kms: str, # noqa: ARG002 + ) -> None: + """ + Test that header, if passed, takes precedence over vectorize setting. + To do so, a faulty header is passed, expecting the call to fail. + This version passes the header through an EmbeddingHeaderProvider + """ + v_store = AstraDBVectorStore( + collection_name=EPHEMERAL_COLLECTION_NAME_VZ_KMS, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + setup_mode=SetupMode.OFF, + collection_vector_service_options=OPENAI_VECTORIZE_OPTIONS_KMS, + collection_embedding_api_key=EmbeddingAPIKeyHeaderProvider("verywrong"), + ) + with pytest.raises(InsertManyException): + v_store.add_texts(["Failing"]) diff --git a/libs/astradb/tests/integration_tests/test_vectorstores.py b/libs/astradb/tests/integration_tests/test_vectorstores.py deleted file mode 100644 index 3690105..0000000 --- a/libs/astradb/tests/integration_tests/test_vectorstores.py +++ /dev/null @@ -1,1951 +0,0 @@ -"""Test of Astra DB vector store class `AstraDBVectorStore` - -Required to run this test: - - a recent `astrapy` Python package available - - an Astra DB instance; - - the two environment variables set: - export ASTRA_DB_API_ENDPOINT="https://-us-east1.apps.astra.datastax.com" - export ASTRA_DB_APPLICATION_TOKEN="AstraCS:........." - - optionally this as well (otherwise defaults are used): - export ASTRA_DB_KEYSPACE="my_keyspace" - - an openai secret for SHARED_SECRET mode, associated to the DB, with name on KMS: - export SHARED_SECRET_NAME_OPENAI="the_api_key_name_in_Astra_KMS" - - an OpenAI key for the vectorize test (in HEADER mode): - export OPENAI_API_KEY="..." - - optionally to test vectorize with nvidia as well (besides openai): - export NVIDIA_VECTORIZE_AVAILABLE="1" - - optionally: - export ASTRA_DB_SKIP_COLLECTION_DELETIONS="0" ("1" = no deletions, default) -""" - -from __future__ import annotations - -import json -import math -import os -import warnings -from typing import TYPE_CHECKING, Iterable - -import pytest -from astrapy.authentication import EmbeddingAPIKeyHeaderProvider, StaticTokenProvider -from langchain_core.documents import Document - -from langchain_astradb.utils.astradb import SetupMode -from langchain_astradb.vectorstores import AstraDBVectorStore -from tests.conftest import ParserEmbeddings, SomeEmbeddings - -from .conftest import ( - NVIDIA_VECTORIZE_OPTIONS, - OPENAI_VECTORIZE_OPTIONS, - OPENAI_VECTORIZE_OPTIONS_HEADER, - AstraDBCredentials, - _has_env_vars, -) - -if TYPE_CHECKING: - from astrapy import Database - from astrapy.db import AstraDB - -# Faster testing (no actual collection deletions). Off by default (=full tests) -SKIP_COLLECTION_DELETE = ( - int(os.environ.get("ASTRA_DB_SKIP_COLLECTION_DELETIONS", "0")) != 0 -) - -COLLECTION_NAME_DIM2 = "lc_test_d2" -COLLECTION_NAME_DIM2_EUCLIDEAN = "lc_test_d2_eucl" -COLLECTION_NAME_VECTORIZE_OPENAI = "lc_test_vec_openai" -COLLECTION_NAME_VECTORIZE_OPENAI_HEADER = "lc_test_vec_openai_h" -COLLECTION_NAME_VECTORIZE_NVIDIA = "lc_test_nvidia" - -MATCH_EPSILON = 0.0001 - -INCOMPATIBLE_INDEXING_MSG = "is detected as having the following indexing policy" - - -def is_nvidia_vector_service_available() -> bool: - # For the time being, this is manually controlled - if os.environ.get("NVIDIA_VECTORIZE_AVAILABLE"): - try: - # any non-zero counts as true: - return int(os.environ["NVIDIA_VECTORIZE_AVAILABLE"]) != 0 - except (TypeError, ValueError): - # the env var has unparsable contents: - return False - else: - return False - - -@pytest.fixture -def store_someemb( - astra_db_credentials: AstraDBCredentials, -) -> Iterable[AstraDBVectorStore]: - emb = SomeEmbeddings(dimension=2) - v_store = AstraDBVectorStore( - embedding=emb, - collection_name=COLLECTION_NAME_DIM2, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - ) - v_store.clear() - - yield v_store - - if not SKIP_COLLECTION_DELETE: - v_store.delete_collection() - else: - v_store.clear() - - -@pytest.fixture -def store_someemb_tokenprovider( - astra_db_credentials: AstraDBCredentials, -) -> Iterable[AstraDBVectorStore]: - """A variant of the store_someemb using a TokenProvider for DB auth.""" - emb = SomeEmbeddings(dimension=2) - v_store = AstraDBVectorStore( - embedding=emb, - collection_name=COLLECTION_NAME_DIM2, - token=StaticTokenProvider(astra_db_credentials["token"]), - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - ) - v_store.clear() - - yield v_store - - if not SKIP_COLLECTION_DELETE: - v_store.delete_collection() - else: - v_store.clear() - - -@pytest.fixture -def store_parseremb( - astra_db_credentials: AstraDBCredentials, -) -> Iterable[AstraDBVectorStore]: - emb = ParserEmbeddings(dimension=2) - v_store = AstraDBVectorStore( - embedding=emb, - collection_name=COLLECTION_NAME_DIM2, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - ) - v_store.clear() - - yield v_store - - if not SKIP_COLLECTION_DELETE: - v_store.delete_collection() - else: - v_store.clear() - - -@pytest.fixture -def vectorize_store( - astra_db_credentials: AstraDBCredentials, -) -> Iterable[AstraDBVectorStore]: - """Astra db vector store with server-side embeddings using openai + shared_secret""" - if "SHARED_SECRET_NAME_OPENAI" not in os.environ: - pytest.skip("OpenAI SHARED_SECRET key not set for KMS vectorize") - - v_store = AstraDBVectorStore( - collection_vector_service_options=OPENAI_VECTORIZE_OPTIONS, - collection_name=COLLECTION_NAME_VECTORIZE_OPENAI, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - ) - v_store.clear() - - yield v_store - - # explicitly delete the collection to avoid max collection limit - v_store.delete_collection() - - -@pytest.fixture -def vectorize_store_w_header( - astra_db_credentials: AstraDBCredentials, -) -> Iterable[AstraDBVectorStore]: - """Astra db vector store with server-side embeddings using openai + header""" - if not os.environ.get("OPENAI_API_KEY"): - pytest.skip("OpenAI key not available") - - v_store = AstraDBVectorStore( - collection_vector_service_options=OPENAI_VECTORIZE_OPTIONS_HEADER, - collection_name=COLLECTION_NAME_VECTORIZE_OPENAI_HEADER, - collection_embedding_api_key=os.environ["OPENAI_API_KEY"], - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - ) - v_store.clear() - - yield v_store - - # explicitly delete the collection to avoid max collection limit - v_store.delete_collection() - - -@pytest.fixture -def vectorize_store_w_header_and_provider( - astra_db_credentials: AstraDBCredentials, -) -> Iterable[AstraDBVectorStore]: - """Astra db vector store with server-side embeddings using openai + header - Variant initialized with a EmbeddingHeadersProvider instance for the header - """ - if not os.environ.get("OPENAI_API_KEY"): - pytest.skip("OpenAI key not available") - - v_store = AstraDBVectorStore( - collection_vector_service_options=OPENAI_VECTORIZE_OPTIONS_HEADER, - collection_name=COLLECTION_NAME_VECTORIZE_OPENAI_HEADER, - collection_embedding_api_key=EmbeddingAPIKeyHeaderProvider( - os.environ["OPENAI_API_KEY"], - ), - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - ) - v_store.clear() - - yield v_store - - # explicitly delete the collection to avoid max collection limit - v_store.delete_collection() - - -@pytest.fixture -def vectorize_store_nvidia( - astra_db_credentials: AstraDBCredentials, -) -> Iterable[AstraDBVectorStore]: - """Astra db vector store with server-side embeddings using the nvidia model""" - if not is_nvidia_vector_service_available(): - pytest.skip("vectorize/nvidia unavailable") - - v_store = AstraDBVectorStore( - collection_vector_service_options=NVIDIA_VECTORIZE_OPTIONS, - collection_name=COLLECTION_NAME_VECTORIZE_NVIDIA, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - ) - v_store.clear() - - yield v_store - - # explicitly delete the collection to avoid max collection limit - v_store.delete_collection() - - -@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars") -class TestAstraDBVectorStore: - def test_astradb_vectorstore_create_delete_sync( - self, astra_db_credentials: AstraDBCredentials - ) -> None: - """Create and delete.""" - emb = SomeEmbeddings(dimension=2) - - v_store = AstraDBVectorStore( - embedding=emb, - collection_name=COLLECTION_NAME_DIM2, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - ) - v_store.add_texts("Sample 1") - if not SKIP_COLLECTION_DELETE: - v_store.delete_collection() - else: - v_store.clear() - - def test_astradb_vectorstore_create_delete_vectorize_sync( - self, astra_db_credentials: AstraDBCredentials - ) -> None: - """Create and delete with vectorize option.""" - v_store = AstraDBVectorStore( - collection_vector_service_options=OPENAI_VECTORIZE_OPTIONS_HEADER, - collection_name=COLLECTION_NAME_VECTORIZE_OPENAI_HEADER, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - collection_embedding_api_key=os.environ["OPENAI_API_KEY"], - ) - v_store.add_texts(["Sample 1"]) - v_store.delete_collection() - - async def test_astradb_vectorstore_create_delete_async( - self, astra_db_credentials: AstraDBCredentials - ) -> None: - """Create and delete.""" - emb = SomeEmbeddings(dimension=2) - # Creation by passing the connection secrets - v_store = AstraDBVectorStore( - embedding=emb, - collection_name=COLLECTION_NAME_DIM2, - **astra_db_credentials, - ) - await v_store.adelete_collection() - - async def test_astradb_vectorstore_create_delete_vectorize_async( - self, astra_db_credentials: AstraDBCredentials - ) -> None: - """Create and delete with vectorize option.""" - v_store = AstraDBVectorStore( - collection_vector_service_options=OPENAI_VECTORIZE_OPTIONS, - collection_name=COLLECTION_NAME_VECTORIZE_OPENAI, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - ) - await v_store.adelete_collection() - - @pytest.mark.skipif( - SKIP_COLLECTION_DELETE, - reason="Collection-deletion tests are suppressed", - ) - def test_astradb_vectorstore_pre_delete_collection_sync( - self, astra_db_credentials: AstraDBCredentials - ) -> None: - """Use of the pre_delete_collection flag.""" - emb = SomeEmbeddings(dimension=2) - v_store = AstraDBVectorStore( - embedding=emb, - collection_name=COLLECTION_NAME_DIM2, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - ) - v_store.clear() - try: - v_store.add_texts( - texts=["aa"], - metadatas=[ - {"k": "a", "ord": 0}, - ], - ids=["a"], - ) - res1 = v_store.similarity_search("aa", k=5) - assert len(res1) == 1 - v_store = AstraDBVectorStore( - embedding=emb, - pre_delete_collection=True, - collection_name=COLLECTION_NAME_DIM2, - **astra_db_credentials, - ) - res1 = v_store.similarity_search("aa", k=5) - assert len(res1) == 0 - finally: - v_store.delete_collection() - - @pytest.mark.skipif( - SKIP_COLLECTION_DELETE, - reason="Collection-deletion tests are suppressed", - ) - async def test_astradb_vectorstore_pre_delete_collection_async( - self, astra_db_credentials: AstraDBCredentials - ) -> None: - """Use of the pre_delete_collection flag.""" - emb = SomeEmbeddings(dimension=2) - # creation by passing the connection secrets - - v_store = AstraDBVectorStore( - embedding=emb, - collection_name=COLLECTION_NAME_DIM2, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - ) - try: - await v_store.aadd_texts( - texts=["aa"], - metadatas=[ - {"k": "a", "ord": 0}, - ], - ids=["a"], - ) - res1 = await v_store.asimilarity_search("aa", k=5) - assert len(res1) == 1 - v_store = AstraDBVectorStore( - embedding=emb, - pre_delete_collection=True, - collection_name=COLLECTION_NAME_DIM2, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - ) - res1 = await v_store.asimilarity_search("aa", k=5) - assert len(res1) == 0 - finally: - await v_store.adelete_collection() - - def test_astradb_vectorstore_from_texts_sync( - self, astra_db_credentials: AstraDBCredentials - ) -> None: - """from_texts methods.""" - emb = SomeEmbeddings(dimension=2) - # prepare empty collection - AstraDBVectorStore( - embedding=emb, - collection_name=COLLECTION_NAME_DIM2, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - ).clear() - # from_texts - v_store = AstraDBVectorStore.from_texts( - texts=["Hi", "Ho"], - embedding=emb, - collection_name=COLLECTION_NAME_DIM2, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - ) - try: - assert v_store.similarity_search("Ho", k=1)[0].page_content == "Ho" - # testing additional kwargs & from_text-specific kwargs - # baseline - AstraDBVectorStore.from_texts( - texts=["F T one", "F T two"], - embedding=emb, - metadatas=[{"m": 1}, {"m": 2}], - ids=["ft1", "ft2"], - collection_name=COLLECTION_NAME_DIM2, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - setup_mode=SetupMode.OFF, - ) - with pytest.warns(UserWarning): - # unknown kwargs going to the constructor through _from_kwargs - AstraDBVectorStore.from_texts( - texts=["F T one", "F T two"], - embedding=emb, - metadatas=[{"m": 1}, {"m": 2}], - ids=["ft1", "ft2"], - collection_name=COLLECTION_NAME_DIM2, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - setup_mode=SetupMode.OFF, - number_of_wizards=123, - name_of_river="Thames", - ) - # routing of 'add_texts' keyword arguments - AstraDBVectorStore.from_texts( - texts=["F T one", "F T two"], - embedding=emb, - metadatas=[{"m": 1}, {"m": 2}], - ids=["ft1", "ft2"], - collection_name=COLLECTION_NAME_DIM2, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - setup_mode=SetupMode.OFF, - batch_size=12, - batch_concurrency=23, - overwrite_concurrency=34, - ) - finally: - if not SKIP_COLLECTION_DELETE: - v_store.delete_collection() - else: - v_store.clear() - - def test_astradb_vectorstore_from_documents_without_ids_sync( - self, astra_db_credentials: AstraDBCredentials - ) -> None: - """from_documents methods.""" - emb = SomeEmbeddings(dimension=2) - v_store = AstraDBVectorStore.from_documents( - [ - Document(page_content="Hee"), - Document(page_content="Hoi"), - ], - embedding=emb, - collection_name=COLLECTION_NAME_DIM2, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - ) - - try: - hits = v_store.similarity_search("Hoi", k=1) - assert len(hits) == 1 - assert hits[0].page_content == "Hoi" - finally: - if not SKIP_COLLECTION_DELETE: - v_store.delete_collection() - else: - v_store.clear() - - def test_astradb_vectorstore_from_documents_separate_ids_sync( - self, astra_db_credentials: AstraDBCredentials - ) -> None: - """from_documents methods.""" - emb = SomeEmbeddings(dimension=2) - with pytest.warns(DeprecationWarning) as rec_warnings: - v_store = AstraDBVectorStore.from_documents( - [ - Document(page_content="Hee"), - Document(page_content="Hoi"), - ], - embedding=emb, - ids=["idx0", "idx1"], - collection_name=COLLECTION_NAME_DIM2, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - ) - f_rec_warnings = [ - wrn for wrn in rec_warnings if issubclass(wrn.category, DeprecationWarning) - ] - assert len(f_rec_warnings) == 1 - - try: - hits = v_store.similarity_search("Hoi", k=1) - assert len(hits) == 1 - assert hits[0].page_content == "Hoi" - assert hits[0].id == "idx1" - finally: - if not SKIP_COLLECTION_DELETE: - v_store.delete_collection() - else: - v_store.clear() - - def test_astradb_vectorstore_from_documents_containing_ids_sync( - self, astra_db_credentials: AstraDBCredentials - ) -> None: - """from_documents methods.""" - emb = SomeEmbeddings(dimension=2) - v_store = AstraDBVectorStore.from_documents( - [ - Document(page_content="Hee", id="idx0"), - Document(page_content="Hoi", id="idx1"), - ], - embedding=emb, - collection_name=COLLECTION_NAME_DIM2, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - ) - try: - hits = v_store.similarity_search("Hoi", k=1) - assert len(hits) == 1 - assert hits[0].page_content == "Hoi" - assert hits[0].id == "idx1" - finally: - if not SKIP_COLLECTION_DELETE: - v_store.delete_collection() - else: - v_store.clear() - - def test_astradb_vectorstore_from_documents_pass_ids_twice_sync( - self, astra_db_credentials: AstraDBCredentials - ) -> None: - """from_documents methods.""" - emb = SomeEmbeddings(dimension=2) - with pytest.warns(DeprecationWarning) as rec_warnings: - v_store = AstraDBVectorStore.from_documents( - [ - Document(page_content="Hee"), - Document(page_content="Hoi", id="idy1"), - ], - ids=["idx0", "idx1"], - embedding=emb, - collection_name=COLLECTION_NAME_DIM2, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - ) - f_rec_warnings = [ - wrn for wrn in rec_warnings if issubclass(wrn.category, DeprecationWarning) - ] - assert len(f_rec_warnings) == 1 - - try: - hits = v_store.similarity_search("Hoi", k=1) - assert len(hits) == 1 - assert hits[0].page_content == "Hoi" - assert hits[0].id == "idx1" - finally: - if not SKIP_COLLECTION_DELETE: - v_store.delete_collection() - else: - v_store.clear() - - def test_astradb_vectorstore_from_texts_vectorize_sync( - self, astra_db_credentials: AstraDBCredentials - ) -> None: - """from_texts methods with vectorize.""" - AstraDBVectorStore( - collection_vector_service_options=OPENAI_VECTORIZE_OPTIONS_HEADER, - collection_embedding_api_key=os.environ["OPENAI_API_KEY"], - collection_name=COLLECTION_NAME_VECTORIZE_OPENAI_HEADER, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - ).clear() - - # from_texts - v_store = AstraDBVectorStore.from_texts( - texts=["Hi", "Ho"], - collection_vector_service_options=OPENAI_VECTORIZE_OPTIONS_HEADER, - collection_embedding_api_key=os.environ["OPENAI_API_KEY"], - collection_name=COLLECTION_NAME_VECTORIZE_OPENAI_HEADER, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - ) - try: - assert v_store.similarity_search("Ho", k=1)[0].page_content == "Ho" - finally: - v_store.delete_collection() - - def test_astradb_vectorstore_from_documents_separate_ids_vectorize_sync( - self, astra_db_credentials: AstraDBCredentials - ) -> None: - """from_documents methods with vectorize.""" - with pytest.warns(DeprecationWarning) as rec_warnings: - v_store = AstraDBVectorStore.from_documents( - [ - Document(page_content="Hee"), - Document(page_content="Hoi"), - ], - ids=["idx0", "idx1"], - collection_vector_service_options=OPENAI_VECTORIZE_OPTIONS_HEADER, - collection_embedding_api_key=os.environ["OPENAI_API_KEY"], - collection_name=COLLECTION_NAME_VECTORIZE_OPENAI_HEADER, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - ) - f_rec_warnings = [ - wrn for wrn in rec_warnings if issubclass(wrn.category, DeprecationWarning) - ] - assert len(f_rec_warnings) == 1 - - try: - hits = v_store.similarity_search("Hoi", k=1) - assert len(hits) == 1 - assert hits[0].page_content == "Hoi" - assert hits[0].id == "idx1" - finally: - v_store.delete_collection() - - async def test_astradb_vectorstore_from_texts_async( - self, astra_db_credentials: AstraDBCredentials - ) -> None: - """from_texts methods.""" - emb = SomeEmbeddings(dimension=2) - # prepare empty collection - await AstraDBVectorStore( - embedding=emb, - collection_name=COLLECTION_NAME_DIM2, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - ).aclear() - # from_texts - v_store = await AstraDBVectorStore.afrom_texts( - texts=["Hi", "Ho"], - embedding=emb, - collection_name=COLLECTION_NAME_DIM2, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - ) - try: - assert (await v_store.asimilarity_search("Ho", k=1))[0].page_content == "Ho" - # testing additional kwargs & from_text-specific kwargs - # baseline - await AstraDBVectorStore.afrom_texts( - texts=["F T one", "F T two"], - embedding=emb, - metadatas=[{"m": 1}, {"m": 2}], - ids=["ft1", "ft2"], - collection_name=COLLECTION_NAME_DIM2, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - setup_mode=SetupMode.OFF, - ) - with pytest.warns(UserWarning): - # unknown kwargs going to the constructor through _from_kwargs - await AstraDBVectorStore.afrom_texts( - texts=["F T one", "F T two"], - embedding=emb, - metadatas=[{"m": 1}, {"m": 2}], - ids=["ft1", "ft2"], - collection_name=COLLECTION_NAME_DIM2, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - setup_mode=SetupMode.OFF, - number_of_wizards=123, - name_of_river="Thames", - ) - # routing of 'add_texts' keyword arguments - await AstraDBVectorStore.afrom_texts( - texts=["F T one", "F T two"], - embedding=emb, - metadatas=[{"m": 1}, {"m": 2}], - ids=["ft1", "ft2"], - collection_name=COLLECTION_NAME_DIM2, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - setup_mode=SetupMode.OFF, - batch_size=12, - batch_concurrency=23, - overwrite_concurrency=34, - ) - finally: - if not SKIP_COLLECTION_DELETE: - await v_store.adelete_collection() - else: - await v_store.aclear() - - async def test_astradb_vectorstore_from_documents_without_ids_async( - self, astra_db_credentials: AstraDBCredentials - ) -> None: - """afrom_documents methods.""" - emb = SomeEmbeddings(dimension=2) - v_store = await AstraDBVectorStore.afrom_documents( - [ - Document(page_content="Hee"), - Document(page_content="Hoi"), - ], - embedding=emb, - collection_name=COLLECTION_NAME_DIM2, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - ) - - try: - hits = await v_store.asimilarity_search("Hoi", k=1) - assert len(hits) == 1 - assert hits[0].page_content == "Hoi" - finally: - if not SKIP_COLLECTION_DELETE: - await v_store.adelete_collection() - else: - await v_store.aclear() - - async def test_astradb_vectorstore_from_documents_separate_ids_async( - self, astra_db_credentials: AstraDBCredentials - ) -> None: - """afrom_documents methods.""" - emb = SomeEmbeddings(dimension=2) - with pytest.warns(DeprecationWarning) as rec_warnings: - v_store = await AstraDBVectorStore.afrom_documents( - [ - Document(page_content="Hee"), - Document(page_content="Hoi"), - ], - embedding=emb, - ids=["idx0", "idx1"], - collection_name=COLLECTION_NAME_DIM2, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - ) - f_rec_warnings = [ - wrn for wrn in rec_warnings if issubclass(wrn.category, DeprecationWarning) - ] - assert len(f_rec_warnings) == 1 - - try: - hits = await v_store.asimilarity_search("Hoi", k=1) - assert len(hits) == 1 - assert hits[0].page_content == "Hoi" - assert hits[0].id == "idx1" - finally: - if not SKIP_COLLECTION_DELETE: - await v_store.adelete_collection() - else: - await v_store.aclear() - - async def test_astradb_vectorstore_from_documents_containing_ids_async( - self, astra_db_credentials: AstraDBCredentials - ) -> None: - """from_documents methods.""" - emb = SomeEmbeddings(dimension=2) - v_store = await AstraDBVectorStore.afrom_documents( - [ - Document(page_content="Hee", id="idx0"), - Document(page_content="Hoi", id="idx1"), - ], - embedding=emb, - collection_name=COLLECTION_NAME_DIM2, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - ) - try: - hits = v_store.similarity_search("Hoi", k=1) - assert len(hits) == 1 - assert hits[0].page_content == "Hoi" - assert hits[0].id == "idx1" - finally: - if not SKIP_COLLECTION_DELETE: - v_store.delete_collection() - else: - v_store.clear() - - async def test_astradb_vectorstore_from_documents_pass_ids_twice_async( - self, astra_db_credentials: AstraDBCredentials - ) -> None: - """from_documents methods.""" - emb = SomeEmbeddings(dimension=2) - with pytest.warns(DeprecationWarning) as rec_warnings: - v_store = await AstraDBVectorStore.afrom_documents( - [ - Document(page_content="Hee"), - Document(page_content="Hoi", id="idy0"), - ], - ids=["idx0", "idx1"], - embedding=emb, - collection_name=COLLECTION_NAME_DIM2, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - ) - f_rec_warnings = [ - wrn for wrn in rec_warnings if issubclass(wrn.category, DeprecationWarning) - ] - assert len(f_rec_warnings) == 1 - - try: - hits = await v_store.asimilarity_search("Hoi", k=1) - assert len(hits) == 1 - assert hits[0].page_content == "Hoi" - assert hits[0].id == "idx1" - finally: - if not SKIP_COLLECTION_DELETE: - v_store.delete_collection() - else: - v_store.clear() - - async def test_astradb_vectorstore_from_texts_vectorize_async( - self, astra_db_credentials: AstraDBCredentials - ) -> None: - """from_texts methods with vectorize.""" - # from_text with vectorize - v_store = await AstraDBVectorStore.afrom_texts( - texts=["Haa", "Huu"], - collection_vector_service_options=OPENAI_VECTORIZE_OPTIONS_HEADER, - collection_embedding_api_key=os.environ["OPENAI_API_KEY"], - collection_name=COLLECTION_NAME_VECTORIZE_OPENAI_HEADER, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - ) - try: - assert (await v_store.asimilarity_search("Haa", k=1))[ - 0 - ].page_content == "Haa" - finally: - await v_store.adelete_collection() - - async def test_astradb_vectorstore_from_documents_separate_ids_vectorize_async( - self, astra_db_credentials: AstraDBCredentials - ) -> None: - """afrom_documents methods with vectorize.""" - with pytest.warns(DeprecationWarning) as rec_warnings: - v_store = await AstraDBVectorStore.afrom_documents( - [ - Document(page_content="HeeH"), - Document(page_content="HooH"), - ], - ids=["idx0", "idx1"], - collection_vector_service_options=OPENAI_VECTORIZE_OPTIONS_HEADER, - collection_embedding_api_key=os.environ["OPENAI_API_KEY"], - collection_name=COLLECTION_NAME_VECTORIZE_OPENAI_HEADER, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - ) - f_rec_warnings = [ - wrn for wrn in rec_warnings if issubclass(wrn.category, DeprecationWarning) - ] - assert len(f_rec_warnings) == 1 - - try: - hits = await v_store.asimilarity_search("HeeH", k=1) - assert len(hits) == 1 - assert hits[0].page_content == "HeeH" - assert hits[0].id == "idx0" - finally: - await v_store.adelete_collection() - - @pytest.mark.parametrize( - "vector_store", - [ - "store_someemb", - "store_someemb_tokenprovider", - "vectorize_store", - "vectorize_store_w_header", - "vectorize_store_w_header_and_provider", - "vectorize_store_nvidia", - ], - ) - def test_astradb_vectorstore_crud_sync( - self, vector_store: str, request: pytest.FixtureRequest - ) -> None: - """Basic add/delete/update behaviour.""" - vstore: AstraDBVectorStore = request.getfixturevalue(vector_store) - - res0 = vstore.similarity_search("Abc", k=2) - assert res0 == [] - # write and check again - vstore.add_texts( - texts=["aa", "bb", "cc"], - metadatas=[ - {"k": "a", "ord": 0}, - {"k": "b", "ord": 1}, - {"k": "c", "ord": 2}, - ], - ids=["a", "b", "c"], - ) - res1 = vstore.similarity_search("Abc", k=5) - assert {doc.page_content for doc in res1} == {"aa", "bb", "cc"} - # partial overwrite and count total entries - vstore.add_texts( - texts=["cc", "dd"], - metadatas=[ - {"k": "c_new", "ord": 102}, - {"k": "d_new", "ord": 103}, - ], - ids=["c", "d"], - ) - res2 = vstore.similarity_search("Abc", k=10) - assert len(res2) == 4 - # pick one that was just updated and check its metadata - res3 = vstore.similarity_search_with_score_id( - query="cc", k=1, filter={"k": "c_new"} - ) - doc3, _, id3 = res3[0] - assert doc3.page_content == "cc" - assert doc3.metadata == {"k": "c_new", "ord": 102} - assert id3 == "c" - # delete and count again - del1_res = vstore.delete(["b"]) - assert del1_res is True - del2_res = vstore.delete(["a", "c", "Z!"]) - assert del2_res is True # a non-existing ID was supplied - assert len(vstore.similarity_search("xy", k=10)) == 1 - # clear store - vstore.clear() - assert vstore.similarity_search("Abc", k=2) == [] - # add_documents with "ids" arg passthrough - vstore.add_documents( - [ - Document(page_content="vv", metadata={"k": "v", "ord": 204}), - Document(page_content="ww", metadata={"k": "w", "ord": 205}), - ], - ids=["v", "w"], - ) - assert len(vstore.similarity_search("xy", k=10)) == 2 - res4 = vstore.similarity_search("ww", k=1, filter={"k": "w"}) - assert res4[0].metadata["ord"] == 205 - - @pytest.mark.parametrize( - "vector_store", - [ - "store_someemb", - "vectorize_store", - "vectorize_store_w_header", - "vectorize_store_nvidia", - ], - ) - async def test_astradb_vectorstore_crud_async( - self, vector_store: str, request: pytest.FixtureRequest - ) -> None: - """Basic add/delete/update behaviour.""" - vstore: AstraDBVectorStore = request.getfixturevalue(vector_store) - - res0 = await vstore.asimilarity_search("Abc", k=2) - assert res0 == [] - # write and check again - await vstore.aadd_texts( - texts=["aa", "bb", "cc"], - metadatas=[ - {"k": "a", "ord": 0}, - {"k": "b", "ord": 1}, - {"k": "c", "ord": 2}, - ], - ids=["a", "b", "c"], - ) - res1 = await vstore.asimilarity_search("Abc", k=5) - assert {doc.page_content for doc in res1} == {"aa", "bb", "cc"} - # partial overwrite and count total entries - await vstore.aadd_texts( - texts=["cc", "dd"], - metadatas=[ - {"k": "c_new", "ord": 102}, - {"k": "d_new", "ord": 103}, - ], - ids=["c", "d"], - ) - res2 = await vstore.asimilarity_search("Abc", k=10) - assert len(res2) == 4 - # pick one that was just updated and check its metadata - res3 = await vstore.asimilarity_search_with_score_id( - query="cc", k=1, filter={"k": "c_new"} - ) - doc3, _, id3 = res3[0] - assert doc3.page_content == "cc" - assert doc3.metadata == {"k": "c_new", "ord": 102} - assert id3 == "c" - # delete and count again - del1_res = await vstore.adelete(["b"]) - assert del1_res is True - del2_res = await vstore.adelete(["a", "c", "Z!"]) - assert del2_res is False # a non-existing ID was supplied - assert len(await vstore.asimilarity_search("xy", k=10)) == 1 - # clear store - await vstore.aclear() - assert await vstore.asimilarity_search("Abc", k=2) == [] - # add_documents with "ids" arg passthrough - await vstore.aadd_documents( - [ - Document(page_content="vv", metadata={"k": "v", "ord": 204}), - Document(page_content="ww", metadata={"k": "w", "ord": 205}), - ], - ids=["v", "w"], - ) - assert len(await vstore.asimilarity_search("xy", k=10)) == 2 - res4 = await vstore.asimilarity_search("ww", k=1, filter={"k": "w"}) - assert res4[0].metadata["ord"] == 205 - - def test_astradb_vectorstore_massive_insert_replace_sync( - self, - store_someemb: AstraDBVectorStore, - ) -> None: - """Testing the insert-many-and-replace-some patterns thoroughly.""" - full_size = 300 - first_group_size = 150 - second_group_slicer = [30, 100, 2] - - all_ids = [f"doc_{idx}" for idx in range(full_size)] - all_texts = [f"document number {idx}" for idx in range(full_size)] - - # massive insertion on empty - group0_ids = all_ids[0:first_group_size] - group0_texts = all_texts[0:first_group_size] - inserted_ids0 = store_someemb.add_texts( - texts=group0_texts, - ids=group0_ids, - ) - assert set(inserted_ids0) == set(group0_ids) - # massive insertion with many overwrites scattered through - # (we change the text to later check on DB for successful update) - _s, _e, _st = second_group_slicer - group1_ids = all_ids[_s:_e:_st] + all_ids[first_group_size:full_size] - group1_texts = [ - txt.upper() - for txt in (all_texts[_s:_e:_st] + all_texts[first_group_size:full_size]) - ] - inserted_ids1 = store_someemb.add_texts( - texts=group1_texts, - ids=group1_ids, - ) - assert set(inserted_ids1) == set(group1_ids) - # final read (we want the IDs to do a full check) - expected_text_by_id = { - **dict(zip(group0_ids, group0_texts)), - **dict(zip(group1_ids, group1_texts)), - } - full_results = store_someemb.similarity_search_with_score_id_by_vector( - embedding=[1.0, 1.0], - k=full_size, - ) - for doc, _, doc_id in full_results: - assert doc.page_content == expected_text_by_id[doc_id] - - async def test_astradb_vectorstore_massive_insert_replace_async( - self, - store_someemb: AstraDBVectorStore, - ) -> None: - """Testing the insert-many-and-replace-some patterns thoroughly.""" - full_size = 300 - first_group_size = 150 - second_group_slicer = [30, 100, 2] - - all_ids = [f"doc_{idx}" for idx in range(full_size)] - all_texts = [f"document number {idx}" for idx in range(full_size)] - - # massive insertion on empty - group0_ids = all_ids[0:first_group_size] - group0_texts = all_texts[0:first_group_size] - - inserted_ids0 = await store_someemb.aadd_texts( - texts=group0_texts, - ids=group0_ids, - ) - assert set(inserted_ids0) == set(group0_ids) - # massive insertion with many overwrites scattered through - # (we change the text to later check on DB for successful update) - _s, _e, _st = second_group_slicer - group1_ids = all_ids[_s:_e:_st] + all_ids[first_group_size:full_size] - group1_texts = [ - txt.upper() - for txt in (all_texts[_s:_e:_st] + all_texts[first_group_size:full_size]) - ] - inserted_ids1 = await store_someemb.aadd_texts( - texts=group1_texts, - ids=group1_ids, - ) - assert set(inserted_ids1) == set(group1_ids) - # final read (we want the IDs to do a full check) - expected_text_by_id = { - **dict(zip(group0_ids, group0_texts)), - **dict(zip(group1_ids, group1_texts)), - } - full_results = await store_someemb.asimilarity_search_with_score_id_by_vector( - embedding=[1.0, 1.0], - k=full_size, - ) - for doc, _, doc_id in full_results: - assert doc.page_content == expected_text_by_id[doc_id] - - def test_astradb_vectorstore_mmr_sync( - self, store_parseremb: AstraDBVectorStore - ) -> None: - """MMR testing. We work on the unit circle with angle multiples - of 2*pi/20 and prepare a store with known vectors for a controlled - MMR outcome. - """ - - def _v_from_i(i: int, n: int) -> str: - angle = 2 * math.pi * i / n - vector = [math.cos(angle), math.sin(angle)] - return json.dumps(vector) - - i_vals = [0, 4, 5, 13] - n_val = 20 - store_parseremb.add_texts( - [_v_from_i(i, n_val) for i in i_vals], metadatas=[{"i": i} for i in i_vals] - ) - res1 = store_parseremb.max_marginal_relevance_search( - _v_from_i(3, n_val), - k=2, - fetch_k=3, - ) - res_i_vals = {doc.metadata["i"] for doc in res1} - assert res_i_vals == {0, 4} - - async def test_astradb_vectorstore_mmr_async( - self, store_parseremb: AstraDBVectorStore - ) -> None: - """MMR testing. We work on the unit circle with angle multiples - of 2*pi/20 and prepare a store with known vectors for a controlled - MMR outcome. - """ - - def _v_from_i(i: int, n: int) -> str: - angle = 2 * math.pi * i / n - vector = [math.cos(angle), math.sin(angle)] - return json.dumps(vector) - - i_vals = [0, 4, 5, 13] - n_val = 20 - await store_parseremb.aadd_texts( - [_v_from_i(i, n_val) for i in i_vals], - metadatas=[{"i": i} for i in i_vals], - ) - res1 = await store_parseremb.amax_marginal_relevance_search( - _v_from_i(3, n_val), - k=2, - fetch_k=3, - ) - res_i_vals = {doc.metadata["i"] for doc in res1} - assert res_i_vals == {0, 4} - - def test_astradb_vectorstore_mmr_vectorize_sync( - self, vectorize_store: AstraDBVectorStore - ) -> None: - """MMR testing with vectorize, sync.""" - vectorize_store.add_texts( - [ - "Dog", - "Wolf", - "Ant", - "Sunshine and piadina", - ], - ids=["d", "w", "a", "s"], - ) - - hits = vectorize_store.max_marginal_relevance_search("Dingo", k=2, fetch_k=3) - assert {doc.page_content for doc in hits} == {"Dog", "Ant"} - - async def test_astradb_vectorstore_mmr_vectorize_async( - self, vectorize_store: AstraDBVectorStore - ) -> None: - """MMR async testing with vectorize, async.""" - await vectorize_store.aadd_texts( - [ - "Dog", - "Wolf", - "Ant", - "Sunshine and piadina", - ], - ids=["d", "w", "a", "s"], - ) - - hits = await vectorize_store.amax_marginal_relevance_search( - "Dingo", - k=2, - fetch_k=3, - ) - assert {doc.page_content for doc in hits} == {"Dog", "Ant"} - - @pytest.mark.parametrize( - "vector_store", - [ - "store_someemb", - "vectorize_store", - "vectorize_store_w_header", - "vectorize_store_nvidia", - ], - ) - def test_astradb_vectorstore_metadata( - self, vector_store: str, request: pytest.FixtureRequest - ) -> None: - """Metadata filtering.""" - vstore: AstraDBVectorStore = request.getfixturevalue(vector_store) - vstore.add_documents( - [ - Document( - page_content="q", - metadata={"ord": ord("q"), "group": "consonant"}, - ), - Document( - page_content="w", - metadata={"ord": ord("w"), "group": "consonant"}, - ), - Document( - page_content="r", - metadata={"ord": ord("r"), "group": "consonant"}, - ), - Document( - page_content="e", - metadata={"ord": ord("e"), "group": "vowel"}, - ), - Document( - page_content="i", - metadata={"ord": ord("i"), "group": "vowel"}, - ), - Document( - page_content="o", - metadata={"ord": ord("o"), "group": "vowel"}, - ), - ] - ) - # no filters - res0 = vstore.similarity_search("x", k=10) - assert {doc.page_content for doc in res0} == set("qwreio") - # single filter - res1 = vstore.similarity_search( - "x", - k=10, - filter={"group": "vowel"}, - ) - assert {doc.page_content for doc in res1} == set("eio") - # multiple filters - res2 = vstore.similarity_search( - "x", - k=10, - filter={"group": "consonant", "ord": ord("q")}, - ) - assert {doc.page_content for doc in res2} == set("q") - # excessive filters - res3 = vstore.similarity_search( - "x", - k=10, - filter={"group": "consonant", "ord": ord("q"), "case": "upper"}, - ) - assert res3 == [] - # filter with logical operator - res4 = vstore.similarity_search( - "x", - k=10, - filter={"$or": [{"ord": ord("q")}, {"ord": ord("r")}]}, - ) - assert {doc.page_content for doc in res4} == {"q", "r"} - - @pytest.mark.parametrize("vector_store", ["store_parseremb"]) - def test_astradb_vectorstore_similarity_scale_sync( - self, vector_store: str, request: pytest.FixtureRequest - ) -> None: - """Scale of the similarity scores.""" - vstore: AstraDBVectorStore = request.getfixturevalue(vector_store) - vstore.add_texts( - texts=[ - json.dumps([1, 1]), - json.dumps([-1, -1]), - ], - ids=["near", "far"], - ) - res1 = vstore.similarity_search_with_score( - json.dumps([0.5, 0.5]), - k=2, - ) - scores = [sco for _, sco in res1] - sco_near, sco_far = scores - assert abs(1 - sco_near) < MATCH_EPSILON - assert abs(sco_far) < MATCH_EPSILON - - @pytest.mark.parametrize("vector_store", ["store_parseremb"]) - async def test_astradb_vectorstore_similarity_scale_async( - self, vector_store: str, request: pytest.FixtureRequest - ) -> None: - """Scale of the similarity scores.""" - vstore: AstraDBVectorStore = request.getfixturevalue(vector_store) - await vstore.aadd_texts( - texts=[ - json.dumps([1, 1]), - json.dumps([-1, -1]), - ], - ids=["near", "far"], - ) - res1 = await vstore.asimilarity_search_with_score( - json.dumps([0.5, 0.5]), - k=2, - ) - scores = [sco for _, sco in res1] - sco_near, sco_far = scores - assert abs(1 - sco_near) < MATCH_EPSILON - assert abs(sco_far) < MATCH_EPSILON - - @pytest.mark.parametrize( - "vector_store", - [ - "store_someemb", - "vectorize_store", - "vectorize_store_w_header", - "vectorize_store_nvidia", - ], - ) - def test_astradb_vectorstore_massive_delete( - self, vector_store: str, request: pytest.FixtureRequest - ) -> None: - """Larger-scale bulk deletes.""" - vstore: AstraDBVectorStore = request.getfixturevalue(vector_store) - m = 150 - texts = [str(i + 1 / 7.0) for i in range(2 * m)] - ids0 = ["doc_%i" % i for i in range(m)] - ids1 = ["doc_%i" % (i + m) for i in range(m)] - ids = ids0 + ids1 - vstore.add_texts(texts=texts, ids=ids) - # deleting a bunch of these - del_res0 = vstore.delete(ids0) - assert del_res0 is True - # deleting the rest plus a fake one - del_res1 = vstore.delete([*ids1, "ghost!"]) - assert del_res1 is True # ensure no error - # nothing left - assert vstore.similarity_search("x", k=2 * m) == [] - - @pytest.mark.skipif( - SKIP_COLLECTION_DELETE, - reason="Collection-deletion tests are suppressed", - ) - def test_astradb_vectorstore_delete_collection( - self, astra_db_credentials: AstraDBCredentials - ) -> None: - """Behaviour of 'delete_collection'.""" - collection_name = COLLECTION_NAME_DIM2 - emb = SomeEmbeddings(dimension=2) - v_store = AstraDBVectorStore( - embedding=emb, - collection_name=collection_name, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - ) - v_store.add_texts(["huh"]) - assert len(v_store.similarity_search("hah", k=10)) == 1 - # another instance pointing to the same collection on DB - v_store_kenny = AstraDBVectorStore( - embedding=emb, - collection_name=collection_name, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - ) - v_store_kenny.delete_collection() - # dropped on DB, but 'v_store' should have no clue: - with pytest.raises(ValueError, match="Collection does not exist"): - _ = v_store.similarity_search("hah", k=10) - - def test_astradb_vectorstore_custom_params_sync( - self, astra_db_credentials: AstraDBCredentials - ) -> None: - """Custom batch size and concurrency params.""" - emb = SomeEmbeddings(dimension=2) - # prepare empty collection - AstraDBVectorStore( - embedding=emb, - collection_name=COLLECTION_NAME_DIM2, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - ).clear() - v_store = AstraDBVectorStore( - embedding=emb, - collection_name=COLLECTION_NAME_DIM2, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - batch_size=17, - bulk_insert_batch_concurrency=13, - bulk_insert_overwrite_concurrency=7, - bulk_delete_concurrency=19, - ) - try: - # add_texts - n = 50 - texts = [str(i + 1 / 7.0) for i in range(n)] - ids = ["doc_%i" % i for i in range(n)] - v_store.add_texts(texts=texts, ids=ids) - v_store.add_texts( - texts=texts, - ids=ids, - batch_size=19, - batch_concurrency=7, - overwrite_concurrency=13, - ) - _ = v_store.delete(ids[: n // 2]) - _ = v_store.delete(ids[n // 2 :], concurrency=23) - finally: - if not SKIP_COLLECTION_DELETE: - v_store.delete_collection() - else: - v_store.clear() - - async def test_astradb_vectorstore_custom_params_async( - self, astra_db_credentials: AstraDBCredentials - ) -> None: - """Custom batch size and concurrency params.""" - emb = SomeEmbeddings(dimension=2) - v_store = AstraDBVectorStore( - embedding=emb, - collection_name="lc_test_c_async", - batch_size=17, - bulk_insert_batch_concurrency=13, - bulk_insert_overwrite_concurrency=7, - bulk_delete_concurrency=19, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - ) - try: - # add_texts - n = 50 - texts = [str(i + 1 / 7.0) for i in range(n)] - ids = ["doc_%i" % i for i in range(n)] - await v_store.aadd_texts(texts=texts, ids=ids) - await v_store.aadd_texts( - texts=texts, - ids=ids, - batch_size=19, - batch_concurrency=7, - overwrite_concurrency=13, - ) - await v_store.adelete(ids[: n // 2]) - await v_store.adelete(ids[n // 2 :], concurrency=23) - finally: - if not SKIP_COLLECTION_DELETE: - await v_store.adelete_collection() - else: - await v_store.aclear() - - def test_astradb_vectorstore_metrics( - self, astra_db_credentials: AstraDBCredentials - ) -> None: - """Different choices of similarity metric. - Both stores (with "cosine" and "euclidea" metrics) contain these two: - - a vector slightly rotated w.r.t query vector - - a vector which is a long multiple of query vector - so, which one is "the closest one" depends on the metric. - """ - emb = ParserEmbeddings(dimension=2) - isq2 = 0.5**0.5 - isa = 0.7 - isb = (1.0 - isa * isa) ** 0.5 - texts = [ - json.dumps([isa, isb]), - json.dumps([10 * isq2, 10 * isq2]), - ] - ids = [ - "rotated", - "scaled", - ] - query_text = json.dumps([isq2, isq2]) - - # prepare empty collections - AstraDBVectorStore( - embedding=emb, - collection_name=COLLECTION_NAME_DIM2, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - ).clear() - AstraDBVectorStore( - embedding=emb, - collection_name=COLLECTION_NAME_DIM2_EUCLIDEAN, - metric="euclidean", - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - ).clear() - - # creation, population, query - cosine - vstore_cos = AstraDBVectorStore( - embedding=emb, - collection_name=COLLECTION_NAME_DIM2, - metric="cosine", - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - ) - try: - vstore_cos.add_texts( - texts=texts, - ids=ids, - ) - _, _, id_from_cos = vstore_cos.similarity_search_with_score_id( - query_text, - k=1, - )[0] - assert id_from_cos == "scaled" - finally: - if not SKIP_COLLECTION_DELETE: - vstore_cos.delete_collection() - else: - vstore_cos.clear() - # creation, population, query - euclidean - - vstore_euc = AstraDBVectorStore( - embedding=emb, - collection_name=COLLECTION_NAME_DIM2_EUCLIDEAN, - metric="euclidean", - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - ) - try: - vstore_euc.add_texts( - texts=texts, - ids=ids, - ) - _, _, id_from_euc = vstore_euc.similarity_search_with_score_id( - query_text, - k=1, - )[0] - assert id_from_euc == "rotated" - finally: - if not SKIP_COLLECTION_DELETE: - vstore_euc.delete_collection() - else: - vstore_euc.clear() - - def test_astradb_vectorstore_indexing_sync( - self, - astra_db_credentials: dict[str, str | None], - database: Database, - ) -> None: - """Test that the right errors/warnings are issued depending - on the compatibility of on-DB indexing settings and the requested ones. - - We do NOT check for substrings in the warning messages: that would - be too brittle a test. - """ - embe = SomeEmbeddings(dimension=2) - - # creation of three collections to test warnings against - database.create_collection("lc_legacy_coll", dimension=2, metric=None) - AstraDBVectorStore( - collection_name="lc_default_idx", - embedding=embe, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - ) - AstraDBVectorStore( - collection_name="lc_custom_idx", - embedding=embe, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - metadata_indexing_exclude={"long_summary", "the_divine_comedy"}, - ) - - # these invocations should just work without warnings - with warnings.catch_warnings(): - warnings.simplefilter("error") - AstraDBVectorStore( - collection_name="lc_default_idx", - embedding=embe, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - ) - AstraDBVectorStore( - collection_name="lc_custom_idx", - embedding=embe, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - metadata_indexing_exclude={"long_summary", "the_divine_comedy"}, - ) - - # some are to throw an error: - with pytest.raises(ValueError, match=INCOMPATIBLE_INDEXING_MSG): - AstraDBVectorStore( - collection_name="lc_default_idx", - embedding=embe, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - metadata_indexing_exclude={"long_summary", "the_divine_comedy"}, - ) - - with pytest.raises(ValueError, match=INCOMPATIBLE_INDEXING_MSG): - AstraDBVectorStore( - collection_name="lc_custom_idx", - embedding=embe, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - metadata_indexing_exclude={"changed_fields"}, - ) - - with pytest.raises(ValueError, match=INCOMPATIBLE_INDEXING_MSG): - AstraDBVectorStore( - collection_name="lc_custom_idx", - embedding=embe, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - ) - - with pytest.raises( - ValueError, - match="Astra DB collection 'lc_legacy_coll' is detected as having " - "indexing turned on for all fields", - ): - AstraDBVectorStore( - collection_name="lc_legacy_coll", - embedding=embe, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - metadata_indexing_exclude={"long_summary", "the_divine_comedy"}, - ) - - # one case should result in just a warning: - with pytest.warns(UserWarning) as rec_warnings: - AstraDBVectorStore( - collection_name="lc_legacy_coll", - embedding=embe, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - ) - f_rec_warnings = [ - wrn for wrn in rec_warnings if issubclass(wrn.category, UserWarning) - ] - assert len(f_rec_warnings) == 1 - - # cleanup - database.drop_collection("lc_legacy_coll") - database.drop_collection("lc_default_idx") - database.drop_collection("lc_custom_idx") - - async def test_astradb_vectorstore_indexing_async( - self, - astra_db_credentials: dict[str, str | None], - database: Database, - ) -> None: - """Async version of the same test on warnings/errors related - to incompatible indexing choices. - """ - embe = SomeEmbeddings(dimension=2) - - # creation of three collections to test warnings against - database.create_collection("lc_legacy_coll", dimension=2, metric=None) - await AstraDBVectorStore( - collection_name="lc_default_idx", - embedding=embe, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - setup_mode=SetupMode.ASYNC, - ).asimilarity_search("boo") - await AstraDBVectorStore( - collection_name="lc_custom_idx", - embedding=embe, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - metadata_indexing_exclude={"long_summary", "the_divine_comedy"}, - setup_mode=SetupMode.ASYNC, - ).asimilarity_search("boo") - - # these invocations should just work without warnings - with warnings.catch_warnings(): - warnings.simplefilter("error") - def_store = AstraDBVectorStore( - collection_name="lc_default_idx", - embedding=embe, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - setup_mode=SetupMode.ASYNC, - ) - await def_store.aadd_texts(["All good."]) - cus_store = AstraDBVectorStore( - collection_name="lc_custom_idx", - embedding=embe, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - metadata_indexing_exclude={"long_summary", "the_divine_comedy"}, - setup_mode=SetupMode.ASYNC, - ) - await cus_store.aadd_texts(["All good."]) - - # some are to throw an error: - def_store = AstraDBVectorStore( - collection_name="lc_default_idx", - embedding=embe, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - metadata_indexing_exclude={"long_summary", "the_divine_comedy"}, - setup_mode=SetupMode.ASYNC, - ) - with pytest.raises(ValueError, match=INCOMPATIBLE_INDEXING_MSG): - await def_store.aadd_texts(["Not working."]) - - cus_store = AstraDBVectorStore( - collection_name="lc_custom_idx", - embedding=embe, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - metadata_indexing_exclude={"changed_fields"}, - setup_mode=SetupMode.ASYNC, - ) - with pytest.raises(ValueError, match=INCOMPATIBLE_INDEXING_MSG): - await cus_store.aadd_texts(["Not working."]) - - cus_store = AstraDBVectorStore( - collection_name="lc_custom_idx", - embedding=embe, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - setup_mode=SetupMode.ASYNC, - ) - with pytest.raises(ValueError, match=INCOMPATIBLE_INDEXING_MSG): - await cus_store.aadd_texts(["Not working."]) - - leg_store = AstraDBVectorStore( - collection_name="lc_legacy_coll", - embedding=embe, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - metadata_indexing_exclude={"long_summary", "the_divine_comedy"}, - setup_mode=SetupMode.ASYNC, - ) - with pytest.raises( - ValueError, - match="Astra DB collection 'lc_legacy_coll' is detected as having " - "indexing turned on for all fields", - ): - await leg_store.aadd_texts(["Not working."]) - - # one case should result in just a warning: - with pytest.warns(UserWarning) as rec_warnings: - leg_store = AstraDBVectorStore( - collection_name="lc_legacy_coll", - embedding=embe, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - setup_mode=SetupMode.ASYNC, - ) - await leg_store.aadd_texts(["Triggering warning."]) - # cleaning out 'spurious' "unclosed socket/transport..." warnings - f_rec_warnings = [ - wrn for wrn in rec_warnings if issubclass(wrn.category, UserWarning) - ] - assert len(f_rec_warnings) == 1 - - await database.to_async().drop_collection("lc_legacy_coll") - await database.to_async().drop_collection("lc_default_idx") - await database.to_async().drop_collection("lc_custom_idx") - - @pytest.mark.skipif( - os.environ.get("ASTRA_DB_ENVIRONMENT", "prod").upper() != "PROD", - reason="Can run on Astra DB prod only", - ) - def test_astradb_vectorstore_coreclients_init_sync( - self, - astra_db_credentials: dict[str, str | None], - core_astra_db: AstraDB, - ) -> None: - """A deprecation warning from passing a (core) AstraDB, but it works.""" - collection_name = "lc_test_vstore_coreclsync" - emb = SomeEmbeddings(dimension=2) - - try: - v_store_init_ok = AstraDBVectorStore( - embedding=emb, - collection_name=collection_name, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - ) - v_store_init_ok.add_texts(["One text"]) - - with pytest.warns(DeprecationWarning) as rec_warnings: - v_store_init_core = AstraDBVectorStore( - embedding=emb, - collection_name=collection_name, - astra_db_client=core_astra_db, - ) - - results = v_store_init_core.similarity_search("another", k=1) - # cleaning out 'spurious' "unclosed socket/transport..." warnings - f_rec_warnings = [ - wrn - for wrn in rec_warnings - if issubclass(wrn.category, DeprecationWarning) - ] - assert len(f_rec_warnings) == 1 - assert len(results) == 1 - assert results[0].page_content == "One text" - finally: - if not SKIP_COLLECTION_DELETE: - v_store_init_ok.delete_collection() - else: - v_store_init_ok.clear() - - @pytest.mark.skipif( - os.environ.get("ASTRA_DB_ENVIRONMENT", "prod").upper() != "PROD", - reason="Can run on Astra DB prod only", - ) - async def test_astradb_vectorstore_coreclients_init_async( - self, - astra_db_credentials: dict[str, str | None], - core_astra_db: AstraDB, - ) -> None: - """A deprecation warning from passing a (core) AstraDB, but it works.""" - collection_name = "lc_test_vstore_coreclasync" - emb = SomeEmbeddings(dimension=2) - - try: - v_store_init_ok = AstraDBVectorStore( - embedding=emb, - collection_name=collection_name, - token=astra_db_credentials["token"], - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - setup_mode=SetupMode.ASYNC, - ) - await v_store_init_ok.aadd_texts(["One text"]) - - with pytest.warns(DeprecationWarning) as rec_warnings: - v_store_init_core = AstraDBVectorStore( - embedding=emb, - collection_name=collection_name, - astra_db_client=core_astra_db, - setup_mode=SetupMode.ASYNC, - ) - - results = await v_store_init_core.asimilarity_search("another", k=1) - f_rec_warnings = [ - wrn - for wrn in rec_warnings - if issubclass(wrn.category, DeprecationWarning) - ] - assert len(f_rec_warnings) == 1 - assert len(results) == 1 - assert results[0].page_content == "One text" - finally: - if not SKIP_COLLECTION_DELETE: - await v_store_init_ok.adelete_collection() - else: - await v_store_init_ok.aclear() diff --git a/libs/astradb/tests/unit_tests/test_vectorstores.py b/libs/astradb/tests/unit_tests/test_vectorstores.py index f0a2fab..51a5ac2 100644 --- a/libs/astradb/tests/unit_tests/test_vectorstores.py +++ b/libs/astradb/tests/unit_tests/test_vectorstores.py @@ -7,7 +7,7 @@ DEFAULT_INDEXING_OPTIONS, AstraDBVectorStore, ) -from tests.conftest import SomeEmbeddings +from tests.conftest import ParserEmbeddings FAKE_TOKEN = "t" # noqa: S105 @@ -25,7 +25,7 @@ def test_initialization(self) -> None: api_endpoint=a_e_string, namespace="n", ) - embedding = SomeEmbeddings(dimension=2) + embedding = ParserEmbeddings(dimension=2) with pytest.warns(DeprecationWarning): AstraDBVectorStore( embedding=embedding,