Skip to content

Commit

Permalink
Add drop database command
Browse files Browse the repository at this point in the history
  • Loading branch information
DataDaoDe committed Jan 27, 2025
1 parent cf8428c commit 61fefca
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 12 deletions.
33 changes: 29 additions & 4 deletions advanced_alchemy/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,8 +402,17 @@ async def _dump_tables() -> None:
@database_group.command(name="create", help="Create a new database.")
@bind_key_option
@no_prompt_option
def create_database(bind_key: str | None, no_prompt: bool) -> None: # pyright: ignore[reportUnusedFunction]
@click.option(
"--encoding",
"encoding",
help="Set the encoding for the created database",
type=str,
required=False,
default="utf8",
)
def create_database(bind_key: str | None, no_prompt: bool, encoding: str) -> None: # pyright: ignore[reportUnusedFunction]
from rich.prompt import Confirm

from advanced_alchemy.utils.databases import create_database as _create_database

ctx = click.get_current_context()
Expand All @@ -416,10 +425,26 @@ def create_database(bind_key: str | None, no_prompt: bool) -> None: # pyright:
True if no_prompt else Confirm.ask(f"[bold]Are you sure you want to create a new database `{db_name}`?[/]")
)
if input_confirmed:
_create_database(sqlalchemy_config)
_create_database(sqlalchemy_config, encoding)

@database_group.command(name="drop", help="Drop the current database.")
def drop_database() -> None: # pyright: ignore[reportUnusedFunction]
raise NotImplementedError
@bind_key_option
@no_prompt_option
def drop_database(bind_key: str | None, no_prompt: bool) -> None: # pyright: ignore[reportUnusedFunction]
from rich.prompt import Confirm

from advanced_alchemy.utils.databases import drop_database as _drop_database

ctx = click.get_current_context()
sqlalchemy_config = get_config_by_bind_key(ctx, bind_key)

db_name = sqlalchemy_config.get_engine().url.database

console.rule("[yellow]Starting database deletion process[/]", align="left")
input_confirmed = (
True if no_prompt else Confirm.ask(f"[bold]Are you sure you want to drop database `{db_name}`?[/]")
)
if input_confirmed:
_drop_database(sqlalchemy_config)

return database_group
113 changes: 105 additions & 8 deletions advanced_alchemy/utils/databases.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
from __future__ import annotations

import asyncio
import os
from copy import copy
from pathlib import Path
from typing import TYPE_CHECKING

from sqlalchemy import Engine, text
from sqlalchemy.ext.asyncio import AsyncEngine

from advanced_alchemy.config.asyncio import SQLAlchemyAsyncConfig
from advanced_alchemy.config.sync import SQLAlchemySyncConfig
from advanced_alchemy.utils.dataclass import Empty

if TYPE_CHECKING:
from sqlalchemy.engine.url import URL

from advanced_alchemy.config.asyncio import SQLAlchemyAsyncConfig
from advanced_alchemy.config.sync import SQLAlchemySyncConfig


