diff --git a/tests/sqlmodel/conftest.py b/tests/sqlmodel/conftest.py index aaa0120..63baef3 100644 --- a/tests/sqlmodel/conftest.py +++ b/tests/sqlmodel/conftest.py @@ -271,6 +271,7 @@ class TaskRead(TaskReadSub): assignee: Optional[UserReadSub] client: Optional[ClientRead] + def is_docker_running() -> bool: # pragma: no cover try: DockerClient() @@ -278,6 +279,7 @@ def is_docker_running() -> bool: # pragma: no cover except Exception: return False + async_engine = create_async_engine( "sqlite+aiosqlite:///:memory:", echo=True, future=True ) @@ -299,14 +301,16 @@ async def _setup_database(url: str) -> AsyncGenerator[AsyncSession]: @pytest_asyncio.fixture(scope="function") -async def async_session(request: pytest.FixtureRequest) -> AsyncGenerator[AsyncSession]: # pragma: no cover +async def async_session( + request: pytest.FixtureRequest, +) -> AsyncGenerator[AsyncSession]: # pragma: no cover dialect_marker = request.node.get_closest_marker("dialect") dialect = dialect_marker.args[0] if dialect_marker else "sqlite" - + if dialect == "postgresql" or dialect == "mysql": if not is_docker_running(): pytest.skip("Docker is required, but not running") - + if dialect == "postgresql": with PostgresContainer() as postgres: url = postgres.get_connection_url() @@ -314,7 +318,9 @@ async def async_session(request: pytest.FixtureRequest) -> AsyncGenerator[AsyncS yield session elif dialect == "mysql": with MySqlContainer() as mysql: - url = make_url(mysql.get_connection_url())._replace(drivername="mysql+aiomysql") + url = make_url(mysql.get_connection_url())._replace( + drivername="mysql+aiomysql" + ) async with _setup_database(url) as session: yield session else: