Skip to content

Commit

Permalink
Fix some blockings
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet committed Nov 27, 2024
1 parent cc8a201 commit 40abd35
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 36 deletions.
6 changes: 3 additions & 3 deletions libs/astradb/langchain_astradb/utils/astradb.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,9 +425,9 @@ async def _asetup_db(
except DataAPIException as data_api_exception:
# possibly the collection is preexisting and may have legacy,
# or custom, indexing settings: verify
collection_descriptors = [
coll_desc async for coll_desc in self.async_database.list_collections()
]
collection_descriptors = list(
await asyncio.to_thread(self.database.list_collections)
)
try:
if not self._validate_indexing_policy(
collection_descriptors=collection_descriptors,
Expand Down
51 changes: 37 additions & 14 deletions libs/astradb/tests/integration_tests/test_graphvectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from __future__ import annotations

import asyncio
from typing import TYPE_CHECKING, Any

import pytest
Expand Down Expand Up @@ -400,7 +401,9 @@ async def test_gvs_similarity_search_async(
request: pytest.FixtureRequest,
) -> None:
"""Simple (non-graph) similarity search on a graph vector store."""
g_store: AstraDBGraphVectorStore = request.getfixturevalue(store_name)
g_store: AstraDBGraphVectorStore = await asyncio.to_thread(
request.getfixturevalue, store_name
)
query = "universe" if is_vectorize else "[2, 10]"
embedding = [2.0, 10.0]

Expand All @@ -421,8 +424,10 @@ async def test_gvs_similarity_search_async(
assert ss_by_v_labels == ["AR", "A0"]

if is_autodetected:
assert_all_flat_docs(
g_store.vector_store.astra_env.collection, is_vectorize=is_vectorize
await asyncio.to_thread(
assert_all_flat_docs,
g_store.vector_store.astra_env.collection,
is_vectorize=is_vectorize,
)

@pytest.mark.parametrize(
Expand Down Expand Up @@ -488,7 +493,9 @@ async def test_gvs_traversal_search_async(
request: pytest.FixtureRequest,
) -> None:
"""Graph traversal search on a graph vector store."""
g_store: AstraDBGraphVectorStore = request.getfixturevalue(store_name)
g_store: AstraDBGraphVectorStore = await asyncio.to_thread(
request.getfixturevalue, store_name
)
query = "universe" if is_vectorize else "[2, 10]"

# this is a set, as some of the internals of trav.search are set-driven
Expand All @@ -499,8 +506,10 @@ async def test_gvs_traversal_search_async(
}
assert ts_labels == {"AR", "A0", "BR", "B0", "TR", "T0"}
if is_autodetected:
assert_all_flat_docs(
g_store.vector_store.astra_env.collection, is_vectorize=is_vectorize
await asyncio.to_thread(
assert_all_flat_docs,
g_store.vector_store.astra_env.collection,
is_vectorize=is_vectorize,
)

@pytest.mark.parametrize(
Expand Down Expand Up @@ -572,7 +581,9 @@ async def test_gvs_mmr_traversal_search_async(
request: pytest.FixtureRequest,
) -> None:
"""MMR Graph traversal search on a graph vector store."""
g_store: AstraDBGraphVectorStore = request.getfixturevalue(store_name)
g_store: AstraDBGraphVectorStore = await asyncio.to_thread(
request.getfixturevalue, store_name
)
query = "universe" if is_vectorize else "[2, 10]"

mt_labels = [
Expand All @@ -589,8 +600,10 @@ async def test_gvs_mmr_traversal_search_async(

assert mt_labels == ["AR", "BR"]
if is_autodetected:
assert_all_flat_docs(
g_store.vector_store.astra_env.collection, is_vectorize=is_vectorize
await asyncio.to_thread(
assert_all_flat_docs,
g_store.vector_store.astra_env.collection,
is_vectorize=is_vectorize,
)

@pytest.mark.parametrize(
Expand Down Expand Up @@ -652,7 +665,9 @@ async def test_gvs_metadata_search_async(
request: pytest.FixtureRequest,
) -> None:
"""Metadata search on a graph vector store."""
g_store: AstraDBGraphVectorStore = request.getfixturevalue(store_name)
g_store: AstraDBGraphVectorStore = await asyncio.to_thread(
request.getfixturevalue, store_name
)
mt_response = await g_store.ametadata_search(
filter={"label": "T0"},
n=2,
Expand Down Expand Up @@ -726,7 +741,9 @@ async def test_gvs_get_by_document_id_async(
request: pytest.FixtureRequest,
) -> None:
"""Get by document_id on a graph vector store."""
g_store: AstraDBGraphVectorStore = request.getfixturevalue(store_name)
g_store: AstraDBGraphVectorStore = await asyncio.to_thread(
request.getfixturevalue, store_name
)
doc = await g_store.aget_by_document_id(document_id="FL")
assert doc is not None
assert doc.metadata["label"] == "FL"
Expand Down Expand Up @@ -816,7 +833,9 @@ async def test_gvs_from_texts_async(
collection_fixture_name: str,
request: pytest.FixtureRequest,
) -> None:
collection_name: str = request.getfixturevalue(collection_fixture_name)
collection_name: str = asyncio.to_thread(
request.getfixturevalue, collection_fixture_name
)
init_kwargs: dict[str, Any]
if is_vectorize:
init_kwargs = {
Expand Down Expand Up @@ -915,7 +934,9 @@ async def test_gvs_from_documents_containing_ids_async(
collection_fixture_name: str,
request: pytest.FixtureRequest,
) -> None:
collection_name: str = request.getfixturevalue(collection_fixture_name)
collection_name: str = await asyncio.to_thread(
request.getfixturevalue, collection_fixture_name
)
init_kwargs: dict[str, Any]
if is_vectorize:
init_kwargs = {
Expand Down Expand Up @@ -1005,7 +1026,9 @@ async def test_gvs_add_nodes_async(
store_name: str,
request: pytest.FixtureRequest,
) -> None:
g_store: AstraDBGraphVectorStore = request.getfixturevalue(store_name)
g_store: AstraDBGraphVectorStore = asyncio.to_thread(
request.getfixturevalue, store_name
)
links0 = [
Link(kind="kA", direction="out", tag="tA"),
Link(kind="kB", direction="bidir", tag="tB"),
Expand Down
6 changes: 3 additions & 3 deletions libs/astradb/tests/integration_tests/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ async def astra_db_empty_store_async(
astra_db_credentials: AstraDBCredentials,
collection_idxid: Collection,
) -> AstraDBStore:
collection_idxid.delete_many({})
await collection_idxid.to_async().delete_many({})
return AstraDBStore(
collection_name=collection_idxid.name,
token=StaticTokenProvider(astra_db_credentials["token"]),
Expand Down Expand Up @@ -364,7 +364,7 @@ async def test_store_indexing_on_legacy_async(
database: Database,
) -> None:
"""Test of instantiation against a legacy collection, async version."""
database.create_collection(
await database.to_async().create_collection(
EPHEMERAL_LEGACY_IDX_NAME,
indexing=None,
check_exists=False,
Expand Down Expand Up @@ -413,7 +413,7 @@ async def test_store_indexing_on_custom_async(
database: Database,
) -> None:
"""Test of instantiation against a legacy collection, async version."""
database.create_collection(
await database.to_async().create_collection(
EPHEMERAL_CUSTOM_IDX_NAME,
indexing={"deny": ["useless", "forgettable"]},
check_exists=False,
Expand Down
43 changes: 30 additions & 13 deletions libs/astradb/tests/integration_tests/test_vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import asyncio
import json
import math
import os
Expand Down Expand Up @@ -371,7 +372,9 @@ async def test_astradb_vectorstore_from_texts_async(
request: pytest.FixtureRequest,
) -> None:
"""from_texts methods and the associated warnings, async version."""
collection: Collection = request.getfixturevalue(collection_fixture_name)
collection: Collection = await asyncio.to_thread(
request.getfixturevalue, collection_fixture_name
)
init_kwargs: dict[str, Any]
if is_vectorize:
init_kwargs = {
Expand Down Expand Up @@ -481,7 +484,9 @@ async def test_astradb_vectorstore_from_documents_async(
from_documents, esp. the various handling of ID-in-doc vs external.
Async version.
"""
collection: Collection = request.getfixturevalue(collection_fixture_name)
collection: Collection = await asyncio.to_thread(
request.getfixturevalue, collection_fixture_name
)
pc1, pc2 = page_contents
init_kwargs: dict[str, Any]
if is_vectorize:
Expand Down Expand Up @@ -509,7 +514,7 @@ async def test_astradb_vectorstore_from_documents_async(
assert len(hits) == 1
assert hits[0].page_content == pc2
assert hits[0].metadata == {"m": 3}
v_store.clear()
await v_store.aclear()

# IDs passed separately.
with pytest.warns(DeprecationWarning) as rec_warnings:
Expand All @@ -536,7 +541,7 @@ async def test_astradb_vectorstore_from_documents_async(
assert hits[0].page_content == pc2
assert hits[0].metadata == {"m": 3}
assert hits[0].id == "idx3"
v_store_2.clear()
await v_store_2.aclear()

# IDs in documents.
v_store_3 = await AstraDBVectorStore.afrom_documents(
Expand All @@ -557,7 +562,7 @@ async def test_astradb_vectorstore_from_documents_async(
assert hits[0].page_content == pc2
assert hits[0].metadata == {"m": 3}
assert hits[0].id == "idx3"
v_store_3.clear()
await v_store_3.aclear()

# IDs both in documents and aside.
with pytest.warns(DeprecationWarning) as rec_warnings:
Expand Down Expand Up @@ -692,7 +697,9 @@ async def test_astradb_vectorstore_crud_async(
request: pytest.FixtureRequest,
) -> None:
"""Add/delete/update behaviour, async version."""
vstore: AstraDBVectorStore = request.getfixturevalue(vector_store)
vstore: AstraDBVectorStore = await asyncio.to_thread(
request.getfixturevalue, vector_store
)

res0 = await vstore.asimilarity_search("[-1,-1]", k=2)
assert res0 == []
Expand Down Expand Up @@ -1242,13 +1249,15 @@ async def test_astradb_vectorstore_metadata_search_async(
metadata_documents: list[Document],
) -> None:
"""Metadata Search"""
vstore: AstraDBVectorStore = request.getfixturevalue(vector_store)
vstore: AstraDBVectorStore = await asyncio.to_thread(
request.getfixturevalue, vector_store
)
await vstore.aadd_documents(metadata_documents)
# no filters
res0 = await vstore.ametadata_search(filter={}, n=10)
assert {doc.metadata["letter"] for doc in res0} == set("qwreio")
# single filter
res1 = vstore.metadata_search(
res1 = await vstore.ametadata_search(
n=10,
filter={"group": "vowel"},
)
Expand Down Expand Up @@ -1313,7 +1322,9 @@ async def test_astradb_vectorstore_get_by_document_id_async(
metadata_documents: list[Document],
) -> None:
"""Get by document_id"""
vstore: AstraDBVectorStore = request.getfixturevalue(vector_store)
vstore: AstraDBVectorStore = await asyncio.to_thread(
request.getfixturevalue, vector_store
)
await vstore.aadd_documents(metadata_documents)
# invalid id
invalid = await vstore.aget_by_document_id(document_id="z")
Expand Down Expand Up @@ -1398,7 +1409,9 @@ async def test_astradb_vectorstore_similarity_scale_async(
request: pytest.FixtureRequest,
) -> None:
"""Scale of the similarity scores, async version."""
vstore: AstraDBVectorStore = request.getfixturevalue(vector_store)
vstore: AstraDBVectorStore = await asyncio.to_thread(
request.getfixturevalue, vector_store
)
await vstore.aadd_texts(
texts=texts,
ids=["near", "far"],
Expand Down Expand Up @@ -1432,7 +1445,9 @@ async def test_astradb_vectorstore_asimilarity_search_with_embedding(
"""asimilarity_search_with_embedding is used as the building
block for other components (like AstraDBGraphVectorStore).
"""
vstore: AstraDBVectorStore = request.getfixturevalue(vector_store)
vstore: AstraDBVectorStore = await asyncio.to_thread(
request.getfixturevalue, vector_store
)
await vstore.aadd_documents(metadata_documents)

query_embedding, results = await vstore.asimilarity_search_with_embedding(
Expand Down Expand Up @@ -1465,7 +1480,9 @@ async def test_astradb_vectorstore_asimilarity_search_with_embedding_by_vector(
"""asimilarity_search_with_embedding_by_vector is used as the building
block for other components (like AstraDBGraphVectorStore).
"""
vstore: AstraDBVectorStore = request.getfixturevalue(vector_store)
vstore: AstraDBVectorStore = await asyncio.to_thread(
request.getfixturevalue, vector_store
)
await vstore.aadd_documents(metadata_documents)

vector_dimensions = 1536 if is_vectorize else 2
Expand Down Expand Up @@ -1748,7 +1765,7 @@ async def test_astradb_vectorstore_coreclients_init_async(
Expect a deprecation warning from passing a (core) AstraDB class,
but it must work. Async version.
"""
vector_store_d2.add_texts(["[1,2]"])
await vector_store_d2.aadd_texts(["[1,2]"])

with pytest.warns(DeprecationWarning) as rec_warnings:
v_store_init_core = AstraDBVectorStore(
Expand Down
13 changes: 10 additions & 3 deletions libs/astradb/tests/integration_tests/test_vectorstore_ddl_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from __future__ import annotations

import asyncio
import warnings
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -96,10 +97,13 @@ async def test_astradb_vectorstore_create_delete_async(
namespace=astra_db_credentials["namespace"],
environment=astra_db_credentials["environment"],
metric="cosine",
setup_mode=SetupMode.ASYNC,
)
await v_store.aadd_texts(["[1,2]"])
await v_store.adelete_collection()
assert ephemeral_collection_cleaner_d2 not in database.list_collection_names()
assert ephemeral_collection_cleaner_d2 not in await asyncio.to_thread(
database.list_collection_names
)

async def test_astradb_vectorstore_create_delete_vectorize_async(
self,
Expand All @@ -118,10 +122,13 @@ async def test_astradb_vectorstore_create_delete_vectorize_async(
metric="cosine",
collection_vector_service_options=OPENAI_VECTORIZE_OPTIONS_HEADER,
collection_embedding_api_key=openai_api_key,
setup_mode=SetupMode.ASYNC,
)
await v_store.aadd_texts(["[1,2]"])
await v_store.adelete_collection()
assert ephemeral_collection_cleaner_vz not in database.list_collection_names()
assert ephemeral_collection_cleaner_vz not in await asyncio.to_thread(
database.list_collection_names
)

def test_astradb_vectorstore_pre_delete_collection_sync(
self,
Expand Down Expand Up @@ -340,7 +347,7 @@ async def test_astradb_vectorstore_indexing_legacy_async(
Test of the vector store behaviour for various indexing settings,
with an existing 'legacy' collection (i.e. unspecified indexing policy).
"""
database.create_collection(
await database.to_async().create_collection(
EPHEMERAL_LEGACY_IDX_NAME_D2,
dimension=2,
check_exists=False,
Expand Down

0 comments on commit 40abd35

Please sign in to comment.