def set_engine_database_url(url: URL, database: str | None) -> URL:
if hasattr(url, "_replace"):
Expand Down Expand Up @@ -44,14 +48,13 @@ def create_database(config: SQLAlchemyAsyncConfig | SQLAlchemySyncConfig, encodi
config.engine_config.isolation_level = "AUTOCOMMIT"

engine = config.create_engine_callable(str(url))

if isinstance(engine, Engine):
create_sync_database(engine, database, encoding)
if isinstance(engine, AsyncEngine):
asyncio.run(create_database_async(engine, database, encoding))
else:
asyncio.run(create_async_database(engine, database, encoding))
create_database_sync(engine, database, encoding)


def create_sync_database(engine: Engine, database: str | None, encoding: str = "utf8") -> None:
def create_database_sync(engine: Engine, database: str | None, encoding: str = "utf8") -> None:
dialect_name = engine.url.get_dialect().name
if dialect_name == "postgresql":
with engine.begin() as conn:
Expand All @@ -77,7 +80,7 @@ def create_sync_database(engine: Engine, database: str | None, encoding: str = "
engine.dispose()


async def create_async_database(engine: AsyncEngine, database: str | None, encoding: str = "utf8") -> None:
async def create_database_async(engine: AsyncEngine, database: str | None, encoding: str = "utf8") -> None:
dialect_name = engine.url.get_dialect().name
if dialect_name == "postgresql":
async with engine.begin() as conn:
Expand All @@ -101,3 +104,97 @@ async def create_async_database(engine: AsyncEngine, database: str | None, encod
await conn.execute(text(sql))

await engine.dispose()


def drop_database(config: SQLAlchemyAsyncConfig | SQLAlchemySyncConfig) -> None:
url = config.get_engine().url
database = url.database
dialect_name = url.get_dialect().name
dialect_driver = url.get_dialect().driver

if dialect_name == "postgresql":
url = set_engine_database_url(url, database="postgres")
elif dialect_name == "mssql":
url = set_engine_database_url(url, database="master")
elif dialect_name == "cockroachdb":
url = set_engine_database_url(url, database="defaultdb")
elif dialect_name != "sqlite":
url = set_engine_database_url(url, database=None)

if dialect_name == "mssql" and dialect_driver in {"pymssql", "pyodbc"}:
if config.engine_config.connect_args is Empty:
config.engine_config.connect_args = {}
else:
config.engine_config.connect_args["autocommit"] = True # type: ignore # noqa: PGH003
elif dialect_name == "postgresql" and dialect_driver in {
"asyncpg",
"pg8000",
"psycopg",
"psycopg2",
"psycopg2cffi",
}:
config.engine_config.isolation_level = "AUTOCOMMIT"

engine = config.create_engine_callable(str(url))
if isinstance(engine, AsyncEngine):
asyncio.run(drop_database_async(engine, database))
else:
drop_database_sync(engine, database)


async def drop_database_async(engine: AsyncEngine, database: str | None) -> None:
dialect_name = engine.url.get_dialect().name
if dialect_name == "sqlite" and database != ":memory:":
if database:
Path(database).unlink()
elif dialect_name == "postgresql":
async with engine.begin() as conn:
# Disconnect all users from the database we are dropping.
version = conn.dialect.server_version_info
pid_column = ("pid" if version >= (9, 2) else "procpid") if version else "procpid"
sql = f"""
SELECT pg_terminate_backend(pg_stat_activity.{pid_column})
FROM pg_stat_activity
WHERE pg_stat_activity.datname = '{database}'
AND {pid_column} <> pg_backend_pid();
""" # noqa: S608
await conn.execute(text(sql))

# Drop the database.
sql = f"DROP DATABASE {database}"
await conn.execute(text(sql))
else:
async with engine.begin() as conn:
sql = f"DROP DATABASE {database}"
await conn.execute(text(sql))

await engine.dispose()


def drop_database_sync(engine: Engine, database: str | None) -> None:
dialect_name = engine.url.get_dialect().name
if dialect_name == "sqlite" and database != ":memory:":
if database:
Path(database).unlink()
elif dialect_name == "postgresql":
with engine.begin() as conn:
# Disconnect all users from the database we are dropping.
version = conn.dialect.server_version_info
pid_column = ("pid" if version >= (9, 2) else "procpid") if version else "procpid"
sql = f"""
SELECT pg_terminate_backend(pg_stat_activity.{pid_column})
FROM pg_stat_activity
WHERE pg_stat_activity.datname = '{database}'
AND {pid_column} <> pg_backend_pid();
""" # noqa: S608
conn.execute(text(sql))

# Drop the database.
sql = f"DROP DATABASE {database}"
conn.execute(text(sql))
else:
with engine.begin() as conn:
sql = f"DROP DATABASE {database}"
conn.execute(text(sql))

engine.dispose()

0 comments on commit 61fefca

Please sign in to comment.