Skip to content

Commit

Permalink
alter where bool cast is
Browse files Browse the repository at this point in the history
  • Loading branch information
johncmerfeld committed Nov 19, 2024
1 parent 9ffaaf3 commit bdc26f8
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 4 deletions.
7 changes: 3 additions & 4 deletions airflow/utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -1458,7 +1458,8 @@ def check_query_exists(query_stmt: Select, *, session: Session) -> bool:
:meta private:
"""
count_stmt = select(literal(True)).select_from(query_stmt.order_by(None).subquery())
return session.scalar(count_stmt)
# we must cast to bool because scalar() can return None
return bool(session.scalar(count_stmt))


def exists_query(*where: ClauseElement, session: Session) -> bool:
Expand Down Expand Up @@ -1557,9 +1558,7 @@ def __setstate__(self, state: Any) -> None:
self._session = get_current_task_instance_session()

def __bool__(self) -> bool:
if check := check_query_exists(self._select_asc, session=self._session) is not None:
return check
return False
return check_query_exists(self._select_asc, session=self._session)

def __eq__(self, other: Any) -> bool:
if not isinstance(other, collections.abc.Sequence):
Expand Down
14 changes: 14 additions & 0 deletions tests/utils/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from airflow.models import Base as airflow_base
from airflow.settings import engine
from airflow.utils.db import (
LazySelectSequence,
_get_alembic_config,
check_migrations,
compare_server_default,
Expand All @@ -56,6 +57,12 @@
pytestmark = [pytest.mark.db_test, pytest.mark.skip_if_database_isolation_mode]


class EmptyLazySelectSequence(LazySelectSequence):
_data = []

def __init__(self):
super().__init__(None, None, session="MockSession")

class TestDb:
def test_database_schema_and_sqlalchemy_model_are_in_sync(self):
import airflow.models
Expand Down Expand Up @@ -251,3 +258,10 @@ def test_alembic_configuration(self):
import airflow

assert config.config_file_name == os.path.join(os.path.dirname(airflow.__file__), "alembic.ini")

@mock.patch("sqlalchemy.orm.Session.scalar")
def test_bool_lazy_select_sequence(self, mock_scalar):
mock_scalar.return_value = None

lss = EmptyLazySelectSequence()
assert not bool(lss)

0 comments on commit bdc26f8

Please sign in to comment.