Skip to content

Commit

Permalink
Fix dpp_test.py failures on [databricks] 14.3 (#11768)
Browse files Browse the repository at this point in the history
Fixes #11536.

This commit fixes the tests in `dpp_test.py` that were failing on
Databricks 14.3.

The failures were largely a result of an erroneous shim implementation,
that was fixed as part of #11750.

This commit accounts for the remaining failures that result from there
being a `CollectLimitExec` in certain DPP query plans (that include
broadcast joins, for example).  The tests have been made more
permissive, in allowing the `CollectLimitExec` to run on the CPU.

The `CollectLimitExec` based plans will be further explored as part of
#11764.

Signed-off-by: MithunR <[email protected]>
  • Loading branch information
mythrocks authored Nov 26, 2024
1 parent ff0ca0f commit ed02cfe
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions integration_tests/src/main/python/dpp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from conftest import spark_tmp_table_factory
from data_gen import *
from marks import ignore_order, allow_non_gpu, datagen_overrides, disable_ansi_mode
from spark_session import is_before_spark_320, with_cpu_session, is_before_spark_312, is_databricks_runtime, is_databricks113_or_later
from spark_session import is_before_spark_320, with_cpu_session, is_before_spark_312, is_databricks_runtime, is_databricks113_or_later, is_databricks_version_or_later

# non-positive values here can produce a degenerative join, so here we ensure that most values are
# positive to ensure the join will produce rows. See https://github.com/NVIDIA/spark-rapids/issues/10147
Expand Down Expand Up @@ -167,10 +167,17 @@ def fn(spark):
'''
]

# On some Databricks versions (>=14.3), some query plans include a `CollectLimitExec`,
# when filtering partitions. This exec falls back to CPU. These tests allow for `CollectLimit` to
# run on the CPU, if everything else in the plan execute as expected.
# Further details are furnished at https://github.com/NVIDIA/spark-rapids/issues/11764.
dpp_fallback_execs=["CollectLimitExec"] if is_databricks_version_or_later(14,3) else []

@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114
# When BroadcastExchangeExec is available on filtering side, and it can be reused:
# DynamicPruningExpression(InSubqueryExec(value, GpuSubqueryBroadcastExec)))
@ignore_order
@allow_non_gpu(*dpp_fallback_execs)
@datagen_overrides(seed=0, reason="https://github.com/NVIDIA/spark-rapids/issues/10147")
@pytest.mark.parametrize('store_format', ['parquet', 'orc'], ids=idfn)
@pytest.mark.parametrize('s_index', list(range(len(_statements))), ids=idfn)
Expand Down Expand Up @@ -245,6 +252,7 @@ def test_dpp_bypass(spark_tmp_table_factory, store_format, s_index, aqe_enabled)
# then Spark will plan an extra Aggregate to collect filtering values:
# DynamicPruningExpression(InSubqueryExec(value, SubqueryExec(Aggregate(...))))
@ignore_order
@allow_non_gpu(*dpp_fallback_execs)
@pytest.mark.parametrize('store_format', ['parquet', 'orc'], ids=idfn)
@pytest.mark.parametrize('s_index', list(range(len(_statements))), ids=idfn)
@pytest.mark.parametrize('aqe_enabled', [
Expand Down Expand Up @@ -285,10 +293,11 @@ def test_dpp_skip(spark_tmp_table_factory, store_format, s_index, aqe_enabled):
non_exist_classes='DynamicPruningExpression',
conf=dict(_dpp_fallback_conf + [('spark.sql.adaptive.enabled', aqe_enabled)]))

dpp_like_any_fallback_execs=['FilterExec', 'CollectLimitExec'] if is_databricks_version_or_later(14,3) else ['FilterExec']

# GPU verification on https://issues.apache.org/jira/browse/SPARK-34436
@ignore_order
@allow_non_gpu('FilterExec')
@allow_non_gpu(*dpp_like_any_fallback_execs)
@pytest.mark.parametrize('store_format', ['parquet', 'orc'], ids=idfn)
@pytest.mark.parametrize('aqe_enabled', [
'false',
Expand Down Expand Up @@ -327,6 +336,7 @@ def create_dim_table_for_like(spark):


@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114
@allow_non_gpu(*dpp_fallback_execs)
# Test handling DPP expressions from a HashedRelation that rearranges columns
@pytest.mark.parametrize('aqe_enabled', [
'false',
Expand Down

0 comments on commit ed02cfe

Please sign in to comment.