diff --git a/README.md b/README.md index a77dd7e7..8d04e2d4 100644 --- a/README.md +++ b/README.md @@ -305,6 +305,10 @@ from astrapy.info import ( AdminDatabaseInfo, DatabaseInfo, CollectionInfo, + CollectionDefaultIDOptions, + CollectionVectorOptions, + CollectionOptions, + CollectionDescriptor, ) ``` diff --git a/astrapy/collection.py b/astrapy/collection.py index 5f4c68ae..daa689a2 100644 --- a/astrapy/collection.py +++ b/astrapy/collection.py @@ -57,7 +57,7 @@ BulkWriteResult, ) from astrapy.cursors import AsyncCursor, Cursor -from astrapy.info import CollectionInfo +from astrapy.info import CollectionInfo, CollectionOptions if TYPE_CHECKING: @@ -347,7 +347,7 @@ def set_caller( caller_version=caller_version, ) - def options(self, *, max_time_ms: Optional[int] = None) -> Dict[str, Any]: + def options(self, *, max_time_ms: Optional[int] = None) -> CollectionOptions: """ Get the collection options, i.e. its configuration as read from the database. @@ -359,22 +359,21 @@ def options(self, *, max_time_ms: Optional[int] = None) -> Dict[str, Any]: max_time_ms: a timeout, in milliseconds, for the underlying HTTP request. Returns: - a dictionary expressing the collection as a set of key-value pairs - matching the arguments of a `create_collection` call. + a CollectionOptions instance describing the collection. (See also the database `list_collections` method.) Example: >>> my_coll.options() - {'name': 'my_v_collection', 'dimension': 3, 'metric': 'cosine'} + CollectionOptions(vector=CollectionVectorOptions(dimension=3, metric='cosine')) """ - self_dicts = [ - coll_dict - for coll_dict in self.database.list_collections(max_time_ms=max_time_ms) - if coll_dict["name"] == self.name + self_descriptors = [ + coll_desc + for coll_desc in self.database.list_collections(max_time_ms=max_time_ms) + if coll_desc.name == self.name ] - if self_dicts: - return self_dicts[0] # type: ignore[no-any-return] + if self_descriptors: + return self_descriptors[0].options # type: ignore[no-any-return] else: raise CollectionNotFoundException( text=f"Collection {self.namespace}.{self.name} not found.", @@ -2411,7 +2410,7 @@ def set_caller( caller_version=caller_version, ) - async def options(self, *, max_time_ms: Optional[int] = None) -> Dict[str, Any]: + async def options(self, *, max_time_ms: Optional[int] = None) -> CollectionOptions: """ Get the collection options, i.e. its configuration as read from the database. @@ -2423,24 +2422,23 @@ async def options(self, *, max_time_ms: Optional[int] = None) -> Dict[str, Any]: max_time_ms: a timeout, in milliseconds, for the underlying HTTP request. Returns: - a dictionary expressing the collection as a set of key-value pairs - matching the arguments of a `create_collection` call. + a CollectionOptions instance describing the collection. (See also the database `list_collections` method.) Example: >>> asyncio.run(my_async_coll.options()) - {'name': 'my_v_collection', 'dimension': 3, 'metric': 'cosine'} + CollectionOptions(vector=CollectionVectorOptions(dimension=3, metric='cosine')) """ - self_dicts = [ - coll_dict - async for coll_dict in self.database.list_collections( + self_descriptors = [ + coll_desc + async for coll_desc in self.database.list_collections( max_time_ms=max_time_ms ) - if coll_dict["name"] == self.name + if coll_desc.name == self.name ] - if self_dicts: - return self_dicts[0] # type: ignore[no-any-return] + if self_descriptors: + return self_descriptors[0].options # type: ignore[no-any-return] else: raise CollectionNotFoundException( text=f"Collection {self.namespace}.{self.name} not found.", diff --git a/astrapy/database.py b/astrapy/database.py index d7d87d14..c079c6f8 100644 --- a/astrapy/database.py +++ b/astrapy/database.py @@ -27,7 +27,7 @@ base_timeout_info, ) from astrapy.cursors import AsyncCommandCursor, CommandCursor -from astrapy.info import DatabaseInfo +from astrapy.info import DatabaseInfo, CollectionDescriptor from astrapy.admin import parse_api_endpoint, fetch_database_info if TYPE_CHECKING: @@ -70,41 +70,6 @@ def _validate_create_collection_options( ) -def _recast_api_collection_dict(api_coll_dict: Dict[str, Any]) -> Dict[str, Any]: - _name = api_coll_dict["name"] - _options = api_coll_dict.get("options") or {} - _v_options0 = _options.get("vector") or {} - _indexing = _options.get("indexing") or {} - _v_dimension = _v_options0.get("dimension") - _v_metric = _v_options0.get("metric") - _default_id = _options.get("defaultId") - # defaultId may potentially in the future have other subfields than 'type': - if _default_id: - _default_id_type = _default_id.get("type") - _rest_default_id = {k: v for k, v in _default_id.items() if k != "type"} - else: - _default_id_type = None - _rest_default_id = None - _additional_options = { - **{ - k: v - for k, v in _options.items() - if k not in {"vector", "indexing", "defaultId"} - }, - **({"defaultId": _rest_default_id} if _rest_default_id else {}), - } - recast_dict0 = { - "name": _name, - "dimension": _v_dimension, - "metric": _v_metric, - "indexing": _indexing, - "default_id_type": _default_id_type, - "additional_options": _additional_options, - } - recast_dict = {k: v for k, v in recast_dict0.items() if v} - return recast_dict - - class Database: """ A Data API database. This is the entry-point object for doing database-level @@ -592,7 +557,7 @@ def list_collections( *, namespace: Optional[str] = None, max_time_ms: Optional[int] = None, - ) -> CommandCursor[Dict[str, Any]]: + ) -> CommandCursor[CollectionDescriptor]: """ List all collections in a given namespace for this database. @@ -602,20 +567,19 @@ def list_collections( max_time_ms: a timeout, in milliseconds, for the underlying HTTP request. Returns: - a `CommandCursor` to iterate over dictionaries, each - expressing a collection as a set of key-value pairs - matching the arguments of a `create_collection` call. + a `CommandCursor` to iterate over CollectionDescriptor instances, + each corresponding to a collection. Example: >>> ccur = my_db.list_collections() >>> ccur >>> list(ccur) - [{'name': 'my_v_col'}] + [CollectionDescriptor(name='my_v_col', options=CollectionOptions())] >>> for coll_dict in my_db.list_collections(): ... print(coll_dict) ... - {'name': 'my_v_col'} + CollectionDescriptor(name='my_v_col', options=CollectionOptions()) """ if namespace: @@ -631,11 +595,11 @@ def list_collections( raw_response=gc_response, ) else: - # we know this is a list of dicts which need a little adjusting + # we know this is a list of dicts, to marshal into "descriptors" return CommandCursor( address=self._astra_db.base_url, items=[ - _recast_api_collection_dict(col_dict) + CollectionDescriptor.from_dict(col_dict) for col_dict in gc_response["status"]["collections"] ], ) @@ -1286,7 +1250,7 @@ def list_collections( *, namespace: Optional[str] = None, max_time_ms: Optional[int] = None, - ) -> AsyncCommandCursor[Dict[str, Any]]: + ) -> AsyncCommandCursor[CollectionDescriptor]: """ List all collections in a given namespace for this database. @@ -1296,9 +1260,8 @@ def list_collections( max_time_ms: a timeout, in milliseconds, for the underlying HTTP request. Returns: - an `AsyncCommandCursor` to iterate over dictionaries, each - expressing a collection as a set of key-value pairs - matching the arguments of a `create_collection` call. + an `AsyncCommandCursor` to iterate over CollectionDescriptor instances, + each corresponding to a collection. Example: >>> async def a_list_colls(adb: AsyncDatabase) -> None: @@ -1310,8 +1273,8 @@ def list_collections( ... >>> asyncio.run(a_list_colls(my_async_db)) * a_ccur: - * list: [{'name': 'my_v_col'}] - * coll: {'name': 'my_v_col'} + * list: [CollectionDescriptor(name='my_v_col', options=CollectionOptions())] + * coll: CollectionDescriptor(name='my_v_col', options=CollectionOptions()) """ _client: AsyncAstraDB @@ -1329,11 +1292,11 @@ def list_collections( raw_response=gc_response, ) else: - # we know this is a list of dicts which need a little adjusting + # we know this is a list of dicts, to marshal into "descriptors" return AsyncCommandCursor( address=self._astra_db.base_url, items=[ - _recast_api_collection_dict(col_dict) + CollectionDescriptor.from_dict(col_dict) for col_dict in gc_response["status"]["collections"] ], ) diff --git a/astrapy/info.py b/astrapy/info.py index 64ac56e2..8f46998e 100644 --- a/astrapy/info.py +++ b/astrapy/info.py @@ -56,26 +56,6 @@ class DatabaseInfo: raw_info: Optional[Dict[str, Any]] -@dataclass -class CollectionInfo: - """ - Represents the identifying information for a collection, - including the information about the database the collection belongs to. - - Attributes: - database_info: a DatabaseInfo instance for the underlying database. - namespace: the namespace where the collection is located. - name: collection name. Unique within a namespace. - full_name: identifier for the collection within the database, - in the form "namespace.collection_name". - """ - - database_info: DatabaseInfo - namespace: str - name: str - full_name: str - - @dataclass class AdminDatabaseInfo: """ @@ -126,3 +106,264 @@ class AdminDatabaseInfo: storage: Dict[str, Any] termination_time: str raw_info: Optional[Dict[str, Any]] + + +@dataclass +class CollectionInfo: + """ + Represents the identifying information for a collection, + including the information about the database the collection belongs to. + + Attributes: + database_info: a DatabaseInfo instance for the underlying database. + namespace: the namespace where the collection is located. + name: collection name. Unique within a namespace. + full_name: identifier for the collection within the database, + in the form "namespace.collection_name". + """ + + database_info: DatabaseInfo + namespace: str + name: str + full_name: str + + +@dataclass +class CollectionDefaultIDOptions: + """ + The "defaultId" component of the collection options. + See the Data API specifications for allowed values. + + Attributes: + default_id_type: string such as `objectId`, `uuid6` and so on. + """ + + default_id_type: str + + def as_dict(self) -> Dict[str, Any]: + """Recast this object into a dictionary.""" + + return {"type": self.default_id_type} + + @staticmethod + def from_dict( + raw_dict: Optional[Dict[str, Any]] + ) -> Optional[CollectionDefaultIDOptions]: + """ + Create an instance of CollectionDefaultIDOptions from a dictionary + such as one from the Data API. + """ + + if raw_dict is not None: + return CollectionDefaultIDOptions(default_id_type=raw_dict["type"]) + else: + return None + + +@dataclass +class CollectionVectorOptions: + """ + The "vector" component of the collection options. + See the Data API specifications for allowed values. + + Attributes: + dimension: an optional positive integer, the dimensionality of the vector space. + metric: an optional metric among `VectorMetric.DOT_PRODUCT`, + `VectorMetric.EUCLIDEAN` and `VectorMetric.COSINE`. + """ + + dimension: Optional[int] + metric: Optional[str] + + def as_dict(self) -> Dict[str, Any]: + """Recast this object into a dictionary.""" + + return { + k: v + for k, v in { + "dimension": self.dimension, + "metric": self.metric, + }.items() + if v is not None + } + + @staticmethod + def from_dict( + raw_dict: Optional[Dict[str, Any]] + ) -> Optional[CollectionVectorOptions]: + """ + Create an instance of CollectionVectorOptions from a dictionary + such as one from the Data API. + """ + + if raw_dict is not None: + return CollectionVectorOptions( + dimension=raw_dict.get("dimension"), + metric=raw_dict.get("metric"), + ) + else: + return None + + +@dataclass +class CollectionOptions: + """ + A structure expressing the options of a collection. + See the Data API specifications for detailed specification and allowed values. + + Attributes: + vector: an optional CollectionVectorOptions object. + indexing: an optional dictionary with the "indexing" collection properties. + default_id: an optional CollectionDefaultIDOptions object. + raw_options: the raw response from the Data API for the collection configuration. + """ + + vector: Optional[CollectionVectorOptions] + indexing: Optional[Dict[str, Any]] + default_id: Optional[CollectionDefaultIDOptions] + raw_options: Optional[Dict[str, Any]] + + def __repr__(self) -> str: + not_null_pieces = [ + pc + for pc in [ + None if self.vector is None else f"vector={self.vector.__repr__()}", + ( + None + if self.indexing is None + else f"indexing={self.indexing.__repr__()}" + ), + ( + None + if self.default_id is None + else f"default_id={self.default_id.__repr__()}" + ), + ] + if pc is not None + ] + return f"{self.__class__.__name__}({', '.join(not_null_pieces)})" + + def as_dict(self) -> Dict[str, Any]: + """Recast this object into a dictionary.""" + + return { + k: v + for k, v in { + "vector": None if self.vector is None else self.vector.as_dict(), + "indexing": self.indexing, + "defaultId": ( + None if self.default_id is None else self.default_id.as_dict() + ), + }.items() + if v is not None + } + + def flatten(self) -> Dict[str, Any]: + """ + Recast this object as a flat key-value pair suitable for + use as kwargs in a create_collection method call. + """ + + _dimension: Optional[int] + _metric: Optional[str] + _indexing: Optional[Dict[str, Any]] + _default_id_type: Optional[str] + if self.vector is not None: + _dimension = self.vector.dimension + _metric = self.vector.metric + else: + _dimension = None + _metric = None + _indexing = self.indexing + if self.default_id is not None: + _default_id_type = self.default_id.default_id_type + else: + _default_id_type = None + + return { + k: v + for k, v in { + "dimension": _dimension, + "metric": _metric, + "indexing": _indexing, + "default_id_type": _default_id_type, + }.items() + if v is not None + } + + @staticmethod + def from_dict(raw_dict: Dict[str, Any]) -> CollectionOptions: + """ + Create an instance of CollectionOptions from a dictionary + such as one from the Data API. + """ + + return CollectionOptions( + vector=CollectionVectorOptions.from_dict(raw_dict.get("vector")), + indexing=raw_dict.get("indexing"), + default_id=CollectionDefaultIDOptions.from_dict(raw_dict.get("defaultId")), + raw_options=raw_dict, + ) + + +@dataclass +class CollectionDescriptor: + """ + A structure expressing full description of a collection as the Data API + returns it, i.e. its name and its `options` sub-structure. + + Attributes: + name: the name of the collection. + options: a CollectionOptions instance. + raw_descriptor: the raw response from the Data API. + """ + + name: str + options: CollectionOptions + raw_descriptor: Optional[Dict[str, Any]] + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(" + f"name={self.name.__repr__()}, " + f"options={self.options.__repr__()})" + ) + + def as_dict(self) -> Dict[str, Any]: + """ + Recast this object into a dictionary. + Empty `options` will not be returned at all. + """ + + return { + k: v + for k, v in { + "name": self.name, + "options": self.options.as_dict(), + }.items() + if v + } + + def flatten(self) -> Dict[str, Any]: + """ + Recast this object as a flat key-value pair suitable for + use as kwargs in a create_collection method call. + """ + + return { + **(self.options.flatten()), + **{"name": self.name}, + } + + @staticmethod + def from_dict(raw_dict: Dict[str, Any]) -> CollectionDescriptor: + """ + Create an instance of CollectionDescriptor from a dictionary + such as one from the Data API. + """ + + return CollectionDescriptor( + name=raw_dict["name"], + options=CollectionOptions.from_dict(raw_dict.get("options") or {}), + raw_descriptor=raw_dict, + ) diff --git a/tests/idiomatic/integration/test_ddl_async.py b/tests/idiomatic/integration/test_ddl_async.py index fee25c6b..e619945d 100644 --- a/tests/idiomatic/integration/test_ddl_async.py +++ b/tests/idiomatic/integration/test_ddl_async.py @@ -21,7 +21,7 @@ ASTRA_DB_SECONDARY_KEYSPACE, TEST_COLLECTION_NAME, ) -from astrapy.info import DatabaseInfo +from astrapy.info import CollectionDescriptor, DatabaseInfo from astrapy.constants import DefaultIdType, VectorMetric from astrapy.ids import ObjectId, UUID from astrapy import AsyncCollection, AsyncDatabase @@ -47,18 +47,28 @@ async def test_collection_lifecycle_async( ) lc_response = [col async for col in async_database.list_collections()] # - expected_coll_dict = { - "name": TEST_LOCAL_COLLECTION_NAME, - "dimension": 123, - "metric": "euclidean", - "indexing": {"deny": ["a", "b", "c"]}, - } - expected_coll_dict_b = { - "name": TEST_LOCAL_COLLECTION_NAME_B, - "indexing": {"allow": ["z"]}, - } - assert expected_coll_dict in lc_response - assert expected_coll_dict_b in lc_response + expected_coll_descriptor = CollectionDescriptor.from_dict( + { + "name": TEST_LOCAL_COLLECTION_NAME, + "options": { + "vector": { + "dimension": 123, + "metric": "euclidean", + }, + "indexing": {"deny": ["a", "b", "c"]}, + }, + }, + ) + expected_coll_descriptor_b = CollectionDescriptor.from_dict( + { + "name": TEST_LOCAL_COLLECTION_NAME_B, + "options": { + "indexing": {"allow": ["z"]}, + }, + }, + ) + assert expected_coll_descriptor in lc_response + assert expected_coll_descriptor_b in lc_response # col2 = await async_database.get_collection(TEST_LOCAL_COLLECTION_NAME) assert col1 == col2 @@ -79,7 +89,7 @@ async def test_collection_default_id_type_async( ID_TEST_COLLECTION_NAME_ROOT + DefaultIdType.UUID, default_id_type=DefaultIdType.UUID, ) - assert (await acol.options())["default_id_type"] == DefaultIdType.UUID + assert (await acol.options()).default_id.default_id_type == DefaultIdType.UUID await acol.insert_one({"role": "probe"}) doc = await acol.find_one({}) assert isinstance(doc["_id"], UUID) @@ -90,7 +100,7 @@ async def test_collection_default_id_type_async( ID_TEST_COLLECTION_NAME_ROOT + DefaultIdType.UUIDV6, default_id_type=DefaultIdType.UUIDV6, ) - assert (await acol.options())["default_id_type"] == DefaultIdType.UUIDV6 + assert (await acol.options()).default_id.default_id_type == DefaultIdType.UUIDV6 await acol.insert_one({"role": "probe"}) doc = await acol.find_one({}) assert isinstance(doc["_id"], UUID) @@ -102,7 +112,7 @@ async def test_collection_default_id_type_async( ID_TEST_COLLECTION_NAME_ROOT + DefaultIdType.UUIDV7, default_id_type=DefaultIdType.UUIDV7, ) - assert (await acol.options())["default_id_type"] == DefaultIdType.UUIDV7 + assert (await acol.options()).default_id.default_id_type == DefaultIdType.UUIDV7 await acol.insert_one({"role": "probe"}) doc = await acol.find_one({}) assert isinstance(doc["_id"], UUID) @@ -114,7 +124,9 @@ async def test_collection_default_id_type_async( ID_TEST_COLLECTION_NAME_ROOT + DefaultIdType.DEFAULT, default_id_type=DefaultIdType.DEFAULT, ) - assert (await acol.options())["default_id_type"] == DefaultIdType.DEFAULT + assert ( + await acol.options() + ).default_id.default_id_type == DefaultIdType.DEFAULT await acol.drop() time.sleep(2) @@ -122,7 +134,9 @@ async def test_collection_default_id_type_async( ID_TEST_COLLECTION_NAME_ROOT + DefaultIdType.OBJECTID, default_id_type=DefaultIdType.OBJECTID, ) - assert (await acol.options())["default_id_type"] == DefaultIdType.OBJECTID + assert ( + await acol.options() + ).default_id.default_id_type == DefaultIdType.OBJECTID await acol.insert_one({"role": "probe"}) doc = await acol.find_one({}) assert isinstance(doc["_id"], ObjectId) @@ -176,7 +190,8 @@ async def test_collection_options_async( async_collection: AsyncCollection, ) -> None: options = await async_collection.options() - assert options["name"] == async_collection.name + assert options.vector is not None + assert options.vector.dimension == 2 @pytest.mark.skipif( ASTRA_DB_SECONDARY_KEYSPACE is None, reason="No secondary keyspace provided" diff --git a/tests/idiomatic/integration/test_ddl_sync.py b/tests/idiomatic/integration/test_ddl_sync.py index 6197fad9..854698c7 100644 --- a/tests/idiomatic/integration/test_ddl_sync.py +++ b/tests/idiomatic/integration/test_ddl_sync.py @@ -21,7 +21,7 @@ ASTRA_DB_SECONDARY_KEYSPACE, TEST_COLLECTION_NAME, ) -from astrapy.info import DatabaseInfo +from astrapy.info import CollectionDescriptor, DatabaseInfo from astrapy.constants import DefaultIdType, VectorMetric from astrapy.ids import ObjectId, UUID from astrapy import Collection, Database @@ -47,18 +47,28 @@ def test_collection_lifecycle_sync( ) lc_response = list(sync_database.list_collections()) # - expected_coll_dict = { - "name": TEST_LOCAL_COLLECTION_NAME, - "dimension": 123, - "metric": "euclidean", - "indexing": {"deny": ["a", "b", "c"]}, - } - expected_coll_dict_b = { - "name": TEST_LOCAL_COLLECTION_NAME_B, - "indexing": {"allow": ["z"]}, - } - assert expected_coll_dict in lc_response - assert expected_coll_dict_b in lc_response + expected_coll_descriptor = CollectionDescriptor.from_dict( + { + "name": TEST_LOCAL_COLLECTION_NAME, + "options": { + "vector": { + "dimension": 123, + "metric": "euclidean", + }, + "indexing": {"deny": ["a", "b", "c"]}, + }, + }, + ) + expected_coll_descriptor_b = CollectionDescriptor.from_dict( + { + "name": TEST_LOCAL_COLLECTION_NAME_B, + "options": { + "indexing": {"allow": ["z"]}, + }, + }, + ) + assert expected_coll_descriptor in lc_response + assert expected_coll_descriptor_b in lc_response # col2 = sync_database.get_collection(TEST_LOCAL_COLLECTION_NAME) assert col1 == col2 @@ -79,7 +89,7 @@ def test_collection_default_id_type_sync( ID_TEST_COLLECTION_NAME_ROOT + DefaultIdType.UUID, default_id_type=DefaultIdType.UUID, ) - assert col.options()["default_id_type"] == DefaultIdType.UUID + assert col.options().default_id.default_id_type == DefaultIdType.UUID col.insert_one({"role": "probe"}) doc = col.find_one({}) assert isinstance(doc["_id"], UUID) @@ -90,7 +100,7 @@ def test_collection_default_id_type_sync( ID_TEST_COLLECTION_NAME_ROOT + DefaultIdType.UUIDV6, default_id_type=DefaultIdType.UUIDV6, ) - assert col.options()["default_id_type"] == DefaultIdType.UUIDV6 + assert col.options().default_id.default_id_type == DefaultIdType.UUIDV6 col.insert_one({"role": "probe"}) doc = col.find_one({}) assert isinstance(doc["_id"], UUID) @@ -102,7 +112,7 @@ def test_collection_default_id_type_sync( ID_TEST_COLLECTION_NAME_ROOT + DefaultIdType.UUIDV7, default_id_type=DefaultIdType.UUIDV7, ) - assert col.options()["default_id_type"] == DefaultIdType.UUIDV7 + assert col.options().default_id.default_id_type == DefaultIdType.UUIDV7 col.insert_one({"role": "probe"}) doc = col.find_one({}) assert isinstance(doc["_id"], UUID) @@ -114,7 +124,7 @@ def test_collection_default_id_type_sync( ID_TEST_COLLECTION_NAME_ROOT + DefaultIdType.DEFAULT, default_id_type=DefaultIdType.DEFAULT, ) - assert col.options()["default_id_type"] == DefaultIdType.DEFAULT + assert col.options().default_id.default_id_type == DefaultIdType.DEFAULT col.drop() time.sleep(2) @@ -122,7 +132,7 @@ def test_collection_default_id_type_sync( ID_TEST_COLLECTION_NAME_ROOT + DefaultIdType.OBJECTID, default_id_type=DefaultIdType.OBJECTID, ) - assert col.options()["default_id_type"] == DefaultIdType.OBJECTID + assert col.options().default_id.default_id_type == DefaultIdType.OBJECTID col.insert_one({"role": "probe"}) doc = col.find_one({}) assert isinstance(doc["_id"], ObjectId) @@ -172,7 +182,8 @@ def test_collection_options_sync( sync_collection: Collection, ) -> None: options = sync_collection.options() - assert options["name"] == sync_collection.name + assert options.vector is not None + assert options.vector.dimension == 2 @pytest.mark.skipif( ASTRA_DB_SECONDARY_KEYSPACE is None, reason="No secondary keyspace provided" diff --git a/tests/idiomatic/unit/test_collection_options.py b/tests/idiomatic/unit/test_collection_options.py index 347d9e02..140ef76a 100644 --- a/tests/idiomatic/unit/test_collection_options.py +++ b/tests/idiomatic/unit/test_collection_options.py @@ -16,115 +16,166 @@ Unit tests for the validation/parsing of collection options """ +from typing import Any, Dict, List, Tuple + import pytest -from astrapy.database import _recast_api_collection_dict +from astrapy.info import CollectionDescriptor @pytest.mark.describe("test of recasting the collection options from the api") def test_recast_api_collection_dict() -> None: - plain_raw = { - "name": "tablename", - "options": {}, - } - plain_expected = { - "name": "tablename", - } - assert _recast_api_collection_dict(plain_raw) == plain_expected - - plainplus_raw = { - "name": "tablename", - "options": { - "futuretopfield": "ftvalue", - }, - } - plainplus_expected = { - "name": "tablename", - "additional_options": { - "futuretopfield": "ftvalue", - }, - } - assert _recast_api_collection_dict(plainplus_raw) == plainplus_expected - - dim_raw = { - "name": "tablename", - "options": { - "vector": { - "dimension": 10, - }, - }, - } - dim_expected = { - "name": "tablename", - "dimension": 10, - } - assert _recast_api_collection_dict(dim_raw) == dim_expected - - dim_met_raw = { - "name": "tablename", - "options": { - "vector": { - "dimension": 10, + api_coll_descs: List[Tuple[Dict[str, Any], Dict[str, Any]]] = [ + # minimal: + ( + { + "name": "col_name", + }, + {"name": "col_name"}, + ), + # full: + ( + { + "name": "col_name", + "options": { + "vector": { + "dimension": 1024, + "metric": "cosine", + }, + "indexing": {"deny": ["a"]}, + "defaultId": {"type": "objectId"}, + }, + }, + { + "name": "col_name", + "dimension": 1024, "metric": "cosine", + "indexing": {"deny": ["a"]}, + "default_id_type": "objectId", }, - }, - } - dim_met_expected = { - "name": "tablename", - "dimension": 10, - "metric": "cosine", - } - assert _recast_api_collection_dict(dim_met_raw) == dim_met_expected - - dim_met_did_idx_raw = { - "name": "tablename", - "options": { - "defaultId": {"type": "objectId"}, - "indexing": { - "allow": ["a"], - }, - "vector": { - "dimension": 10, + ), + # partial/absent 'vector': + ( + { + "name": "col_name", + "options": { + "vector": { + "metric": "cosine", + }, + "indexing": {"deny": ["a"]}, + "defaultId": {"type": "objectId"}, + }, + }, + { + "name": "col_name", "metric": "cosine", + "indexing": {"deny": ["a"]}, + "default_id_type": "objectId", }, - }, - } - dim_met_did_idx_expected = { - "name": "tablename", - "dimension": 10, - "metric": "cosine", - "indexing": {"allow": ["a"]}, - "default_id_type": "objectId", - } - assert _recast_api_collection_dict(dim_met_did_idx_raw) == dim_met_did_idx_expected - - dim_met_didplus_idx_raw = { - "name": "tablename", - "options": { - "defaultId": { - "type": "objectId", - "futurefield": "fvalue", - }, - "indexing": { - "allow": ["a"], - }, - "vector": { - "dimension": 10, + ), + ( + { + "name": "col_name", + "options": { + "vector": { + "dimension": 1024, + }, + "indexing": {"deny": ["a"]}, + "defaultId": {"type": "objectId"}, + }, + }, + { + "name": "col_name", + "dimension": 1024, + "indexing": {"deny": ["a"]}, + "default_id_type": "objectId", + }, + ), + ( + { + "name": "col_name", + "options": { + "vector": {}, + "indexing": {"deny": ["a"]}, + "defaultId": {"type": "objectId"}, + }, + }, + { + "name": "col_name", + "indexing": {"deny": ["a"]}, + "default_id_type": "objectId", + }, + ), + ( + { + "name": "col_name", + "options": { + "indexing": {"deny": ["a"]}, + "defaultId": {"type": "objectId"}, + }, + }, + { + "name": "col_name", + "indexing": {"deny": ["a"]}, + "default_id_type": "objectId", + }, + ), + # no indexing: + ( + { + "name": "col_name", + "options": { + "vector": { + "dimension": 1024, + "metric": "cosine", + }, + "defaultId": {"type": "objectId"}, + }, + }, + { + "name": "col_name", + "dimension": 1024, + "metric": "cosine", + "default_id_type": "objectId", + }, + ), + # no defaultId: + ( + { + "name": "col_name", + "options": { + "vector": { + "dimension": 1024, + "metric": "cosine", + }, + "indexing": {"deny": ["a"]}, + }, + }, + { + "name": "col_name", + "dimension": 1024, + "metric": "cosine", + "indexing": {"deny": ["a"]}, + }, + ), + # no indexing + no defaultId: + ( + { + "name": "col_name", + "options": { + "vector": { + "dimension": 1024, + "metric": "cosine", + }, + }, + }, + { + "name": "col_name", + "dimension": 1024, "metric": "cosine", }, - }, - } - dim_met_didplus_idx_expected = { - "name": "tablename", - "dimension": 10, - "metric": "cosine", - "indexing": {"allow": ["a"]}, - "default_id_type": "objectId", - "additional_options": { - "defaultId": {"futurefield": "fvalue"}, - }, - } - assert ( - _recast_api_collection_dict(dim_met_didplus_idx_raw) - == dim_met_didplus_idx_expected - ) + ), + ] + for api_coll_desc, flattened_dict in api_coll_descs: + assert CollectionDescriptor.from_dict(api_coll_desc).as_dict() == api_coll_desc + assert CollectionDescriptor.from_dict(api_coll_desc).flatten() == flattened_dict diff --git a/tests/idiomatic/unit/test_imports.py b/tests/idiomatic/unit/test_imports.py index 875f87a4..ec2154e7 100644 --- a/tests/idiomatic/unit/test_imports.py +++ b/tests/idiomatic/unit/test_imports.py @@ -82,6 +82,10 @@ def test_imports() -> None: AdminDatabaseInfo, DatabaseInfo, CollectionInfo, + CollectionDefaultIDOptions, + CollectionVectorOptions, + CollectionOptions, + CollectionDescriptor, ) from astrapy.admin import ( # noqa: F401 Environment,