Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

partner: Astra DB, add indexing support for Vector Store class #17767

15 changes: 14 additions & 1 deletion libs/partners/astradb/langchain_astradb/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,20 @@ class AstraDBBaseStore(Generic[V], BaseStore[str, V], ABC):
"""Base class for the DataStax AstraDB data store."""

def __init__(self, *args: Any, **kwargs: Any) -> None:
self.astra_env = _AstraDBCollectionEnvironment(*args, **kwargs)
if "requested_indexing_policy" in kwargs:
raise ValueError(
"Do not pass 'requested_indexing_policy' to AstraDBBaseStore init"
)
if "default_indexing_policy" in kwargs:
raise ValueError(
"Do not pass 'default_indexing_policy' to AstraDBBaseStore init"
)
kwargs["requested_indexing_policy"] = {"allow": ["_id"]}
kwargs["default_indexing_policy"] = {"allow": ["_id"]}
self.astra_env = _AstraDBCollectionEnvironment(
*args,
**kwargs,
)
self.collection = self.astra_env.collection
self.async_collection = self.astra_env.async_collection

Expand Down
174 changes: 166 additions & 8 deletions libs/partners/astradb/langchain_astradb/utils/astradb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@

import asyncio
import inspect
import json
import warnings
from asyncio import InvalidStateError, Task
from enum import Enum
from typing import Awaitable, Optional, Union
from typing import Any, Awaitable, Dict, List, Optional, Union

import langchain_core
from astrapy.api import APIRequestError
from astrapy.db import AstraDB, AsyncAstraDB


Expand Down Expand Up @@ -89,6 +92,8 @@ def __init__(
pre_delete_collection: bool = False,
embedding_dimension: Union[int, Awaitable[int], None] = None,
metric: Optional[str] = None,
requested_indexing_policy: Optional[Dict[str, Any]] = None,
default_indexing_policy: Optional[Dict[str, Any]] = None,
) -> None:
from astrapy.db import AstraDBCollection, AsyncAstraDBCollection

Expand All @@ -106,6 +111,11 @@ def __init__(
astra_db=self.async_astra_db,
)

if requested_indexing_policy is not None:
_options = {"indexing": requested_indexing_policy}
else:
_options = None

self.async_setup_db_task: Optional[Task] = None
if setup_mode == SetupMode.ASYNC:
async_astra_db = self.async_astra_db
Expand All @@ -117,9 +127,31 @@ async def _setup_db() -> None:
dimension = await embedding_dimension
else:
dimension = embedding_dimension
await async_astra_db.create_collection(
collection_name, dimension=dimension, metric=metric
)

try:
hemidactylus marked this conversation as resolved.
Show resolved Hide resolved
await async_astra_db.create_collection(
collection_name,
dimension=dimension,
metric=metric,
options=_options,
)
except (APIRequestError, ValueError):
# possibly the collection is preexisting and may have legacy,
# or custom, indexing settings: verify
get_coll_response = await async_astra_db.get_collections(
options={"explain": True}
)
collections = (get_coll_response["status"] or {}).get(
"collections"
) or []
if not self._validate_indexing_policy(
detected_collections=collections,
collection_name=self.collection_name,
requested_indexing_policy=requested_indexing_policy,
default_indexing_policy=default_indexing_policy,
):
# other reasons for the exception
raise

