From de7116423adf0602b62642782f27152cd40398eb Mon Sep 17 00:00:00 2001 From: Stefano Lottini Date: Mon, 1 Apr 2024 11:51:08 +0200 Subject: [PATCH] Added async support for admin, as alternate methods on original classes (#272) AstraDBADmin: methods async_database_info async_create_database async_list_databases async_drop_database AstraDBDatabaseAdmin: methods async_list_namespaces async_create_namespace async_drop_namespace async_info async_drop Utility functions: async_fetch_raw_database_info_from_id_token async_fetch_database_info Added to core/ops.py as needed: async_get_databases async_get_database async_create_database async_terminate_database async_create_keyspace async_delete_keyspace All idiomatic addtions come with tests, docstrings --- astrapy/admin.py | 659 +++++++++++++++++++++- astrapy/core/ops.py | 229 +++++++- astrapy/exceptions.py | 23 + tests/idiomatic/integration/test_admin.py | 282 ++++++++- 4 files changed, 1180 insertions(+), 13 deletions(-) diff --git a/astrapy/admin.py b/astrapy/admin.py index 69f66210..b9ce49d8 100644 --- a/astrapy/admin.py +++ b/astrapy/admin.py @@ -14,6 +14,7 @@ from __future__ import annotations +import asyncio import logging import re import time @@ -32,6 +33,7 @@ base_timeout_info, to_dataapi_timeout_exception, ops_recast_method_sync, + ops_recast_method_async, ) @@ -189,6 +191,41 @@ def fetch_raw_database_info_from_id_token( raise to_dataapi_timeout_exception(texc) +async def async_fetch_raw_database_info_from_id_token( + id: str, + *, + token: str, + environment: str = Environment.PROD, + max_time_ms: Optional[int] = None, +) -> Dict[str, Any]: + """ + Fetch database information through the DevOps API and return it in + full, exactly like the API gives it back. + Async version of the function, for use in an asyncio context. + + Args: + id: e. g. "01234567-89ab-cdef-0123-456789abcdef". + token: a valid token to access the database information. + max_time_ms: a timeout, in milliseconds, for waiting on a response. + + Returns: + The full response from the DevOps API about the database. + """ + + astra_db_ops = AstraDBOps( + token=token, + dev_ops_url=DEV_OPS_URL_MAP[environment], + ) + try: + gd_response = await astra_db_ops.async_get_database( + database=id, + timeout_info=base_timeout_info(max_time_ms), + ) + return gd_response + except httpx.TimeoutException as texc: + raise to_dataapi_timeout_exception(texc) + + def fetch_database_info( api_endpoint: str, token: str, namespace: str, max_time_ms: Optional[int] = None ) -> Optional[DatabaseInfo]: @@ -231,6 +268,49 @@ def fetch_database_info( return None +async def async_fetch_database_info( + api_endpoint: str, token: str, namespace: str, max_time_ms: Optional[int] = None +) -> Optional[DatabaseInfo]: + """ + Fetch database information through the DevOps API. + Async version of the function, for use in an asyncio context. + + Args: + api_endpoint: a full API endpoint for the Data Api. + token: a valid token to access the database information. + namespace: the desired namespace that will be used in the result. + max_time_ms: a timeout, in milliseconds, for waiting on a response. + + Returns: + A DatabaseInfo object. + If the API endpoint fails to be parsed, None is returned. + For valid-looking endpoints, if something goes wrong an exception is raised. + """ + + parsed_endpoint = parse_api_endpoint(api_endpoint) + if parsed_endpoint: + gd_response = await async_fetch_raw_database_info_from_id_token( + id=parsed_endpoint.database_id, + token=token, + environment=parsed_endpoint.environment, + max_time_ms=max_time_ms, + ) + raw_info = gd_response["info"] + if namespace not in raw_info["keyspaces"]: + raise DevOpsAPIException(f"Namespace {namespace} not found on DB.") + else: + return DatabaseInfo( + id=parsed_endpoint.database_id, + region=parsed_endpoint.region, + namespace=namespace, + name=raw_info["name"], + environment=parsed_endpoint.environment, + raw_info=raw_info, + ) + else: + return None + + def _recast_as_admin_database_info( admin_database_info_dict: Dict[str, Any], *, @@ -483,6 +563,59 @@ def list_databases( ], ) + @ops_recast_method_async + async def async_list_databases( + self, + *, + max_time_ms: Optional[int] = None, + ) -> CommandCursor[AdminDatabaseInfo]: + """ + Get the list of databases, as obtained with a request to the DevOps API. + Async version of the method, for use in an asyncio context. + + Args: + max_time_ms: a timeout, in milliseconds, for the API request. + + Returns: + A CommandCursor to iterate over the detected databases, + represented as AdminDatabaseInfo objects. + Note that the return type is not an awaitable, rather + a regular iterable, e.g. for use in ordinary "for" loops. + + Example: + >>> async def check_if_db_exists(db_id: str) -> bool: + ... db_cursor = await my_astra_db_admin.async_list_databases() + ... db_list = list(dd_cursor) + ... return db_id in db_list + ... + >>> asyncio.run(check_if_db_exists("xyz")) + True + >>> asyncio.run(check_if_db_exists("01234567-...")) + False + """ + + logger.info("getting databases, async") + gd_list_response = await self._astra_db_ops.async_get_databases( + timeout_info=base_timeout_info(max_time_ms) + ) + logger.info("finished getting databases, async") + if not isinstance(gd_list_response, list): + raise DevOpsAPIException( + "Faulty response from get-databases DevOps API command.", + ) + else: + # we know this is a list of dicts which need a little adjusting + return CommandCursor( + address=self._astra_db_ops.base_url, + items=[ + _recast_as_admin_database_info( + db_dict, + environment=self.environment, + ) + for db_dict in gd_list_response + ], + ) + @ops_recast_method_sync def database_info( self, id: str, *, max_time_ms: Optional[int] = None @@ -508,12 +641,53 @@ def database_info( 'eu-west-1' """ - logger.info(f"getting database info for {id}") + logger.info(f"getting database info for '{id}'") gd_response = self._astra_db_ops.get_database( database=id, timeout_info=base_timeout_info(max_time_ms), ) - logger.info(f"finished getting database info for {id}") + logger.info(f"finished getting database info for '{id}'") + if not isinstance(gd_response, dict): + raise DevOpsAPIException( + "Faulty response from get-database DevOps API command.", + ) + else: + return _recast_as_admin_database_info( + gd_response, + environment=self.environment, + ) + + @ops_recast_method_async + async def async_database_info( + self, id: str, *, max_time_ms: Optional[int] = None + ) -> AdminDatabaseInfo: + """ + Get the full information on a given database, through a request to the DevOps API. + This is an awaitable method suitable for use within an asyncio event loop. + + Args: + id: the ID of the target database, e. g. + "01234567-89ab-cdef-0123-456789abcdef". + max_time_ms: a timeout, in milliseconds, for the API request. + + Returns: + An AdminDatabaseInfo object. + + Example: + >>> async def check_if_db_active(db_id: str) -> bool: + ... db_info = await my_astra_db_admin.async_database_info(db_id) + ... return db_info.status == "ACTIVE" + ... + >>> asyncio.run(check_if_db_active("01234567-...")) + True + """ + + logger.info(f"getting database info for '{id}', async") + gd_response = await self._astra_db_ops.async_get_database( + database=id, + timeout_info=base_timeout_info(max_time_ms), + ) + logger.info(f"finished getting database info for '{id}', async") if not isinstance(gd_response, dict): raise DevOpsAPIException( "Faulty response from get-database DevOps API command.", @@ -600,10 +774,11 @@ def create_database( while last_status_seen in {STATUS_PENDING, STATUS_INITIALIZING}: logger.info(f"sleeping to poll for status of '{new_database_id}'") time.sleep(DATABASE_POLL_SLEEP_TIME) - last_status_seen = self.database_info( + last_db_info = self.database_info( id=new_database_id, max_time_ms=timeout_manager.remaining_timeout_ms(), - ).status + ) + last_status_seen = last_db_info.status if last_status_seen != STATUS_ACTIVE: raise DevOpsAPIException( f"Database {name} entered unexpected status {last_status_seen} after PENDING" @@ -620,6 +795,106 @@ def create_database( else: raise DevOpsAPIException("Could not create the database.") + @ops_recast_method_async + async def async_create_database( + self, + name: str, + *, + cloud_provider: str, + region: str, + namespace: Optional[str] = None, + wait_until_active: bool = True, + max_time_ms: Optional[int] = None, + ) -> AstraDBDatabaseAdmin: + """ + Create a database as requested, optionally waiting for it to be ready. + This is an awaitable method suitable for use within an asyncio event loop. + + Args: + name: the desired name for the database. + namespace: name for the one namespace the database starts with. + If omitted, DevOps API will use its default. + cloud_provider: one of 'aws', 'gcp' or 'azure'. + region: any of the available cloud regions. + wait_until_active: if True (default), the method returns only after + the newly-created database is in ACTIVE state (a few minutes, + usually). If False, it will return right after issuing the + creation request to the DevOps API, and it will be responsibility + of the caller to check the database status before working with it. + max_time_ms: a timeout, in milliseconds, for the whole requested + operation to complete. + Note that a timeout is no guarantee that the creation request + has not reached the API server. + + Returns: + An AstraDBDatabaseAdmin instance. + + Example: + >>> asyncio.run( + ... my_astra_db_admin.async_create_database( + ... "new_database", + ... cloud_provider="aws", + ... region="ap-south-1", + .... ) + ... ) + AstraDBDatabaseAdmin(id=...) + """ + + database_definition = { + k: v + for k, v in { + "name": name, + "tier": "serverless", + "cloudProvider": cloud_provider, + "region": region, + "capacityUnits": 1, + "dbType": "vector", + "keyspace": namespace, + }.items() + if v is not None + } + timeout_manager = MultiCallTimeoutManager( + overall_max_time_ms=max_time_ms, exception_type="devops_api" + ) + logger.info(f"creating database {name}/({cloud_provider}, {region}), async") + cd_response = await self._astra_db_ops.async_create_database( + database_definition=database_definition, + timeout_info=base_timeout_info(max_time_ms), + ) + logger.info( + "devops api returned from creating database " + f"{name}/({cloud_provider}, {region}), async" + ) + if cd_response is not None and "id" in cd_response: + new_database_id = cd_response["id"] + if wait_until_active: + last_status_seen = STATUS_PENDING + while last_status_seen in {STATUS_PENDING, STATUS_INITIALIZING}: + logger.info( + f"sleeping to poll for status of '{new_database_id}', async" + ) + await asyncio.sleep(DATABASE_POLL_SLEEP_TIME) + last_db_info = await self.async_database_info( + id=new_database_id, + max_time_ms=timeout_manager.remaining_timeout_ms(), + ) + last_status_seen = last_db_info.status + if last_status_seen != STATUS_ACTIVE: + raise DevOpsAPIException( + f"Database {name} entered unexpected status {last_status_seen} after PENDING" + ) + # return the database instance + logger.info( + f"finished creating database '{new_database_id}' = " + f"{name}/({cloud_provider}, {region}), async" + ) + return AstraDBDatabaseAdmin.from_astra_db_admin( + id=new_database_id, + astra_db_admin=self, + ) + else: + raise DevOpsAPIException("Could not create the database.") + @ops_recast_method_sync def drop_database( self, @@ -701,6 +976,84 @@ def drop_database( f"Could not issue a successful terminate-database DevOps API request for {id}." ) + @ops_recast_method_async + async def async_drop_database( + self, + id: str, + *, + wait_until_active: bool = True, + max_time_ms: Optional[int] = None, + ) -> Dict[str, Any]: + """ + Drop a database, i.e. delete it completely and permanently with all its data. + Async version of the method, for use in an asyncio context. + + Args: + id: The ID of the database to drop, e. g. + "01234567-89ab-cdef-0123-456789abcdef". + wait_until_active: if True (default), the method returns only after + the database has actually been deleted (generally a few minutes). + If False, it will return right after issuing the + drop request to the DevOps API, and it will be responsibility + of the caller to check the database status/availability + after that, if desired. + max_time_ms: a timeout, in milliseconds, for the whole requested + operation to complete. + Note that a timeout is no guarantee that the deletion request + has not reached the API server. + + Returns: + A dictionary of the form {"ok": 1} in case of success. + Otherwise, an exception is raised. + + Example: + >>> asyncio.run( + ... my_astra_db_admin.async_drop_database("01234567-...") + ... ) + {'ok': 1} + """ + + timeout_manager = MultiCallTimeoutManager( + overall_max_time_ms=max_time_ms, exception_type="devops_api" + ) + logger.info(f"dropping database '{id}', async") + te_response = await self._astra_db_ops.async_terminate_database( + database=id, + timeout_info=base_timeout_info(max_time_ms), + ) + logger.info(f"devops api returned from dropping database '{id}', async") + if te_response == id: + if wait_until_active: + last_status_seen: Optional[str] = STATUS_TERMINATING + _db_name: Optional[str] = None + while last_status_seen == STATUS_TERMINATING: + logger.info(f"sleeping to poll for status of '{id}', async") + await asyncio.sleep(DATABASE_POLL_SLEEP_TIME) + # + detected_databases = [ + a_db_info + for a_db_info in await self.async_list_databases( + max_time_ms=timeout_manager.remaining_timeout_ms(), + ) + if a_db_info.id == id + ] + if detected_databases: + last_status_seen = detected_databases[0].status + _db_name = detected_databases[0].info.name + else: + last_status_seen = None + if last_status_seen is not None: + _name_desc = f" ({_db_name})" if _db_name else "" + raise DevOpsAPIException( + f"Database {id}{_name_desc} entered unexpected status {last_status_seen} after PENDING" + ) + logger.info(f"finished dropping database '{id}', async") + return {"ok": 1} + else: + raise DevOpsAPIException( + f"Could not issue a successful terminate-database DevOps API request for {id}." + ) + def get_database_admin(self, id: str) -> AstraDBDatabaseAdmin: """ Create an AstraDBDatabaseAdmin object for admin work within a certain database. @@ -870,8 +1223,35 @@ def create_namespace(self, name: str, *pargs: Any, **kwargs: Any) -> Dict[str, A @abstractmethod def drop_namespace(self, name: str, *pargs: Any, **kwargs: Any) -> Dict[str, Any]: """ - Drop (delete) a namespace from the database, - returning {'ok': 1} if successful. + Drop (delete) a namespace from the database, returning {'ok': 1} if successful. + """ + ... + + @abstractmethod + async def async_list_namespaces(self, *pargs: Any, **kwargs: Any) -> List[str]: + """ + Get a list of namespaces for the database. + (Async version of the method.) + """ + ... + + @abstractmethod + async def async_create_namespace( + self, name: str, *pargs: Any, **kwargs: Any + ) -> Dict[str, Any]: + """ + Create a namespace in the database, returning {'ok': 1} if successful. + (Async version of the method.) + """ + ... + + @abstractmethod + async def async_drop_namespace( + self, name: str, *pargs: Any, **kwargs: Any + ) -> Dict[str, Any]: + """ + Drop (delete) a namespace from the database, returning {'ok': 1} if successful. + (Async version of the method.) """ ... @@ -1182,6 +1562,37 @@ def info(self, *, max_time_ms: Optional[int] = None) -> AdminDatabaseInfo: logger.info(f"finished getting info ('{self.id}')") return req_response # type: ignore[no-any-return] + async def async_info( + self, *, max_time_ms: Optional[int] = None + ) -> AdminDatabaseInfo: + """ + Query the DevOps API for the full info on this database. + Async version of the method, for use in an asyncio context. + + Args: + max_time_ms: a timeout, in milliseconds, for the DevOps API request. + + Returns: + An AdminDatabaseInfo object. + + Example: + >>> async def wait_until_active(db_admin: AstraDBDatabaseAdmin) -> None: + ... while True: + ... info = await db_admin.async_info() + ... if info.status == "ACTIVE": + ... return + ... + >>> asyncio.run(wait_until_active(admin_for_my_db)) + """ + + logger.info(f"getting info ('{self.id}'), async") + req_response = await self._astra_db_admin.async_database_info( + id=self.id, + max_time_ms=max_time_ms, + ) + logger.info(f"finished getting info ('{self.id}'), async") + return req_response # type: ignore[no-any-return] + def list_namespaces(self, *, max_time_ms: Optional[int] = None) -> List[str]: """ Query the DevOps API for a list of the namespaces in the database. @@ -1205,6 +1616,40 @@ def list_namespaces(self, *, max_time_ms: Optional[int] = None) -> List[str]: else: return info.raw_info["info"]["keyspaces"] # type: ignore[no-any-return] + async def async_list_namespaces( + self, *, max_time_ms: Optional[int] = None + ) -> List[str]: + """ + Query the DevOps API for a list of the namespaces in the database. + Async version of the method, for use in an asyncio context. + + Args: + max_time_ms: a timeout, in milliseconds, for the DevOps API request. + + Returns: + A list of the namespaces, each a string, in no particular order. + + Example: + >>> async def check_if_ns_exists( + ... db_admin: AstraDBDatabaseAdmin, namespace: str + ... ) -> bool: + ... ns_list = await db_admin.async_list_namespaces() + ... return namespace in ns_list + ... + >>> asyncio.run(check_if_ns_exists(admin_for_my_db, "dragons")) + False + >>> asyncio.run(check_if_db_exists(admin_for_my_db, "app_namespace")) + True + """ + + logger.info(f"getting namespaces ('{self.id}'), async") + info = await self.async_info(max_time_ms=max_time_ms) + logger.info(f"finished getting namespaces ('{self.id}'), async") + if info.raw_info is None: + raise DevOpsAPIException("Could not get the namespace list.") + else: + return info.raw_info["info"]["keyspaces"] # type: ignore[no-any-return] + @ops_recast_method_sync def create_namespace( self, @@ -1280,6 +1725,83 @@ def create_namespace( f"Could not issue a successful create-namespace DevOps API request for {name}." ) + # the 'override' is because the error-recast decorator washes out the signature + @ops_recast_method_async + async def async_create_namespace( # type: ignore[override] + self, + name: str, + *, + wait_until_active: bool = True, + max_time_ms: Optional[int] = None, + ) -> Dict[str, Any]: + """ + Create a namespace in this database as requested, + optionally waiting for it to be ready. + Async version of the method, for use in an asyncio context. + + Args: + name: the namespace name. If supplying a namespace that exists + already, the method call proceeds as usual, no errors are + raised, and the whole invocation is a no-op. + wait_until_active: if True (default), the method returns only after + the target database is in ACTIVE state again (a few + seconds, usually). If False, it will return right after issuing the + creation request to the DevOps API, and it will be responsibility + of the caller to check the database status/namespace availability + before working with it. + max_time_ms: a timeout, in milliseconds, for the whole requested + operation to complete. + Note that a timeout is no guarantee that the creation request + has not reached the API server. + + Returns: + A dictionary of the form {"ok": 1} in case of success. + Otherwise, an exception is raised. + + Example: + >>> asyncio.run( + ... my_db_admin.async_create_namespace("app_namespace") + ... ) + {'ok': 1} + """ + + timeout_manager = MultiCallTimeoutManager( + overall_max_time_ms=max_time_ms, exception_type="devops_api" + ) + logger.info(f"creating namespace '{name}' on '{self.id}', async") + cn_response = await self._astra_db_admin._astra_db_ops.async_create_keyspace( + database=self.id, + keyspace=name, + timeout_info=base_timeout_info(max_time_ms), + ) + logger.info( + f"devops api returned from creating namespace " + f"'{name}' on '{self.id}', async" + ) + if cn_response is not None and name == cn_response.get("name"): + if wait_until_active: + last_status_seen = STATUS_MAINTENANCE + while last_status_seen == STATUS_MAINTENANCE: + logger.info(f"sleeping to poll for status of '{self.id}', async") + await asyncio.sleep(DATABASE_POLL_NAMESPACE_SLEEP_TIME) + last_db_info = await self.async_info( + max_time_ms=timeout_manager.remaining_timeout_ms(), + ) + last_status_seen = last_db_info.status + if last_status_seen != STATUS_ACTIVE: + raise DevOpsAPIException( + f"Database entered unexpected status {last_status_seen} after MAINTENANCE." + ) + # is the namespace found? + if name not in await self.async_list_namespaces(): + raise DevOpsAPIException("Could not create the namespace.") + logger.info(f"finished creating namespace '{name}' on '{self.id}', async") + return {"ok": 1} + else: + raise DevOpsAPIException( + f"Could not issue a successful create-namespace DevOps API request for {name}." + ) + @ops_recast_method_sync def drop_namespace( self, @@ -1354,6 +1876,82 @@ def drop_namespace( f"Could not issue a successful delete-namespace DevOps API request for {name}." ) + # the 'override' is because the error-recast decorator washes out the signature + @ops_recast_method_async + async def async_drop_namespace( # type: ignore[override] + self, + name: str, + *, + wait_until_active: bool = True, + max_time_ms: Optional[int] = None, + ) -> Dict[str, Any]: + """ + Delete a namespace from the database, optionally waiting for it + to become active again. + Async version of the method, for use in an asyncio context. + + Args: + name: the namespace to delete. If it does not exist in this database, + an error is raised. + wait_until_active: if True (default), the method returns only after + the target database is in ACTIVE state again (a few + seconds, usually). If False, it will return right after issuing the + deletion request to the DevOps API, and it will be responsibility + of the caller to check the database status/namespace availability + before working with it. + max_time_ms: a timeout, in milliseconds, for the whole requested + operation to complete. + Note that a timeout is no guarantee that the deletion request + has not reached the API server. + + Returns: + A dictionary of the form {"ok": 1} in case of success. + Otherwise, an exception is raised. + + Example: + >>> asyncio.run( + ... my_db_admin.async_drop_namespace("app_namespace") + ... ) + {'ok': 1} + """ + + timeout_manager = MultiCallTimeoutManager( + overall_max_time_ms=max_time_ms, exception_type="devops_api" + ) + logger.info(f"dropping namespace '{name}' on '{self.id}', async") + dk_response = await self._astra_db_admin._astra_db_ops.async_delete_keyspace( + database=self.id, + keyspace=name, + timeout_info=base_timeout_info(max_time_ms), + ) + logger.info( + f"devops api returned from dropping namespace " + f"'{name}' on '{self.id}', async" + ) + if dk_response == name: + if wait_until_active: + last_status_seen = STATUS_MAINTENANCE + while last_status_seen == STATUS_MAINTENANCE: + logger.info(f"sleeping to poll for status of '{self.id}', async") + await asyncio.sleep(DATABASE_POLL_NAMESPACE_SLEEP_TIME) + last_db_info = await self.async_info( + max_time_ms=timeout_manager.remaining_timeout_ms(), + ) + last_status_seen = last_db_info.status + if last_status_seen != STATUS_ACTIVE: + raise DevOpsAPIException( + f"Database entered unexpected status {last_status_seen} after MAINTENANCE." + ) + # is the namespace found? + if name in await self.async_list_namespaces(): + raise DevOpsAPIException("Could not drop the namespace.") + logger.info(f"finished dropping namespace '{name}' on '{self.id}', async") + return {"ok": 1} + else: + raise DevOpsAPIException( + f"Could not issue a successful delete-namespace DevOps API request for {name}." + ) + def drop( self, *, @@ -1405,6 +2003,55 @@ def drop( ) logger.info(f"finished dropping this database ('{self.id}')") + async def async_drop( + self, + *, + wait_until_active: bool = True, + max_time_ms: Optional[int] = None, + ) -> Dict[str, Any]: + """ + Drop this database, i.e. delete it completely and permanently with all its data. + Async version of the method, for use in an asyncio context. + + This method wraps the `drop_database` method of the AstraDBAdmin class, + where more information may be found. + + Args: + wait_until_active: if True (default), the method returns only after + the database has actually been deleted (generally a few minutes). + If False, it will return right after issuing the + drop request to the DevOps API, and it will be responsibility + of the caller to check the database status/availability + after that, if desired. + max_time_ms: a timeout, in milliseconds, for the whole requested + operation to complete. + Note that a timeout is no guarantee that the deletion request + has not reached the API server. + + Returns: + A dictionary of the form {"ok": 1} in case of success. + Otherwise, an exception is raised. + + Example: + >>> asyncio.run(my_db_admin.async_drop()) + {'ok': 1} + + Note: + Once the method succeeds, methods on this object -- such as `info()`, + or `list_namespaces()` -- can still be invoked: however, this hardly + makes sense as the underlying actual database is no more. + It is responsibility of the developer to design a correct flow + which avoids using a deceased database any further. + """ + + logger.info(f"dropping this database ('{self.id}'), async") + return await self._astra_db_admin.async_drop_database( # type: ignore[no-any-return] + id=self.id, + wait_until_active=wait_until_active, + max_time_ms=max_time_ms, + ) + logger.info(f"finished dropping this database ('{self.id}'), async") + def get_database( self, *, diff --git a/astrapy/core/ops.py b/astrapy/core/ops.py index 3e2dcd90..8a63c860 100644 --- a/astrapy/core/ops.py +++ b/astrapy/core/ops.py @@ -18,7 +18,13 @@ from typing import Any, cast, Dict, Optional, TypedDict import httpx -from astrapy.core.api import APIRequestError, api_request, raw_api_request +from astrapy.core.api import ( + APIRequestError, + api_request, + async_api_request, + raw_api_request, + async_raw_api_request, +) from astrapy.core.utils import ( http_methods, @@ -45,8 +51,9 @@ class AstraDBOpsConstructorParams(TypedDict): class AstraDBOps: - # Initialize the shared httpx client as a class attribute + # Initialize the shared httpx clients as class attributes client = httpx.Client() + async_client = httpx.AsyncClient() def __init__( self, @@ -140,6 +147,31 @@ def _ops_request( ) return raw_response + async def _async_ops_request( + self, + method: str, + path: str, + options: Optional[Dict[str, Any]] = None, + json_data: Optional[Dict[str, Any]] = None, + timeout_info: TimeoutInfoWideType = None, + ) -> httpx.Response: + _options = {} if options is None else options + + raw_response = await async_raw_api_request( + client=self.async_client, + base_url=self.base_url, + auth_header=DEFAULT_DEV_OPS_AUTH_HEADER, + token=self.token, + method=method, + json_data=json_data, + url_params=_options, + path=path, + caller_name=self.caller_name, + caller_version=self.caller_version, + timeout=to_httpx_timeout(timeout_info), + ) + return raw_response + def _json_ops_request( self, method: str, @@ -166,6 +198,32 @@ def _json_ops_request( ) return response + async def _async_json_ops_request( + self, + method: str, + path: str, + options: Optional[Dict[str, Any]] = None, + json_data: Optional[Dict[str, Any]] = None, + timeout_info: TimeoutInfoWideType = None, + ) -> OPS_API_RESPONSE: + _options = {} if options is None else options + + response = await async_api_request( + client=self.async_client, + base_url=self.base_url, + auth_header="Authorization", + token=self.token, + method=method, + json_data=json_data, + url_params=_options, + path=path, + skip_error_check=False, + caller_name=None, + caller_version=None, + timeout=to_httpx_timeout(timeout_info), + ) + return response + def get_databases( self, options: Optional[Dict[str, Any]] = None, @@ -189,6 +247,29 @@ def get_databases( return response + async def async_get_databases( + self, + options: Optional[Dict[str, Any]] = None, + timeout_info: TimeoutInfoWideType = None, + ) -> OPS_API_RESPONSE: + """ + Retrieve a list of databases - async version of the method. + + Args: + options (dict, optional): Additional options for the request. + + Returns: + list: a JSON list of dictionaries, one per database. + """ + response = await self._async_json_ops_request( + method=http_methods.GET, + path="/databases", + options=options, + timeout_info=timeout_info, + ) + + return response + def create_database( self, database_definition: Optional[Dict[str, Any]] = None, @@ -219,6 +300,36 @@ def create_database( else: raise ValueError(f"[HTTP {r.status_code}] {r.text}") + async def async_create_database( + self, + database_definition: Optional[Dict[str, Any]] = None, + timeout_info: TimeoutInfoWideType = None, + ) -> Dict[str, str]: + """ + Create a new database - async version of the method. + + Args: + database_definition (dict, optional): A dictionary defining the properties of the database to be created. + timeout_info: either a float (seconds) or a TimeoutInfo dict (see) + + Returns: + dict: A dictionary such as: {"id": the ID of the created database} + Raises an error if not successful. + """ + r = await self._async_ops_request( + method=http_methods.POST, + path="/databases", + json_data=database_definition, + timeout_info=timeout_info, + ) + + if r.status_code == 201: + return {"id": r.headers["Location"]} + elif r.status_code >= 400 and r.status_code < 500: + raise APIRequestError(r, payload=database_definition) + else: + raise ValueError(f"[HTTP {r.status_code}] {r.text}") + def terminate_database( self, database: str = "", timeout_info: TimeoutInfoWideType = None ) -> str: @@ -247,6 +358,34 @@ def terminate_database( return None + async def async_terminate_database( + self, database: str = "", timeout_info: TimeoutInfoWideType = None + ) -> str: + """ + Terminate an existing database - async version of the method. + + Args: + database (str): The identifier of the database to terminate. + timeout_info: either a float (seconds) or a TimeoutInfo dict (see) + + Returns: + str: The identifier of the terminated database, or None if termination was unsuccessful. + """ + r = await self._async_ops_request( + method=http_methods.POST, + path=f"/databases/{database}/terminate", + timeout_info=timeout_info, + ) + + if r.status_code == 202: + return database + elif r.status_code >= 400 and r.status_code < 500: + raise APIRequestError(r, payload=None) + else: + raise ValueError(f"[HTTP {r.status_code}] {r.text}") + + return None + def get_database( self, database: str = "", @@ -273,6 +412,32 @@ def get_database( ), ) + async def async_get_database( + self, + database: str = "", + options: Optional[Dict[str, Any]] = None, + timeout_info: TimeoutInfoWideType = None, + ) -> API_RESPONSE: + """ + Retrieve details of a specific database - async version of the method. + + Args: + database (str): The identifier of the database to retrieve. + options (dict, optional): Additional options for the request. + + Returns: + dict: A JSON response containing the details of the specified database. + """ + return cast( + API_RESPONSE, + await self._async_json_ops_request( + method=http_methods.GET, + path=f"/databases/{database}", + options=options, + timeout_info=timeout_info, + ), + ) + def create_keyspace( self, database: str = "", @@ -303,6 +468,36 @@ def create_keyspace( else: raise ValueError(f"[HTTP {r.status_code}] {r.text}") + async def async_create_keyspace( + self, + database: str = "", + keyspace: str = "", + timeout_info: TimeoutInfoWideType = None, + ) -> Dict[str, str]: + """ + Create a keyspace in a specified database - async version of the method. + + Args: + database (str): The identifier of the database where the keyspace will be created. + keyspace (str): The name of the keyspace to create. + timeout_info: either a float (seconds) or a TimeoutInfo dict (see) + + Returns: + {"ok": 1} if successful. Raises errors otherwise. + """ + r = await self._async_ops_request( + method=http_methods.POST, + path=f"/databases/{database}/keyspaces/{keyspace}", + timeout_info=timeout_info, + ) + + if r.status_code == 201: + return {"name": keyspace} + elif r.status_code >= 400 and r.status_code < 500: + raise APIRequestError(r, payload=None) + else: + raise ValueError(f"[HTTP {r.status_code}] {r.text}") + def delete_keyspace( self, database: str = "", @@ -333,6 +528,36 @@ def delete_keyspace( else: raise ValueError(f"[HTTP {r.status_code}] {r.text}") + async def async_delete_keyspace( + self, + database: str = "", + keyspace: str = "", + timeout_info: TimeoutInfoWideType = None, + ) -> str: + """ + Delete a keyspace from a database - async version of the method. + + Args: + database (str): The identifier of the database to terminate. + keyspace (str): The name of the keyspace to create. + timeout_info: either a float (seconds) or a TimeoutInfo dict (see) + + Returns: + str: The identifier of the deleted keyspace. Otherwise raises an error. + """ + r = await self._async_ops_request( + method=http_methods.DELETE, + path=f"/databases/{database}/keyspaces/{keyspace}", + timeout_info=timeout_info, + ) + + if r.status_code == 202: + return keyspace + elif r.status_code >= 400 and r.status_code < 500: + raise APIRequestError(r, payload=None) + else: + raise ValueError(f"[HTTP {r.status_code}] {r.text}") + def park_database( self, database: str = "", timeout_info: TimeoutInfoWideType = None ) -> OPS_API_RESPONSE: diff --git a/astrapy/exceptions.py b/astrapy/exceptions.py index 661dbd50..899d6f2e 100644 --- a/astrapy/exceptions.py +++ b/astrapy/exceptions.py @@ -715,6 +715,29 @@ def _wrapped_sync(*pargs: Any, **kwargs: Any) -> Any: return _wrapped_sync +def ops_recast_method_async( + method: Callable[..., Awaitable[Any]] +) -> Callable[..., Awaitable[Any]]: + """ + Decorator for an async DevOps method liable to generate the core APIRequestError. + That exception is intercepted and recast as DevOpsAPIException. + Moreover, timeouts are also caught and converted into Data API timeouts. + """ + + @wraps(method) + async def _wrapped_async(*pargs: Any, **kwargs: Any) -> Any: + try: + return await method(*pargs, **kwargs) + except APIRequestError as exc: + raise DevOpsAPIResponseException.from_response( + command=exc.payload, raw_response=exc.response.json() + ) + except httpx.TimeoutException as texc: + raise to_dataapi_timeout_exception(texc) + + return _wrapped_async + + def base_timeout_info(max_time_ms: Optional[int]) -> Union[TimeoutInfo, None]: if max_time_ms is not None: return {"base": max_time_ms / 1000.0} diff --git a/tests/idiomatic/integration/test_admin.py b/tests/idiomatic/integration/test_admin.py index 70e46944..5b09c9d2 100644 --- a/tests/idiomatic/integration/test_admin.py +++ b/tests/idiomatic/integration/test_admin.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, List, Optional, Tuple +from typing import Any, Awaitable, Callable, List, Optional, Tuple import pytest @@ -72,11 +72,21 @@ def wait_until_true( time.sleep(poll_interval) +async def await_until_true( + poll_interval: int, max_seconds: int, acondition: Callable[..., Awaitable[bool]] +) -> None: + ini_time = time.time() + while not (await acondition()): + if time.time() - ini_time > max_seconds: + raise ValueError("Timed out on condition.") + time.sleep(poll_interval) + + @pytest.mark.skipif(not DO_IDIOMATIC_ADMIN_TESTS, reason="Admin tests are suppressed") class TestAdmin: @pytest.mark.parametrize("env_token", admin_test_envs_tokens()) - @pytest.mark.describe("test of the full tour with AstraDBDatabaseAdmin") - def test_astra_db_database_admin(self, env_token: Tuple[str, str]) -> None: + @pytest.mark.describe("test of the full tour with AstraDBDatabaseAdmin, sync") + def test_astra_db_database_admin_sync(self, env_token: Tuple[str, str]) -> None: """ Test plan (it has to be a single giant test to use one DB throughout): - create client -> get_admin @@ -187,8 +197,10 @@ def test_astra_db_database_admin(self, env_token: Tuple[str, str]) -> None: assert created_db_id not in db_ids @pytest.mark.parametrize("env_token", admin_test_envs_tokens()) - @pytest.mark.describe("test of the full tour with AstraDBAdmin and client methods") - def test_astra_db_admin(self, env_token: Tuple[str, str]) -> None: + @pytest.mark.describe( + "test of the full tour with AstraDBAdmin and client methods, sync" + ) + def test_astra_db_admin_sync(self, env_token: Tuple[str, str]) -> None: """ Test plan (it has to be a single giant test to use the two DBs throughout): - create client -> get_admin @@ -307,3 +319,263 @@ def _waiter2() -> bool: max_seconds=DATABASE_TIMEOUT, condition=_waiter2, ) + + @pytest.mark.parametrize("env_token", admin_test_envs_tokens()) + @pytest.mark.describe("test of the full tour with AstraDBDatabaseAdmin, async") + async def test_astra_db_database_admin_async( + self, env_token: Tuple[str, str] + ) -> None: + """ + Test plan (it has to be a single giant test to use one DB throughout): + - create client -> get_admin + - create a db (wait) + - with the AstraDBDatabaseAdmin: + - info + - list namespaces, check + - create 2 namespaces (wait, nonwait) + - list namespaces, check + - get_database -> create_collection/list_collection_names + - get_async_database, check if == previous + - drop namespaces (wait, nonwait) + - list namespaces, check + - drop database (wait) + - check DB not existings + """ + env, token = env_token + db_name = f"test_database_{env}" + db_provider = os.environ[f"{env.upper()}_ADMIN_TEST_ASTRA_DB_PROVIDER"] + db_region = os.environ[f"{env.upper()}_ADMIN_TEST_ASTRA_DB_REGION"] + + # create client, get admin + client: DataAPIClient + if env == "prod": + client = DataAPIClient(token) + else: + client = DataAPIClient(token, environment=env) + admin = client.get_admin() + + # create a db (wait) + db_admin = await admin.async_create_database( + name=db_name, + namespace="custom_namespace", + wait_until_active=True, + cloud_provider=db_provider, + region=db_region, + ) + + # info with the AstraDBDatabaseAdmin + created_db_id = db_admin.id + assert (await db_admin.async_info()).id == created_db_id + + # list nss + namespaces1 = set(await db_admin.async_list_namespaces()) + assert namespaces1 == {"custom_namespace"} + + # create two namespaces + w_create_ns_response = await db_admin.async_create_namespace( + "waited_ns", + wait_until_active=True, + ) + assert w_create_ns_response == {"ok": 1} + + nw_create_ns_response = await db_admin.async_create_namespace( + "nonwaited_ns", + wait_until_active=False, + ) + assert nw_create_ns_response == {"ok": 1} + + async def _awaiter1() -> bool: + return "nonwaited_ns" in (await db_admin.async_list_namespaces()) + + await await_until_true( + poll_interval=NAMESPACE_POLL_SLEEP_TIME, + max_seconds=NAMESPACE_TIMEOUT, + acondition=_awaiter1, + ) + + namespaces3 = set(await db_admin.async_list_namespaces()) + assert namespaces3 - namespaces1 == {"waited_ns", "nonwaited_ns"} + + # get db and use it + adb = db_admin.get_async_database() + await adb.create_collection("canary_coll") + assert "canary_coll" in (await adb.list_collection_names()) + + # check sync db is the same + assert db_admin.get_database().to_async() == adb + + # drop nss, wait, nonwait + w_drop_ns_response = await db_admin.async_drop_namespace( + "waited_ns", + wait_until_active=True, + ) + assert w_drop_ns_response == {"ok": 1} + + nw_drop_ns_response = await db_admin.async_drop_namespace( + "nonwaited_ns", + wait_until_active=False, + ) + assert nw_drop_ns_response == {"ok": 1} + + async def _awaiter2() -> bool: + ns_list = await db_admin.async_list_namespaces() + return "nonwaited_ns" not in ns_list + + await await_until_true( + poll_interval=NAMESPACE_POLL_SLEEP_TIME, + max_seconds=NAMESPACE_TIMEOUT, + acondition=_awaiter2, + ) + + # check nss after dropping two of them + namespaces1b = set(await db_admin.async_list_namespaces()) + assert namespaces1b == namespaces1 + + async def _awaiter3() -> bool: + a_info = await db_admin.async_info() + return a_info.status == "ACTIVE" # type: ignore[no-any-return] + + # drop db and check. We wait a little due to "nontransactional cluster md" + await await_until_true( + poll_interval=PRE_DROP_SAFETY_POLL_INTERVAL, + max_seconds=PRE_DROP_SAFETY_TIMEOUT, + acondition=_awaiter3, + ) + db_drop_response = await db_admin.async_drop() + assert db_drop_response == {"ok": 1} + + db_ids = {db.id for db in (await admin.async_list_databases())} + assert created_db_id not in db_ids + + @pytest.mark.parametrize("env_token", admin_test_envs_tokens()) + @pytest.mark.describe( + "test of the full tour with AstraDBAdmin and client methods, async" + ) + async def test_astra_db_admin_async(self, env_token: Tuple[str, str]) -> None: + """ + Test plan (it has to be a single giant test to use the two DBs throughout): + - create client -> get_admin + - create two dbs (wait, nonwait) + - list the two database ids, check + - get check info on one such db through admin + - with the client: + 4 get_dbs (a/sync, by id+region/api_endpoint), check if equal + and test their list_collections + - get_db_admin from the admin for one of the dbs + - create ns + - get_database -> list_collection_names + - get_async_database and check == with above + - drop dbs, (wait, nonwait) + """ + env, token = env_token + db_name_w = f"test_database_w_{env}" + db_name_nw = f"test_database_nw_{env}" + db_provider = os.environ[f"{env.upper()}_ADMIN_TEST_ASTRA_DB_PROVIDER"] + db_region = os.environ[f"{env.upper()}_ADMIN_TEST_ASTRA_DB_REGION"] + + # create client and get admin + client: DataAPIClient + if env == "prod": + client = DataAPIClient(token) + else: + client = DataAPIClient(token, environment=env) + admin = client.get_admin() + + # create the two dbs + db_admin_nw = await admin.async_create_database( + name=db_name_nw, + wait_until_active=False, + cloud_provider=db_provider, + region=db_region, + ) + created_db_id_nw = db_admin_nw.id + db_admin_w = await admin.async_create_database( + name=db_name_w, + wait_until_active=True, + cloud_provider=db_provider, + region=db_region, + ) + created_db_id_w = db_admin_w.id + + async def _awaiter1() -> bool: + db_ids = {db.id for db in (await admin.async_list_databases())} + return created_db_id_nw in db_ids + + await await_until_true( + poll_interval=DATABASE_POLL_SLEEP_TIME, + max_seconds=DATABASE_TIMEOUT, + acondition=_awaiter1, + ) + + # list, check ids + db_ids = {db.id for db in (await admin.async_list_databases())} + assert {created_db_id_nw, created_db_id_w} - db_ids == set() + + # get info through admin + db_w_info = await admin.async_database_info(created_db_id_w) + assert db_w_info.id == created_db_id_w + + # get and compare dbs obtained by the client + synthetic_api_endpoint = API_ENDPOINT_TEMPLATE_MAP[env].format( + database_id=created_db_id_w, + region=db_region, + ) + adb_w_d = client.get_async_database(created_db_id_w) + adb_w_r = client.get_async_database(created_db_id_w, region=db_region) + adb_w_e = client.get_async_database_by_api_endpoint(synthetic_api_endpoint) + db_w_d = client.get_database(created_db_id_w) + db_w_r = client.get_database(created_db_id_w, region=db_region) + db_w_e = client.get_database_by_api_endpoint(synthetic_api_endpoint) + assert isinstance(await adb_w_d.list_collection_names(), list) + assert adb_w_r == adb_w_d + assert adb_w_e == adb_w_d + assert db_w_d.to_async() == adb_w_d + assert db_w_r.to_async() == adb_w_d + assert db_w_e.to_async() == adb_w_d + + # get db admin from the admin and use it + db_w_admin = admin.get_database_admin(created_db_id_w) + await db_w_admin.async_create_namespace("additional_namespace") + adb_w_from_admin = db_w_admin.get_async_database() + assert isinstance(await adb_w_from_admin.list_collection_names(), list) + db_w_from_admin = db_w_admin.get_database() + assert db_w_from_admin.to_async() == adb_w_from_admin + + # drop databases: the w one through the admin, the nw using its db-admin + # (this covers most cases if combined with the + # (w, using db-admin) of test_astra_db_database_admin) + assert db_admin_nw == admin.get_database_admin(created_db_id_nw) + # drop db and check. We wait a little due to "nontransactional cluster md" + + async def _awaiter2() -> bool: + a_info = await db_admin_nw.async_info() + return a_info.status == "ACTIVE" # type: ignore[no-any-return] + + async def _awaiter3() -> bool: + a_info = await db_admin_w.async_info() + return a_info.status == "ACTIVE" # type: ignore[no-any-return] + + await await_until_true( + poll_interval=PRE_DROP_SAFETY_POLL_INTERVAL, + max_seconds=PRE_DROP_SAFETY_TIMEOUT, + acondition=_awaiter2, + ) + await await_until_true( + poll_interval=PRE_DROP_SAFETY_POLL_INTERVAL, + max_seconds=PRE_DROP_SAFETY_TIMEOUT, + acondition=_awaiter3, + ) + drop_nw_response = await db_admin_nw.async_drop(wait_until_active=False) + assert drop_nw_response == {"ok": 1} + drop_w_response = await admin.async_drop_database(created_db_id_w) + assert drop_w_response == {"ok": 1} + + async def _awaiter4() -> bool: + db_ids = {db.id for db in (await admin.async_list_databases())} + return created_db_id_nw not in db_ids + + await await_until_true( + poll_interval=DATABASE_POLL_SLEEP_TIME, + max_seconds=DATABASE_TIMEOUT, + acondition=_awaiter4, + )