From bdc26f8c49270065dc2079d5e4b7c7477371172b Mon Sep 17 00:00:00 2001 From: johncmerfeld Date: Tue, 19 Nov 2024 08:51:17 -0600 Subject: [PATCH] alter where bool cast is --- airflow/utils/db.py | 7 +++---- tests/utils/test_db.py | 14 ++++++++++++++ 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/airflow/utils/db.py b/airflow/utils/db.py index 41e4e411e0b1a..c6b4a9edde19a 100644 --- a/airflow/utils/db.py +++ b/airflow/utils/db.py @@ -1458,7 +1458,8 @@ def check_query_exists(query_stmt: Select, *, session: Session) -> bool: :meta private: """ count_stmt = select(literal(True)).select_from(query_stmt.order_by(None).subquery()) - return session.scalar(count_stmt) + # we must cast to bool because scalar() can return None + return bool(session.scalar(count_stmt)) def exists_query(*where: ClauseElement, session: Session) -> bool: @@ -1557,9 +1558,7 @@ def __setstate__(self, state: Any) -> None: self._session = get_current_task_instance_session() def __bool__(self) -> bool: - if check := check_query_exists(self._select_asc, session=self._session) is not None: - return check - return False + return check_query_exists(self._select_asc, session=self._session) def __eq__(self, other: Any) -> bool: if not isinstance(other, collections.abc.Sequence): diff --git a/tests/utils/test_db.py b/tests/utils/test_db.py index 1d3412d361412..b05932ab689b2 100644 --- a/tests/utils/test_db.py +++ b/tests/utils/test_db.py @@ -37,6 +37,7 @@ from airflow.models import Base as airflow_base from airflow.settings import engine from airflow.utils.db import ( + LazySelectSequence, _get_alembic_config, check_migrations, compare_server_default, @@ -56,6 +57,12 @@ pytestmark = [pytest.mark.db_test, pytest.mark.skip_if_database_isolation_mode] +class EmptyLazySelectSequence(LazySelectSequence): + _data = [] + + def __init__(self): + super().__init__(None, None, session="MockSession") + class TestDb: def test_database_schema_and_sqlalchemy_model_are_in_sync(self): import airflow.models @@ -251,3 +258,10 @@ def test_alembic_configuration(self): import airflow assert config.config_file_name == os.path.join(os.path.dirname(airflow.__file__), "alembic.ini") + + @mock.patch("sqlalchemy.orm.Session.scalar") + def test_bool_lazy_select_sequence(self, mock_scalar): + mock_scalar.return_value = None + + lss = EmptyLazySelectSequence() + assert not bool(lss) \ No newline at end of file