self.async_setup_db_task = asyncio.create_task(_setup_db())
elif setup_mode == SetupMode.SYNC:
Expand All @@ -130,12 +162,138 @@ async def _setup_db() -> None:
"Cannot use an awaitable embedding_dimension with async_setup "
"set to False"
)
self.astra_db.create_collection(
collection_name,
dimension=embedding_dimension, # type: ignore[arg-type]
metric=metric,
else:
try:
self.astra_db.create_collection(
collection_name,
dimension=embedding_dimension, # type: ignore[arg-type]
metric=metric,
options=_options,
)
except (APIRequestError, ValueError):
# possibly the collection is preexisting and may have legacy,
# or custom, indexing settings: verify
get_coll_response = self.astra_db.get_collections( # type: ignore[union-attr]
options={"explain": True}
)
collections = (get_coll_response["status"] or {}).get(
"collections"
) or []
if not self._validate_indexing_policy(
detected_collections=collections,
collection_name=self.collection_name,
requested_indexing_policy=requested_indexing_policy,
default_indexing_policy=default_indexing_policy,
):
# other reasons for the exception
raise

@staticmethod
def _validate_indexing_policy(
detected_collections: List[Dict[str, Any]],
collection_name: str,
requested_indexing_policy: Optional[Dict[str, Any]],
default_indexing_policy: Optional[Dict[str, Any]],
) -> bool:
"""
This is a validation helper, to be called when the collection-creation
call has failed.

Args:
detected_collection (List[Dict[str, Any]]):
the list of collection items returned by astrapy
collection_name (str): the name of the collection whose attempted
creation failed
requested_indexing_policy: the 'indexing' part of the collection
options, e.g. `{"deny": ["field1", "field2"]}`.
Leave to its default of None if no options required.
default_indexing_policy: an optional 'default value' for the
above, used to issue just a gentle warning in the special
case that no policy is detected on a preexisting collection
on DB and the default is requested. This is to enable
a warning-only transition to new code using indexing without
disrupting usage of a legacy collection, i.e. one created
before adopting the usage of indexing policies altogether.
You cannot pass this one without requested_indexing_policy.

This function may raise an error (indexing mismatches), issue a warning
(about legacy collections), or do nothing.
In any case, when the function returns, it returns either
- True: the exception was handled here as part of the indexing
management
- False: the exception is unrelated to indexing and the caller
has to reraise it.
"""
if requested_indexing_policy is None and default_indexing_policy is not None:
raise ValueError(
"Cannot specify a default indexing policy "
"when no indexing policy is requested for this collection "
"(requested_indexing_policy is None, "
"default_indexing_policy is not None)."
)

preexisting = [
collection
for collection in detected_collections
if collection["name"] == collection_name
]
if preexisting:
pre_collection = preexisting[0]
# if it has no "indexing", it is a legacy collection
pre_col_options = pre_collection.get("options") or {}
if "indexing" not in pre_col_options:
# legacy collection on DB
if requested_indexing_policy == default_indexing_policy:
warnings.warn(
(
f"Astra DB collection '{collection_name}' is "
"detected as legacy and has indexing turned "
"on for all fields. This implies stricter "
"limitations on the amount of text each string in a "
"document can store. Consider reindexing anew on a "
"fresh collection to be able to store longer texts."
),
UserWarning,
stacklevel=2,
)
else:
raise ValueError(
f"Astra DB collection '{collection_name}' is "
"detected as legacy and has indexing turned "
"on for all fields. This is incompatible with "
"the requested indexing policy for this object. "
"Consider reindexing anew on a fresh "
"collection with the requested indexing "
"policy, or alternatively leave the indexing "
"settings for this object to their defaults "
"to keep using this collection."
)
elif pre_col_options["indexing"] != requested_indexing_policy:
# collection on DB has indexing settings, but different
options_json = json.dumps(pre_col_options["indexing"])
if pre_col_options["indexing"] == default_indexing_policy:
default_desc = " (default setting)"
else:
default_desc = ""
raise ValueError(
f"Astra DB collection '{collection_name}' is "
"detected as having the following indexing policy: "
f"{options_json}{default_desc}. This is incompatible "
"with the requested indexing policy for this object. "
"Consider reindexing anew on a fresh "
"collection with the requested indexing "
"policy, or alternatively align the requested "
"indexing settings to the collection to keep using it."
)
else:
# the discrepancies have to do with options other than indexing
return False
# the original exception, related to indexing, was handled here
return True
else:
# foreign-origin for the original exception
return False

def ensure_db_setup(self) -> None:
if self.async_setup_db_task:
try:
Expand Down
Loading
Loading