Skip to content

Commit

Permalink
fix: add mutator to get_columns_description (#29885)
Browse files Browse the repository at this point in the history
  • Loading branch information
eschutho authored Aug 9, 2024
1 parent fb6efb9 commit 38d64e8
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 61 deletions.
5 changes: 3 additions & 2 deletions superset/connectors/sqla/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,9 @@ def get_columns_description(
with database.get_raw_connection(catalog=catalog, schema=schema) as conn:
cursor = conn.cursor()
query = database.apply_limit_to_sql(query, limit=1)
cursor.execute(query)
db_engine_spec.execute(cursor, query, database)
mutated_query = database.mutate_sql_based_on_config(query)
cursor.execute(mutated_query)
db_engine_spec.execute(cursor, mutated_query, database)
result = db_engine_spec.fetch_data(cursor, limit=1)
result_set = SupersetResultSet(result, cursor.description, db_engine_spec)
return result_set.columns
Expand Down
127 changes: 68 additions & 59 deletions tests/integration_tests/datasource_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,24 @@ def create_test_table_context(database: Database):
engine.execute(f"DROP TABLE {full_table_name}")


@contextmanager
def create_and_cleanup_table(table=None):
if table is None:
table = SqlaTable(
table_name="dummy_sql_table",
database=get_example_database(),
schema=get_example_default_schema(),
sql="select 123 as intcol, 'abc' as strcol",
)
db.session.add(table)
db.session.commit()
try:
yield table
finally:
db.session.delete(table)
db.session.commit()


class TestDatasource(SupersetTestCase):
def setUp(self):
db.session.begin(subtransactions=True)
Expand Down Expand Up @@ -123,37 +141,22 @@ def test_always_filter_main_dttm(self):
sql=sql,
)

db.session.add(table)
db.session.commit()
with create_and_cleanup_table(table):
table.always_filter_main_dttm = False
result = str(table.get_sqla_query(**query_obj).sqla_query.whereclause)
assert "default_dttm" not in result and "additional_dttm" in result

table.always_filter_main_dttm = False
result = str(table.get_sqla_query(**query_obj).sqla_query.whereclause)
assert "default_dttm" not in result and "additional_dttm" in result

table.always_filter_main_dttm = True
result = str(table.get_sqla_query(**query_obj).sqla_query.whereclause)
assert "default_dttm" in result and "additional_dttm" in result

db.session.delete(table)
db.session.commit()
table.always_filter_main_dttm = True
result = str(table.get_sqla_query(**query_obj).sqla_query.whereclause)
assert "default_dttm" in result and "additional_dttm" in result

def test_external_metadata_for_virtual_table(self):
self.login(ADMIN_USERNAME)
table = SqlaTable(
table_name="dummy_sql_table",
database=get_example_database(),
schema=get_example_default_schema(),
sql="select 123 as intcol, 'abc' as strcol",
)
db.session.add(table)
db.session.commit()

table = self.get_table(name="dummy_sql_table")
url = f"/datasource/external_metadata/table/{table.id}/"
resp = self.get_json_resp(url)
assert {o.get("column_name") for o in resp} == {"intcol", "strcol"}
db.session.delete(table)
db.session.commit()
with create_and_cleanup_table() as table:
url = f"/datasource/external_metadata/table/{table.id}/"
resp = self.get_json_resp(url)
assert {o.get("column_name") for o in resp} == {"intcol", "strcol"}

@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_external_metadata_by_name_for_physical_table(self):
Expand All @@ -178,31 +181,42 @@ def test_external_metadata_by_name_for_physical_table(self):

def test_external_metadata_by_name_for_virtual_table(self):
self.login(ADMIN_USERNAME)
table = SqlaTable(
table_name="dummy_sql_table",
database=get_example_database(),
schema=get_example_default_schema(),
sql="select 123 as intcol, 'abc' as strcol",
)
db.session.add(table)
db.session.commit()
with create_and_cleanup_table() as tbl:
params = prison.dumps(
{
"datasource_type": "table",
"database_name": tbl.database.database_name,
"schema_name": tbl.schema,
"table_name": tbl.table_name,
"normalize_columns": tbl.normalize_columns,
"always_filter_main_dttm": tbl.always_filter_main_dttm,
}
)
url = f"/datasource/external_metadata_by_name/?q={params}"
resp = self.get_json_resp(url)
assert {o.get("column_name") for o in resp} == {"intcol", "strcol"}

tbl = self.get_table(name="dummy_sql_table")
params = prison.dumps(
{
"datasource_type": "table",
"database_name": tbl.database.database_name,
"schema_name": tbl.schema,
"table_name": tbl.table_name,
"normalize_columns": tbl.normalize_columns,
"always_filter_main_dttm": tbl.always_filter_main_dttm,
}
)
url = f"/datasource/external_metadata_by_name/?q={params}"
resp = self.get_json_resp(url)
assert {o.get("column_name") for o in resp} == {"intcol", "strcol"}
db.session.delete(tbl)
db.session.commit()
def test_external_metadata_by_name_for_virtual_table_uses_mutator(self):
self.login(ADMIN_USERNAME)
with create_and_cleanup_table() as tbl:
app.config["SQL_QUERY_MUTATOR"] = (
lambda sql, **kwargs: "SELECT 456 as intcol, 'def' as mutated_strcol"
)

params = prison.dumps(
{
"datasource_type": "table",
"database_name": tbl.database.database_name,
"schema_name": tbl.schema,
"table_name": tbl.table_name,
"normalize_columns": tbl.normalize_columns,
"always_filter_main_dttm": tbl.always_filter_main_dttm,
}
)
url = f"/datasource/external_metadata_by_name/?q={params}"
resp = self.get_json_resp(url)
assert {o.get("column_name") for o in resp} == {"intcol", "mutated_strcol"}
app.config["SQL_QUERY_MUTATOR"] = None

def test_external_metadata_by_name_from_sqla_inspector(self):
self.login(ADMIN_USERNAME)
Expand Down Expand Up @@ -278,15 +292,10 @@ def test_external_metadata_for_virtual_table_template_params(self):
sql="select {{ foo }} as intcol",
template_params=json.dumps({"foo": "123"}),
)
db.session.add(table)
db.session.commit()

