From 682570a3bdd288bb6c0cbd4cb2e3d2b3c1566102 Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Thu, 21 Nov 2024 16:43:25 -0500 Subject: [PATCH 1/2] fix(dataset): use sqlglot for DML check --- superset/connectors/sqla/utils.py | 12 ++--- .../unit_tests/connectors/sqla/utils_test.py | 50 ++++++++++++++++++- 2 files changed, 55 insertions(+), 7 deletions(-) diff --git a/superset/connectors/sqla/utils.py b/superset/connectors/sqla/utils.py index 18e5de48c10d6..84a6753f22861 100644 --- a/superset/connectors/sqla/utils.py +++ b/superset/connectors/sqla/utils.py @@ -38,7 +38,8 @@ ) from superset.models.core import Database from superset.result_set import SupersetResultSet -from superset.sql_parse import ParsedQuery, Table +from superset.sql.parse import SQLScript +from superset.sql_parse import Table from superset.superset_typing import ResultSetColumnType if TYPE_CHECKING: @@ -105,8 +106,8 @@ def get_virtual_table_metadata(dataset: SqlaTable) -> list[ResultSetColumnType]: sql = dataset.get_template_processor().process_template( dataset.sql, **dataset.template_params_dict ) - parsed_query = ParsedQuery(sql, engine=db_engine_spec.engine) - if not db_engine_spec.is_readonly_query(parsed_query): + parsed_script = SQLScript(sql, engine=db_engine_spec.engine) + if parsed_script.has_mutation(): raise SupersetSecurityException( SupersetError( error_type=SupersetErrorType.DATASOURCE_SECURITY_ACCESS_ERROR, @@ -114,8 +115,7 @@ def get_virtual_table_metadata(dataset: SqlaTable) -> list[ResultSetColumnType]: level=ErrorLevel.ERROR, ) ) - statements = parsed_query.get_statements() - if len(statements) > 1: + if len(parsed_script.statements) > 1: raise SupersetSecurityException( SupersetError( error_type=SupersetErrorType.DATASOURCE_SECURITY_ACCESS_ERROR, @@ -127,7 +127,7 @@ def get_virtual_table_metadata(dataset: SqlaTable) -> list[ResultSetColumnType]: dataset.database, dataset.catalog, dataset.schema, - statements[0], + sql, ) diff --git a/tests/unit_tests/connectors/sqla/utils_test.py b/tests/unit_tests/connectors/sqla/utils_test.py index 75d5a1fe32914..0da3ab7e95a9d 100644 --- a/tests/unit_tests/connectors/sqla/utils_test.py +++ b/tests/unit_tests/connectors/sqla/utils_test.py @@ -15,9 +15,14 @@ # specific language governing permissions and limitations # under the License. +import pytest from pytest_mock import MockerFixture -from superset.connectors.sqla.utils import get_columns_description +from superset.connectors.sqla.utils import ( + get_columns_description, + get_virtual_table_metadata, +) +from superset.exceptions import SupersetSecurityException # Returns column descriptions when given valid database, catalog, schema, and query @@ -89,3 +94,46 @@ def test_returns_column_descriptions(mocker: MockerFixture) -> None: "is_dttm": False, }, ] + + +def test_get_virtual_table_metadata(mocker: MockerFixture) -> None: + """ + Test the `get_virtual_table_metadata` function. + """ + mocker.patch( + "superset.connectors.sqla.utils.get_columns_description", + return_value=[{"name": "one", "type": "INTEGER"}], + ) + dataset = mocker.MagicMock( + sql="with source as ( select 1 as one ) select * from source", + ) + dataset.database.db_engine_spec.engine = "postgresql" + dataset.get_template_processor().process_template.return_value = dataset.sql + + assert get_virtual_table_metadata(dataset) == [{"name": "one", "type": "INTEGER"}] + + +def test_get_virtual_table_metadata_mutating(mocker: MockerFixture) -> None: + """ + Test the `get_virtual_table_metadata` function with mutating SQL. + """ + dataset = mocker.MagicMock(sql="DROP TABLE sample_data") + dataset.database.db_engine_spec.engine = "postgresql" + dataset.get_template_processor().process_template.return_value = dataset.sql + + with pytest.raises(SupersetSecurityException) as excinfo: + get_virtual_table_metadata(dataset) + assert str(excinfo.value) == "Only `SELECT` statements are allowed" + + +def test_get_virtual_table_metadata_multiple(mocker: MockerFixture) -> None: + """ + Test the `get_virtual_table_metadata` function with multiple statements. + """ + dataset = mocker.MagicMock(sql="SELECT 1; SELECT 2") + dataset.database.db_engine_spec.engine = "postgresql" + dataset.get_template_processor().process_template.return_value = dataset.sql + + with pytest.raises(SupersetSecurityException) as excinfo: + get_virtual_table_metadata(dataset) + assert str(excinfo.value) == "Only single queries supported" From a8306da41f9737b3ec5936a1cb678688e374b7df Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Thu, 21 Nov 2024 18:52:15 -0500 Subject: [PATCH 2/2] Fix test --- tests/integration_tests/charts/api_tests.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/integration_tests/charts/api_tests.py b/tests/integration_tests/charts/api_tests.py index 784c4651ada29..fc00dd3bc2955 100644 --- a/tests/integration_tests/charts/api_tests.py +++ b/tests/integration_tests/charts/api_tests.py @@ -1964,6 +1964,7 @@ def test_gets_owned_created_favorited_by_me_filter(self): assert rv.status_code == 200 data = json.loads(rv.data.decode("utf-8")) + data["result"].sort(key=lambda x: x["datasource_id"]) assert data["result"][0]["slice_name"] == "name0" assert data["result"][0]["datasource_id"] == 1