Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Collection options is a dataclass and not a dict anymore #269

Merged
merged 3 commits into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,10 @@ from astrapy.info import (
AdminDatabaseInfo,
DatabaseInfo,
CollectionInfo,
CollectionDefaultIDOptions,
CollectionVectorOptions,
CollectionOptions,
CollectionDescriptor,
)
```

Expand Down
40 changes: 19 additions & 21 deletions astrapy/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
BulkWriteResult,
)
from astrapy.cursors import AsyncCursor, Cursor
from astrapy.info import CollectionInfo
from astrapy.info import CollectionInfo, CollectionOptions


if TYPE_CHECKING:
Expand Down Expand Up @@ -347,7 +347,7 @@ def set_caller(
caller_version=caller_version,
)

def options(self, *, max_time_ms: Optional[int] = None) -> Dict[str, Any]:
def options(self, *, max_time_ms: Optional[int] = None) -> CollectionOptions:
"""
Get the collection options, i.e. its configuration as read from the database.

Expand All @@ -359,22 +359,21 @@ def options(self, *, max_time_ms: Optional[int] = None) -> Dict[str, Any]:
max_time_ms: a timeout, in milliseconds, for the underlying HTTP request.

Returns:
a dictionary expressing the collection as a set of key-value pairs
matching the arguments of a `create_collection` call.
a CollectionOptions instance describing the collection.
(See also the database `list_collections` method.)

Example:
>>> my_coll.options()
{'name': 'my_v_collection', 'dimension': 3, 'metric': 'cosine'}
CollectionOptions(vector=CollectionVectorOptions(dimension=3, metric='cosine'))
"""

self_dicts = [
coll_dict
for coll_dict in self.database.list_collections(max_time_ms=max_time_ms)
if coll_dict["name"] == self.name
self_descriptors = [
coll_desc
for coll_desc in self.database.list_collections(max_time_ms=max_time_ms)
if coll_desc.name == self.name
]
if self_dicts:
return self_dicts[0] # type: ignore[no-any-return]
if self_descriptors:
return self_descriptors[0].options # type: ignore[no-any-return]
else:
raise CollectionNotFoundException(
text=f"Collection {self.namespace}.{self.name} not found.",
Expand Down Expand Up @@ -2411,7 +2410,7 @@ def set_caller(
caller_version=caller_version,
)

async def options(self, *, max_time_ms: Optional[int] = None) -> Dict[str, Any]:
async def options(self, *, max_time_ms: Optional[int] = None) -> CollectionOptions:
"""
Get the collection options, i.e. its configuration as read from the database.

Expand All @@ -2423,24 +2422,23 @@ async def options(self, *, max_time_ms: Optional[int] = None) -> Dict[str, Any]:
max_time_ms: a timeout, in milliseconds, for the underlying HTTP request.

Returns:
a dictionary expressing the collection as a set of key-value pairs
matching the arguments of a `create_collection` call.
a CollectionOptions instance describing the collection.
(See also the database `list_collections` method.)

Example:
>>> asyncio.run(my_async_coll.options())
{'name': 'my_v_collection', 'dimension': 3, 'metric': 'cosine'}
CollectionOptions(vector=CollectionVectorOptions(dimension=3, metric='cosine'))
"""

self_dicts = [
coll_dict
async for coll_dict in self.database.list_collections(
self_descriptors = [
coll_desc
async for coll_desc in self.database.list_collections(
max_time_ms=max_time_ms
)
if coll_dict["name"] == self.name
if coll_desc.name == self.name
]
if self_dicts:
return self_dicts[0] # type: ignore[no-any-return]
if self_descriptors:
return self_descriptors[0].options # type: ignore[no-any-return]
else:
raise CollectionNotFoundException(
text=f"Collection {self.namespace}.{self.name} not found.",
Expand Down
67 changes: 15 additions & 52 deletions astrapy/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
base_timeout_info,
)
from astrapy.cursors import AsyncCommandCursor, CommandCursor
from astrapy.info import DatabaseInfo
from astrapy.info import DatabaseInfo, CollectionDescriptor
from astrapy.admin import parse_api_endpoint, fetch_database_info

if TYPE_CHECKING:
Expand Down Expand Up @@ -70,41 +70,6 @@ def _validate_create_collection_options(
)


def _recast_api_collection_dict(api_coll_dict: Dict[str, Any]) -> Dict[str, Any]:
_name = api_coll_dict["name"]
_options = api_coll_dict.get("options") or {}
_v_options0 = _options.get("vector") or {}
_indexing = _options.get("indexing") or {}
_v_dimension = _v_options0.get("dimension")
_v_metric = _v_options0.get("metric")
_default_id = _options.get("defaultId")
# defaultId may potentially in the future have other subfields than 'type':
if _default_id:
_default_id_type = _default_id.get("type")
_rest_default_id = {k: v for k, v in _default_id.items() if k != "type"}
else:
_default_id_type = None
_rest_default_id = None
_additional_options = {
**{
k: v
for k, v in _options.items()
if k not in {"vector", "indexing", "defaultId"}
},
**({"defaultId": _rest_default_id} if _rest_default_id else {}),
}
recast_dict0 = {
"name": _name,
"dimension": _v_dimension,
"metric": _v_metric,
"indexing": _indexing,
"default_id_type": _default_id_type,
"additional_options": _additional_options,
}
recast_dict = {k: v for k, v in recast_dict0.items() if v}
return recast_dict


class Database:
"""
A Data API database. This is the entry-point object for doing database-level
Expand Down Expand Up @@ -592,7 +557,7 @@ def list_collections(
*,
namespace: Optional[str] = None,
max_time_ms: Optional[int] = None,
) -> CommandCursor[Dict[str, Any]]:
) -> CommandCursor[CollectionDescriptor]:
"""
List all collections in a given namespace for this database.

Expand All @@ -602,20 +567,19 @@ def list_collections(
max_time_ms: a timeout, in milliseconds, for the underlying HTTP request.

Returns:
a `CommandCursor` to iterate over dictionaries, each
expressing a collection as a set of key-value pairs
matching the arguments of a `create_collection` call.
a `CommandCursor` to iterate over CollectionDescriptor instances,
each corresponding to a collection.

Example:
>>> ccur = my_db.list_collections()
>>> ccur
<astrapy.cursors.CommandCursor object at ...>
>>> list(ccur)
[{'name': 'my_v_col'}]
[CollectionDescriptor(name='my_v_col', options=CollectionOptions())]
>>> for coll_dict in my_db.list_collections():
... print(coll_dict)
...
{'name': 'my_v_col'}
CollectionDescriptor(name='my_v_col', options=CollectionOptions())
"""

if namespace:
Expand All @@ -631,11 +595,11 @@ def list_collections(
raw_response=gc_response,
)
else:
# we know this is a list of dicts which need a little adjusting
# we know this is a list of dicts, to marshal into "descriptors"
return CommandCursor(
address=self._astra_db.base_url,
items=[
_recast_api_collection_dict(col_dict)
CollectionDescriptor.from_dict(col_dict)
for col_dict in gc_response["status"]["collections"]
],
)
Expand Down Expand Up @@ -1286,7 +1250,7 @@ def list_collections(
*,
namespace: Optional[str] = None,
max_time_ms: Optional[int] = None,
) -> AsyncCommandCursor[Dict[str, Any]]:
) -> AsyncCommandCursor[CollectionDescriptor]:
"""
List all collections in a given namespace for this database.

Expand All @@ -1296,9 +1260,8 @@ def list_collections(
max_time_ms: a timeout, in milliseconds, for the underlying HTTP request.

Returns:
an `AsyncCommandCursor` to iterate over dictionaries, each
expressing a collection as a set of key-value pairs
matching the arguments of a `create_collection` call.
an `AsyncCommandCursor` to iterate over CollectionDescriptor instances,
each corresponding to a collection.

Example:
>>> async def a_list_colls(adb: AsyncDatabase) -> None:
Expand All @@ -1310,8 +1273,8 @@ def list_collections(
...
>>> asyncio.run(a_list_colls(my_async_db))
* a_ccur: <astrapy.cursors.AsyncCommandCursor object at ...>
* list: [{'name': 'my_v_col'}]
* coll: {'name': 'my_v_col'}
* list: [CollectionDescriptor(name='my_v_col', options=CollectionOptions())]
* coll: CollectionDescriptor(name='my_v_col', options=CollectionOptions())
"""

_client: AsyncAstraDB
Expand All @@ -1329,11 +1292,11 @@ def list_collections(
raw_response=gc_response,
)
else:
# we know this is a list of dicts which need a little adjusting
# we know this is a list of dicts, to marshal into "descriptors"
return AsyncCommandCursor(
address=self._astra_db.base_url,
items=[
_recast_api_collection_dict(col_dict)
CollectionDescriptor.from_dict(col_dict)
for col_dict in gc_response["status"]["collections"]
],
)
Expand Down
Loading
Loading