From 07cbba64e3ea82c169bfaa02b3a92e91207919b1 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Wed, 26 Jun 2024 13:39:53 +0800 Subject: [PATCH] [SPARK-48706][PYTHON] Python UDF in higher order functions should not throw internal error ### What changes were proposed in this pull request? This PR fixes the error messages and classes when Python UDFs are used in higher order functions. ### Why are the changes needed? To show the proper user-facing exceptions with error classes. ### Does this PR introduce _any_ user-facing change? Yes, previously it threw internal error such as: ```python from pyspark.sql.functions import transform, udf, col, array spark.range(1).select(transform(array("id"), lambda x: udf(lambda y: y)(x))).collect() ``` Before: ``` py4j.protocol.Py4JJavaError: An error occurred while calling o74.collectToPython. : org.apache.spark.SparkException: Job aborted due to stage failure: Task 15 in stage 0.0 failed 1 times, most recent failure: Lost task 15.0 in stage 0.0 (TID 15) (ip-192-168-123-103.ap-northeast-2.compute.internal executor driver): org.apache.spark.SparkException: [INTERNAL_ERROR] Cannot evaluate expression: (lambda x_0#3L)#2 SQLSTATE: XX000 at org.apache.spark.SparkException$.internalError(SparkException.scala:92) at org.apache.spark.SparkException$.internalError(SparkException.scala:96) ``` After: ``` pyspark.errors.exceptions.captured.AnalysisException: [INVALID_LAMBDA_FUNCTION_CALL.UNEVALUABLE] Invalid lambda function call. Python UDFs should be used in a lambda function at a higher order function. However, "(lambda x_0#3L)" was a Python UDF. SQLSTATE: 42K0D; Project [transform(array(id#0L), lambdafunction((lambda x_0#3L)#2, lambda x_0#3L, false)) AS transform(array(id), lambdafunction((lambda x_0#3L), namedlambdavariable()))#4] +- Range (0, 1, step=1, splits=Some(16)) ``` ### How was this patch tested? Unittest was added ### Was this patch authored or co-authored using generative AI tooling? No. Closes #47079 from HyukjinKwon/SPARK-48706. Authored-by: Hyukjin Kwon Signed-off-by: Kent Yao --- .../main/resources/error/error-conditions.json | 5 +++++ .../sql/catalyst/analysis/CheckAnalysis.scala | 8 ++++++++ .../sql/execution/python/PythonUDFSuite.scala | 16 ++++++++++++++-- 3 files changed, 27 insertions(+), 2 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index bf251f057af59..72f358f87d624 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -4482,6 +4482,11 @@ "INSERT INTO with IF NOT EXISTS in the PARTITION spec." ] }, + "LAMBDA_FUNCTION_WITH_PYTHON_UDF" : { + "message" : [ + "Lambda function with Python UDF in a higher order function." + ] + }, "LATERAL_COLUMN_ALIAS_IN_AGGREGATE_FUNC" : { "message" : [ "Referencing a lateral column alias in the aggregate function ." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index bd8f8fe9f6528..9f3eee5198a16 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -254,6 +254,14 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB hof.invalidFormat(checkRes) } + case hof: HigherOrderFunction + if hof.resolved && hof.functions + .exists(_.exists(_.isInstanceOf[PythonUDF])) => + val u = hof.functions.flatMap(_.find(_.isInstanceOf[PythonUDF])).head + hof.failAnalysis( + errorClass = "UNSUPPORTED_FEATURE.LAMBDA_FUNCTION_WITH_PYTHON_UDF", + messageParameters = Map("funcName" -> toSQLExpr(u))) + // If an attribute can't be resolved as a map key of string type, either the key should be // surrounded with single quotes, or there is a typo in the attribute name. case GetMapValue(map, key: Attribute) if isMapWithStringKey(map) && !key.resolved => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala index 3101281251b1b..2e56ad0ab4160 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala @@ -17,8 +17,8 @@ package org.apache.spark.sql.execution.python -import org.apache.spark.sql.{IntegratedUDFTestUtils, QueryTest} -import org.apache.spark.sql.functions.count +import org.apache.spark.sql.{AnalysisException, IntegratedUDFTestUtils, QueryTest} +import org.apache.spark.sql.functions.{array, count, transform} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.LongType @@ -112,4 +112,16 @@ class PythonUDFSuite extends QueryTest with SharedSparkSession { val pandasTestUDF = TestGroupedAggPandasUDF(name = udfName) assert(df.agg(pandasTestUDF(df("id"))).schema.fieldNames.exists(_.startsWith(udfName))) } + + test("SPARK-48706: Negative test case for Python UDF in higher order functions") { + assume(shouldTestPythonUDFs) + checkError( + exception = intercept[AnalysisException] { + spark.range(1).select(transform(array("id"), x => pythonTestUDF(x))).collect() + }, + errorClass = "UNSUPPORTED_FEATURE.LAMBDA_FUNCTION_WITH_PYTHON_UDF", + parameters = Map("funcName" -> "\"pyUDF(namedlambdavariable())\""), + context = ExpectedContext( + "transform", s".*${this.getClass.getSimpleName}.*")) + } }