From bbf0aaec5da77a7130495eab41fabfc8089ab1d2 Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Fri, 18 Aug 2023 11:59:43 -0700 Subject: [PATCH] chore: rename `get_iterable` (#24994) --- superset/connectors/base/models.py | 4 ++-- superset/daos/base.py | 10 +++++----- superset/utils/core.py | 9 ++++++--- tests/integration_tests/utils_tests.py | 10 +++++----- 4 files changed, 18 insertions(+), 15 deletions(-) diff --git a/superset/connectors/base/models.py b/superset/connectors/base/models.py index bd730c2406efd..d5386c7a66c3e 100644 --- a/superset/connectors/base/models.py +++ b/superset/connectors/base/models.py @@ -334,7 +334,7 @@ def data_for_slices( # pylint: disable=too-many-locals form_data = slc.form_data # pull out all required metrics from the form_data for metric_param in METRIC_FORM_DATA_PARAMS: - for metric in utils.get_iterable(form_data.get(metric_param) or []): + for metric in utils.as_list(form_data.get(metric_param) or []): metric_names.add(utils.get_metric_name(metric)) if utils.is_adhoc_metric(metric): column = metric.get("column") or {} @@ -377,7 +377,7 @@ def data_for_slices( # pylint: disable=too-many-locals if utils.is_adhoc_column(column) else column for column_param in COLUMN_FORM_DATA_PARAMS - for column in utils.get_iterable(form_data.get(column_param) or []) + for column in utils.as_list(form_data.get(column_param) or []) ] column_names.update(_columns) diff --git a/superset/daos/base.py b/superset/daos/base.py index c96275c414a62..a69f07da5dc89 100644 --- a/superset/daos/base.py +++ b/superset/daos/base.py @@ -16,7 +16,7 @@ # under the License. from __future__ import annotations -from typing import Any, Generic, get_args, TypeVar +from typing import Any, cast, Generic, get_args, TypeVar from flask_appbuilder.models.filters import BaseFilter from flask_appbuilder.models.sqla import Model @@ -30,7 +30,7 @@ DAOUpdateFailedError, ) from superset.extensions import db -from superset.utils.core import get_iterable +from superset.utils.core import as_list T = TypeVar("T", bound=Model) @@ -197,7 +197,7 @@ def update( return item # type: ignore @classmethod - def delete(cls, items: T | list[T], commit: bool = True) -> None: + def delete(cls, item_or_items: T | list[T], commit: bool = True) -> None: """ Delete the specified item(s) including their associated relationships. @@ -214,9 +214,9 @@ def delete(cls, items: T | list[T], commit: bool = True) -> None: :raises DAODeleteFailedError: If the deletion failed :see: https://docs.sqlalchemy.org/en/latest/orm/queryguide/dml.html """ - + items = cast(list[T], as_list(item_or_items)) try: - for item in get_iterable(items): + for item in items: db.session.delete(item) if commit: diff --git a/superset/utils/core.py b/superset/utils/core.py index 8b1cc1a4856b2..3e297b5c90273 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -1578,12 +1578,15 @@ def split( yield string[i:] -def get_iterable(x: Any) -> list[Any]: +T = TypeVar("T") + + +def as_list(x: T | list[T]) -> list[T]: """ - Get an iterable (list) representation of the object. + Wrap an object in a list if it's not a list. :param x: The object - :returns: An iterable representation + :returns: A list wrapping the object if it's not already a list """ return x if isinstance(x, list) else [x] diff --git a/tests/integration_tests/utils_tests.py b/tests/integration_tests/utils_tests.py index 86d8bf6e68501..6648d72c61788 100644 --- a/tests/integration_tests/utils_tests.py +++ b/tests/integration_tests/utils_tests.py @@ -54,7 +54,7 @@ format_timedelta, GenericDataType, get_form_data_token, - get_iterable, + as_list, get_email_address_list, get_stacktrace, json_int_dttm_ser, @@ -749,10 +749,10 @@ def test_get_or_create_db_existing_invalid_uri(self): database = get_or_create_db("test_db", "sqlite:///superset.db") assert database.sqlalchemy_uri == "sqlite:///superset.db" - def test_get_iterable(self): - self.assertListEqual(get_iterable(123), [123]) - self.assertListEqual(get_iterable([123]), [123]) - self.assertListEqual(get_iterable("foo"), ["foo"]) + def test_as_list(self): + self.assertListEqual(as_list(123), [123]) + self.assertListEqual(as_list([123]), [123]) + self.assertListEqual(as_list("foo"), ["foo"]) @pytest.mark.usefixtures("load_world_bank_dashboard_with_slices") def test_build_extra_filters(self):