From bebe401b1aff59d7bce22c54bc25bfdc6811a515 Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Wed, 21 Feb 2024 01:54:35 +0100 Subject: [PATCH] astradb[patch]: Add AstraDBStore to langchain-astradb package (#17789) Co-authored-by: Erick Friis --- .../langchain_community/storage/astradb.py | 11 + .../astradb/langchain_astradb/__init__.py | 3 + .../astradb/langchain_astradb/storage.py | 217 ++++++++++++++++++ .../langchain_astradb/utils/astradb.py | 142 ++++++++++++ .../tests/integration_tests/test_storage.py | 176 ++++++++++++++ .../astradb/tests/unit_tests/test_imports.py | 2 + 6 files changed, 551 insertions(+) create mode 100644 libs/partners/astradb/langchain_astradb/storage.py create mode 100644 libs/partners/astradb/langchain_astradb/utils/astradb.py create mode 100644 libs/partners/astradb/tests/integration_tests/test_storage.py diff --git a/libs/community/langchain_community/storage/astradb.py b/libs/community/langchain_community/storage/astradb.py index 959ef374124c7..a486f8851b544 100644 --- a/libs/community/langchain_community/storage/astradb.py +++ b/libs/community/langchain_community/storage/astradb.py @@ -15,6 +15,7 @@ TypeVar, ) +from langchain_core._api.deprecation import deprecated from langchain_core.stores import BaseStore, ByteStore from langchain_community.utilities.astradb import ( @@ -124,6 +125,11 @@ async def ayield_keys(self, *, prefix: Optional[str] = None) -> AsyncIterator[st yield key +@deprecated( + since="0.0.22", + removal="0.2.0", + alternative_import="langchain_astradb.AstraDBStore", +) class AstraDBStore(AstraDBBaseStore[Any]): """BaseStore implementation using DataStax AstraDB as the underlying store. @@ -143,6 +149,11 @@ def encode_value(self, value: Any) -> Any: return value +@deprecated( + since="0.0.22", + removal="0.2.0", + alternative_import="langchain_astradb.AstraDBByteStore", +) class AstraDBByteStore(AstraDBBaseStore[bytes], ByteStore): """ByteStore implementation using DataStax AstraDB as the underlying store. diff --git a/libs/partners/astradb/langchain_astradb/__init__.py b/libs/partners/astradb/langchain_astradb/__init__.py index fc86dd73bcf46..85ce86d32abb8 100644 --- a/libs/partners/astradb/langchain_astradb/__init__.py +++ b/libs/partners/astradb/langchain_astradb/__init__.py @@ -1,5 +1,8 @@ +from langchain_astradb.storage import AstraDBByteStore, AstraDBStore from langchain_astradb.vectorstores import AstraDBVectorStore __all__ = [ + "AstraDBByteStore", + "AstraDBStore", "AstraDBVectorStore", ] diff --git a/libs/partners/astradb/langchain_astradb/storage.py b/libs/partners/astradb/langchain_astradb/storage.py new file mode 100644 index 0000000000000..119e82c2cf496 --- /dev/null +++ b/libs/partners/astradb/langchain_astradb/storage.py @@ -0,0 +1,217 @@ +from __future__ import annotations + +import base64 +from abc import ABC, abstractmethod +from typing import ( + Any, + AsyncIterator, + Generic, + Iterator, + List, + Optional, + Sequence, + Tuple, + TypeVar, +) + +from astrapy.db import AstraDB, AsyncAstraDB +from langchain_core.stores import BaseStore, ByteStore + +from langchain_astradb.utils.astradb import ( + SetupMode, + _AstraDBCollectionEnvironment, +) + +V = TypeVar("V") + + +class AstraDBBaseStore(Generic[V], BaseStore[str, V], ABC): + """Base class for the DataStax AstraDB data store.""" + + def __init__(self, *args: Any, **kwargs: Any) -> None: + self.astra_env = _AstraDBCollectionEnvironment(*args, **kwargs) + self.collection = self.astra_env.collection + self.async_collection = self.astra_env.async_collection + + @abstractmethod + def decode_value(self, value: Any) -> Optional[V]: + """Decodes value from Astra DB""" + + @abstractmethod + def encode_value(self, value: Optional[V]) -> Any: + """Encodes value for Astra DB""" + + def mget(self, keys: Sequence[str]) -> List[Optional[V]]: + self.astra_env.ensure_db_setup() + docs_dict = {} + for doc in self.collection.paginated_find(filter={"_id": {"$in": list(keys)}}): + docs_dict[doc["_id"]] = doc.get("value") + return [self.decode_value(docs_dict.get(key)) for key in keys] + + async def amget(self, keys: Sequence[str]) -> List[Optional[V]]: + await self.astra_env.aensure_db_setup() + docs_dict = {} + async for doc in self.async_collection.paginated_find( + filter={"_id": {"$in": list(keys)}} + ): + docs_dict[doc["_id"]] = doc.get("value") + return [self.decode_value(docs_dict.get(key)) for key in keys] + + def mset(self, key_value_pairs: Sequence[Tuple[str, V]]) -> None: + self.astra_env.ensure_db_setup() + for k, v in key_value_pairs: + self.collection.upsert({"_id": k, "value": self.encode_value(v)}) + + async def amset(self, key_value_pairs: Sequence[Tuple[str, V]]) -> None: + await self.astra_env.aensure_db_setup() + for k, v in key_value_pairs: + await self.async_collection.upsert( + {"_id": k, "value": self.encode_value(v)} + ) + + def mdelete(self, keys: Sequence[str]) -> None: + self.astra_env.ensure_db_setup() + self.collection.delete_many(filter={"_id": {"$in": list(keys)}}) + + async def amdelete(self, keys: Sequence[str]) -> None: + await self.astra_env.aensure_db_setup() + await self.async_collection.delete_many(filter={"_id": {"$in": list(keys)}}) + + def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]: + self.astra_env.ensure_db_setup() + docs = self.collection.paginated_find() + for doc in docs: + key = doc["_id"] + if not prefix or key.startswith(prefix): + yield key + + async def ayield_keys(self, *, prefix: Optional[str] = None) -> AsyncIterator[str]: + await self.astra_env.aensure_db_setup() + async for doc in self.async_collection.paginated_find(): + key = doc["_id"] + if not prefix or key.startswith(prefix): + yield key + + +class AstraDBStore(AstraDBBaseStore[Any]): + def __init__( + self, + collection_name: str, + *, + token: Optional[str] = None, + api_endpoint: Optional[str] = None, + astra_db_client: Optional[AstraDB] = None, + namespace: Optional[str] = None, + async_astra_db_client: Optional[AsyncAstraDB] = None, + pre_delete_collection: bool = False, + setup_mode: SetupMode = SetupMode.SYNC, + ) -> None: + """BaseStore implementation using DataStax AstraDB as the underlying store. + + The value type can be any type serializable by json.dumps. + Can be used to store embeddings with the CacheBackedEmbeddings. + + Documents in the AstraDB collection will have the format + + .. code-block:: json + { + "_id": "", + "value": + } + + Args: + collection_name: name of the Astra DB collection to create/use. + token: API token for Astra DB usage. + api_endpoint: full URL to the API endpoint, + such as `https://-us-east1.apps.astra.datastax.com`. + astra_db_client: *alternative to token+api_endpoint*, + you can pass an already-created 'astrapy.db.AstraDB' instance. + async_astra_db_client: *alternative to token+api_endpoint*, + you can pass an already-created 'astrapy.db.AsyncAstraDB' instance. + namespace: namespace (aka keyspace) where the + collection is created. Defaults to the database's "default namespace". + setup_mode: mode used to create the Astra DB collection (SYNC, ASYNC or + OFF). + pre_delete_collection: whether to delete the collection + before creating it. If False and the collection already exists, + the collection will be used as is. + """ + super().__init__( + collection_name=collection_name, + token=token, + api_endpoint=api_endpoint, + astra_db_client=astra_db_client, + async_astra_db_client=async_astra_db_client, + namespace=namespace, + setup_mode=setup_mode, + pre_delete_collection=pre_delete_collection, + ) + + def decode_value(self, value: Any) -> Any: + return value + + def encode_value(self, value: Any) -> Any: + return value + + +class AstraDBByteStore(AstraDBBaseStore[bytes], ByteStore): + def __init__( + self, + *, + collection_name: str, + token: Optional[str] = None, + api_endpoint: Optional[str] = None, + astra_db_client: Optional[AstraDB] = None, + namespace: Optional[str] = None, + async_astra_db_client: Optional[AsyncAstraDB] = None, + pre_delete_collection: bool = False, + setup_mode: SetupMode = SetupMode.SYNC, + ) -> None: + """ByteStore implementation using DataStax AstraDB as the underlying store. + + The bytes values are converted to base64 encoded strings + Documents in the AstraDB collection will have the format + + .. code-block:: json + { + "_id": "", + "value": "" + } + + Args: + collection_name: name of the Astra DB collection to create/use. + token: API token for Astra DB usage. + api_endpoint: full URL to the API endpoint, + such as `https://-us-east1.apps.astra.datastax.com`. + astra_db_client: *alternative to token+api_endpoint*, + you can pass an already-created 'astrapy.db.AstraDB' instance. + async_astra_db_client: *alternative to token+api_endpoint*, + you can pass an already-created 'astrapy.db.AsyncAstraDB' instance. + namespace: namespace (aka keyspace) where the + collection is created. Defaults to the database's "default namespace". + setup_mode: mode used to create the Astra DB collection (SYNC, ASYNC or + OFF). + pre_delete_collection: whether to delete the collection + before creating it. If False and the collection already exists, + the collection will be used as is. + """ + super().__init__( + collection_name=collection_name, + token=token, + api_endpoint=api_endpoint, + astra_db_client=astra_db_client, + async_astra_db_client=async_astra_db_client, + namespace=namespace, + setup_mode=setup_mode, + pre_delete_collection=pre_delete_collection, + ) + + def decode_value(self, value: Any) -> Optional[bytes]: + if value is None: + return None + return base64.b64decode(value) + + def encode_value(self, value: Optional[bytes]) -> Any: + if value is None: + return None + return base64.b64encode(value).decode("ascii") diff --git a/libs/partners/astradb/langchain_astradb/utils/astradb.py b/libs/partners/astradb/langchain_astradb/utils/astradb.py new file mode 100644 index 0000000000000..dc2c24c3ff89c --- /dev/null +++ b/libs/partners/astradb/langchain_astradb/utils/astradb.py @@ -0,0 +1,142 @@ +from __future__ import annotations + +import asyncio +import inspect +from asyncio import InvalidStateError, Task +from enum import Enum +from typing import Awaitable, Optional, Union + +from astrapy.db import AstraDB, AsyncAstraDB + + +class SetupMode(Enum): + SYNC = 1 + ASYNC = 2 + OFF = 3 + + +class _AstraDBEnvironment: + def __init__( + self, + token: Optional[str] = None, + api_endpoint: Optional[str] = None, + astra_db_client: Optional[AstraDB] = None, + async_astra_db_client: Optional[AsyncAstraDB] = None, + namespace: Optional[str] = None, + ) -> None: + self.token = token + self.api_endpoint = api_endpoint + astra_db = astra_db_client + async_astra_db = async_astra_db_client + self.namespace = namespace + + # Conflicting-arg checks: + if astra_db_client is not None or async_astra_db_client is not None: + if token is not None or api_endpoint is not None: + raise ValueError( + "You cannot pass 'astra_db_client' or 'async_astra_db_client' to " + "AstraDBEnvironment if passing 'token' and 'api_endpoint'." + ) + + if token and api_endpoint: + astra_db = AstraDB( + token=token, + api_endpoint=api_endpoint, + namespace=self.namespace, + ) + async_astra_db = AsyncAstraDB( + token=token, + api_endpoint=api_endpoint, + namespace=self.namespace, + ) + + if astra_db: + self.astra_db = astra_db + if async_astra_db: + self.async_astra_db = async_astra_db + else: + self.async_astra_db = self.astra_db.to_async() + elif async_astra_db: + self.async_astra_db = async_astra_db + self.astra_db = self.async_astra_db.to_sync() + else: + raise ValueError( + "Must provide 'astra_db_client' or 'async_astra_db_client' or " + "'token' and 'api_endpoint'" + ) + + +class _AstraDBCollectionEnvironment(_AstraDBEnvironment): + def __init__( + self, + collection_name: str, + token: Optional[str] = None, + api_endpoint: Optional[str] = None, + astra_db_client: Optional[AstraDB] = None, + async_astra_db_client: Optional[AsyncAstraDB] = None, + namespace: Optional[str] = None, + setup_mode: SetupMode = SetupMode.SYNC, + pre_delete_collection: bool = False, + embedding_dimension: Union[int, Awaitable[int], None] = None, + metric: Optional[str] = None, + ) -> None: + from astrapy.db import AstraDBCollection, AsyncAstraDBCollection + + super().__init__( + token, api_endpoint, astra_db_client, async_astra_db_client, namespace + ) + self.collection_name = collection_name + self.collection = AstraDBCollection( + collection_name=collection_name, + astra_db=self.astra_db, + ) + + self.async_collection = AsyncAstraDBCollection( + collection_name=collection_name, + astra_db=self.async_astra_db, + ) + + self.async_setup_db_task: Optional[Task] = None + if setup_mode == SetupMode.ASYNC: + async_astra_db = self.async_astra_db + + async def _setup_db() -> None: + if pre_delete_collection: + await async_astra_db.delete_collection(collection_name) + if inspect.isawaitable(embedding_dimension): + dimension = await embedding_dimension + else: + dimension = embedding_dimension + await async_astra_db.create_collection( + collection_name, dimension=dimension, metric=metric + ) + + self.async_setup_db_task = asyncio.create_task(_setup_db()) + elif setup_mode == SetupMode.SYNC: + if pre_delete_collection: + self.astra_db.delete_collection(collection_name) + if inspect.isawaitable(embedding_dimension): + raise ValueError( + "Cannot use an awaitable embedding_dimension with async_setup " + "set to False" + ) + self.astra_db.create_collection( + collection_name, + dimension=embedding_dimension, # type: ignore[arg-type] + metric=metric, + ) + + def ensure_db_setup(self) -> None: + if self.async_setup_db_task: + try: + self.async_setup_db_task.result() + except InvalidStateError: + raise ValueError( + "Asynchronous setup of the DB not finished. " + "NB: AstraDB components sync methods shouldn't be called from the " + "event loop. Consider using their async equivalents." + ) + + async def aensure_db_setup(self) -> None: + if self.async_setup_db_task: + await self.async_setup_db_task diff --git a/libs/partners/astradb/tests/integration_tests/test_storage.py b/libs/partners/astradb/tests/integration_tests/test_storage.py new file mode 100644 index 0000000000000..86e919e4df38f --- /dev/null +++ b/libs/partners/astradb/tests/integration_tests/test_storage.py @@ -0,0 +1,176 @@ +"""Implement integration tests for AstraDB storage.""" +from __future__ import annotations + +import os + +import pytest +from astrapy.db import AstraDB, AsyncAstraDB + +from langchain_astradb.storage import AstraDBByteStore, AstraDBStore +from langchain_astradb.utils.astradb import SetupMode + + +def _has_env_vars() -> bool: + return all( + [ + "ASTRA_DB_APPLICATION_TOKEN" in os.environ, + "ASTRA_DB_API_ENDPOINT" in os.environ, + ] + ) + + +@pytest.fixture +def astra_db() -> AstraDB: + from astrapy.db import AstraDB + + return AstraDB( + token=os.environ["ASTRA_DB_APPLICATION_TOKEN"], + api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"], + namespace=os.environ.get("ASTRA_DB_KEYSPACE"), + ) + + +@pytest.fixture +def async_astra_db() -> AsyncAstraDB: + from astrapy.db import AsyncAstraDB + + return AsyncAstraDB( + token=os.environ["ASTRA_DB_APPLICATION_TOKEN"], + api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"], + namespace=os.environ.get("ASTRA_DB_KEYSPACE"), + ) + + +def init_store(astra_db: AstraDB, collection_name: str) -> AstraDBStore: + store = AstraDBStore(collection_name=collection_name, astra_db_client=astra_db) + store.mset([("key1", [0.1, 0.2]), ("key2", "value2")]) + return store + + +def init_bytestore(astra_db: AstraDB, collection_name: str) -> AstraDBByteStore: + store = AstraDBByteStore(collection_name=collection_name, astra_db_client=astra_db) + store.mset([("key1", b"value1"), ("key2", b"value2")]) + return store + + +async def init_async_store( + async_astra_db: AsyncAstraDB, collection_name: str +) -> AstraDBStore: + store = AstraDBStore( + collection_name=collection_name, + async_astra_db_client=async_astra_db, + setup_mode=SetupMode.ASYNC, + ) + await store.amset([("key1", [0.1, 0.2]), ("key2", "value2")]) + return store + + +@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars") +class TestAstraDBStore: + def test_mget(self, astra_db: AstraDB) -> None: + """Test AstraDBStore mget method.""" + collection_name = "lc_test_store_mget" + try: + store = init_store(astra_db, collection_name) + assert store.mget(["key1", "key2"]) == [[0.1, 0.2], "value2"] + finally: + astra_db.delete_collection(collection_name) + + async def test_amget(self, async_astra_db: AsyncAstraDB) -> None: + """Test AstraDBStore amget method.""" + collection_name = "lc_test_store_mget" + try: + store = await init_async_store(async_astra_db, collection_name) + assert await store.amget(["key1", "key2"]) == [[0.1, 0.2], "value2"] + finally: + await async_astra_db.delete_collection(collection_name) + + def test_mset(self, astra_db: AstraDB) -> None: + """Test that multiple keys can be set with AstraDBStore.""" + collection_name = "lc_test_store_mset" + try: + store = init_store(astra_db, collection_name) + result = store.collection.find_one({"_id": "key1"}) + assert result["data"]["document"]["value"] == [0.1, 0.2] + result = store.collection.find_one({"_id": "key2"}) + assert result["data"]["document"]["value"] == "value2" + finally: + astra_db.delete_collection(collection_name) + + async def test_amset(self, async_astra_db: AsyncAstraDB) -> None: + """Test that multiple keys can be set with AstraDBStore.""" + collection_name = "lc_test_store_mset" + try: + store = await init_async_store(async_astra_db, collection_name) + result = await store.async_collection.find_one({"_id": "key1"}) + assert result["data"]["document"]["value"] == [0.1, 0.2] + result = await store.async_collection.find_one({"_id": "key2"}) + assert result["data"]["document"]["value"] == "value2" + finally: + await async_astra_db.delete_collection(collection_name) + + def test_mdelete(self, astra_db: AstraDB) -> None: + """Test that deletion works as expected.""" + collection_name = "lc_test_store_mdelete" + try: + store = init_store(astra_db, collection_name) + store.mdelete(["key1", "key2"]) + result = store.mget(["key1", "key2"]) + assert result == [None, None] + finally: + astra_db.delete_collection(collection_name) + + async def test_amdelete(self, async_astra_db: AsyncAstraDB) -> None: + """Test that deletion works as expected.""" + collection_name = "lc_test_store_mdelete" + try: + store = await init_async_store(async_astra_db, collection_name) + await store.amdelete(["key1", "key2"]) + result = await store.amget(["key1", "key2"]) + assert result == [None, None] + finally: + await async_astra_db.delete_collection(collection_name) + + def test_yield_keys(self, astra_db: AstraDB) -> None: + collection_name = "lc_test_store_yield_keys" + try: + store = init_store(astra_db, collection_name) + assert set(store.yield_keys()) == {"key1", "key2"} + assert set(store.yield_keys(prefix="key")) == {"key1", "key2"} + assert set(store.yield_keys(prefix="lang")) == set() + finally: + astra_db.delete_collection(collection_name) + + async def test_ayield_keys(self, async_astra_db: AsyncAstraDB) -> None: + collection_name = "lc_test_store_yield_keys" + try: + store = await init_async_store(async_astra_db, collection_name) + assert {key async for key in store.ayield_keys()} == {"key1", "key2"} + assert {key async for key in store.ayield_keys(prefix="key")} == { + "key1", + "key2", + } + assert {key async for key in store.ayield_keys(prefix="lang")} == set() + finally: + await async_astra_db.delete_collection(collection_name) + + def test_bytestore_mget(self, astra_db: AstraDB) -> None: + """Test AstraDBByteStore mget method.""" + collection_name = "lc_test_bytestore_mget" + try: + store = init_bytestore(astra_db, collection_name) + assert store.mget(["key1", "key2"]) == [b"value1", b"value2"] + finally: + astra_db.delete_collection(collection_name) + + def test_bytestore_mset(self, astra_db: AstraDB) -> None: + """Test that multiple keys can be set with AstraDBByteStore.""" + collection_name = "lc_test_bytestore_mset" + try: + store = init_bytestore(astra_db, collection_name) + result = store.collection.find_one({"_id": "key1"}) + assert result["data"]["document"]["value"] == "dmFsdWUx" + result = store.collection.find_one({"_id": "key2"}) + assert result["data"]["document"]["value"] == "dmFsdWUy" + finally: + astra_db.delete_collection(collection_name) diff --git a/libs/partners/astradb/tests/unit_tests/test_imports.py b/libs/partners/astradb/tests/unit_tests/test_imports.py index 2240748c70c6e..bd34cbffd2f4c 100644 --- a/libs/partners/astradb/tests/unit_tests/test_imports.py +++ b/libs/partners/astradb/tests/unit_tests/test_imports.py @@ -1,6 +1,8 @@ from langchain_astradb import __all__ EXPECTED_ALL = [ + "AstraDBByteStore", + "AstraDBStore", "AstraDBVectorStore", ]