From 1e842b452087d24c18f7af5b1e97cd4aa081975b Mon Sep 17 00:00:00 2001 From: MithunR Date: Tue, 26 Nov 2024 00:15:03 +0000 Subject: [PATCH] Fix `dpp_test.py` failures on [databricks] 14.3 Fixes #11536. This commit fixes the tests in `dpp_test.py` that were failing on Databricks 14.3. The failures were largely a result of an erroneous shim implementation, that was fixed as part of #11750. This commit accounts for the remaining failures that result from there being a `CollectLimitExec` in certain DPP query plans (that include broadcast joins, for example). The tests have been made more permissive, in allowing the `CollectLimitExec` to run on the CPU. The `CollectLimitExec` based plans will be further explored as part of https://github.com/NVIDIA/spark-rapids/issues/11764. Signed-off-by: MithunR --- integration_tests/src/main/python/dpp_test.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/integration_tests/src/main/python/dpp_test.py b/integration_tests/src/main/python/dpp_test.py index b362a4175f3..3d5ee1a5afa 100644 --- a/integration_tests/src/main/python/dpp_test.py +++ b/integration_tests/src/main/python/dpp_test.py @@ -20,7 +20,7 @@ from conftest import spark_tmp_table_factory from data_gen import * from marks import ignore_order, allow_non_gpu, datagen_overrides, disable_ansi_mode -from spark_session import is_before_spark_320, with_cpu_session, is_before_spark_312, is_databricks_runtime, is_databricks113_or_later +from spark_session import is_before_spark_320, with_cpu_session, is_before_spark_312, is_databricks_runtime, is_databricks113_or_later, is_databricks_version_or_later # non-positive values here can produce a degenerative join, so here we ensure that most values are # positive to ensure the join will produce rows. See https://github.com/NVIDIA/spark-rapids/issues/10147 @@ -167,10 +167,17 @@ def fn(spark): ''' ] +# On some Databricks versions (>=14.3), some query plans include a `CollectLimitExec`, +# when filtering partitions. This exec falls back to CPU. These tests allow for `CollectLimit` to +# run on the CPU, if everything else in the plan execute as expected. +# Further details are furnished at https://github.com/NVIDIA/spark-rapids/issues/11764. +dpp_fallback_execs=["CollectLimitExec"] if is_databricks_version_or_later(14,3) else [] + @disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114 # When BroadcastExchangeExec is available on filtering side, and it can be reused: # DynamicPruningExpression(InSubqueryExec(value, GpuSubqueryBroadcastExec))) @ignore_order +@allow_non_gpu(*dpp_fallback_execs) @datagen_overrides(seed=0, reason="https://github.com/NVIDIA/spark-rapids/issues/10147") @pytest.mark.parametrize('store_format', ['parquet', 'orc'], ids=idfn) @pytest.mark.parametrize('s_index', list(range(len(_statements))), ids=idfn) @@ -245,6 +252,7 @@ def test_dpp_bypass(spark_tmp_table_factory, store_format, s_index, aqe_enabled) # then Spark will plan an extra Aggregate to collect filtering values: # DynamicPruningExpression(InSubqueryExec(value, SubqueryExec(Aggregate(...)))) @ignore_order +@allow_non_gpu(*dpp_fallback_execs) @pytest.mark.parametrize('store_format', ['parquet', 'orc'], ids=idfn) @pytest.mark.parametrize('s_index', list(range(len(_statements))), ids=idfn) @pytest.mark.parametrize('aqe_enabled', [ @@ -285,10 +293,11 @@ def test_dpp_skip(spark_tmp_table_factory, store_format, s_index, aqe_enabled): non_exist_classes='DynamicPruningExpression', conf=dict(_dpp_fallback_conf + [('spark.sql.adaptive.enabled', aqe_enabled)])) +dpp_like_any_fallback_execs=['FilterExec', 'CollectLimitExec'] if is_databricks_version_or_later(14,3) else ['FilterExec'] # GPU verification on https://issues.apache.org/jira/browse/SPARK-34436 @ignore_order -@allow_non_gpu('FilterExec') +@allow_non_gpu(*dpp_like_any_fallback_execs) @pytest.mark.parametrize('store_format', ['parquet', 'orc'], ids=idfn) @pytest.mark.parametrize('aqe_enabled', [ 'false', @@ -327,6 +336,7 @@ def create_dim_table_for_like(spark): @disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114 +@allow_non_gpu(*dpp_fallback_execs) # Test handling DPP expressions from a HashedRelation that rearranges columns @pytest.mark.parametrize('aqe_enabled', [ 'false',