table = self.get_table(name="dummy_sql_table_with_template_params")
url = f"/datasource/external_metadata/table/{table.id}/"
resp = self.get_json_resp(url)
assert {o.get("column_name") for o in resp} == {"intcol"}
db.session.delete(table)
db.session.commit()
with create_and_cleanup_table(table) as tbl:
url = f"/datasource/external_metadata/table/{tbl.id}/"
resp = self.get_json_resp(url)
assert {o.get("column_name") for o in resp} == {"intcol"}

def test_external_metadata_for_malicious_virtual_table(self):
self.login(ADMIN_USERNAME)
Expand Down
91 changes: 91 additions & 0 deletions tests/unit_tests/connectors/sqla/utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from pytest_mock import MockerFixture

from superset.connectors.sqla.utils import get_columns_description


# Returns column descriptions when given valid database, catalog, schema, and query
def test_returns_column_descriptions(mocker: MockerFixture) -> None:
database = mocker.MagicMock()
cursor = mocker.MagicMock()

result_set = mocker.MagicMock()
db_engine_spec = mocker.MagicMock()

CURSOR_DESCR = (
("foo", "string"),
("bar", "string"),
("baz", "string"),
("type_generic", "string"),
("is_dttm", "boolean"),
)
cursor.description = CURSOR_DESCR

database.get_raw_connection.return_value.__enter__.return_value.cursor.return_value = cursor
database.db_engine_spec = db_engine_spec
database.apply_limit_to_sql.return_value = "SELECT * FROM table LIMIT 1"
database.mutate_sql_based_on_config.return_value = "SELECT * FROM table LIMIT 1"
db_engine_spec.fetch_data.return_value = [("col1", "col1", "STRING", None, False)]
db_engine_spec.get_datatype.return_value = "STRING"
db_engine_spec.get_column_spec.return_value.is_dttm = False
db_engine_spec.get_column_spec.return_value.generic_type = "STRING"

mocker.patch("superset.result_set.SupersetResultSet", return_value=result_set)

columns = get_columns_description(
database, "catalog", "schema", "SELECT * FROM table"
)

assert columns == [
{
"column_name": "foo",
"name": "foo",
"type": "STRING",
"type_generic": "STRING",
"is_dttm": False,
},
{
"column_name": "bar",
"name": "bar",
"type": "STRING",
"type_generic": "STRING",
"is_dttm": False,
},
{
"column_name": "baz",
"name": "baz",
"type": "STRING",
"type_generic": "STRING",
"is_dttm": False,
},
{
"column_name": "type_generic",
"name": "type_generic",
"type": "STRING",
"type_generic": "STRING",
"is_dttm": False,
},
{
"column_name": "is_dttm",
"name": "is_dttm",
"type": "STRING",
"type_generic": "STRING",
"is_dttm": False,
},
]

0 comments on commit 38d64e8

Please sign in to comment.