From 3860ac5d2313d438c25e5c46ed4e2b8b2b5227e3 Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Wed, 19 Jan 2022 14:07:30 -0800 Subject: [PATCH] [SPARK-37957][SQL] Correctly pass deterministic flag for V2 scalar functions ### What changes were proposed in this pull request? Pass `isDeterministic` flag to `ApplyFunctionExpression`, `Invoke` and `StaticInvoke` when processing V2 scalar functions. ### Why are the changes needed? A V2 scalar function can be declared as non-deterministic. However, currently Spark doesn't pass the flag when converting the V2 function to a catalyst expression, which could lead to incorrect results if being applied with certain optimizations. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Added a unit test. Closes #35243 from sunchao/SPARK-37957. Authored-by: Chao Sun Signed-off-by: Chao Sun --- .../sql/catalyst/analysis/Analyzer.scala | 6 +- .../expressions/ApplyFunctionExpression.scala | 2 + .../expressions/objects/objects.scala | 12 +- .../catalog/functions/JavaLongAdd.java | 2 +- .../catalog/functions/JavaRandomAdd.java | 110 ++++++++++++++++++ .../catalog/functions/JavaStrLen.java | 2 +- .../connector/DataSourceV2FunctionSuite.scala | 33 +++++- 7 files changed, 158 insertions(+), 9 deletions(-) create mode 100644 sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaRandomAdd.java diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 89c7b5f9c5d44..42bfa24698b34 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2299,12 +2299,14 @@ class Analyzer(override val catalogManager: CatalogManager) case Some(m) if Modifier.isStatic(m.getModifiers) => StaticInvoke(scalarFunc.getClass, scalarFunc.resultType(), MAGIC_METHOD_NAME, arguments, inputTypes = declaredInputTypes, - propagateNull = false, returnNullable = scalarFunc.isResultNullable) + propagateNull = false, returnNullable = scalarFunc.isResultNullable, + isDeterministic = scalarFunc.isDeterministic) case Some(_) => val caller = Literal.create(scalarFunc, ObjectType(scalarFunc.getClass)) Invoke(caller, MAGIC_METHOD_NAME, scalarFunc.resultType(), arguments, methodInputTypes = declaredInputTypes, propagateNull = false, - returnNullable = scalarFunc.isResultNullable) + returnNullable = scalarFunc.isResultNullable, + isDeterministic = scalarFunc.isDeterministic) case _ => // TODO: handle functions defined in Scala too - in Scala, even if a // subclass do not override the default method in parent interface diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApplyFunctionExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApplyFunctionExpression.scala index b33b9ed57f112..da4000f53e3e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApplyFunctionExpression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApplyFunctionExpression.scala @@ -31,6 +31,8 @@ case class ApplyFunctionExpression( override def name: String = function.name() override def dataType: DataType = function.resultType() override def inputTypes: Seq[AbstractDataType] = function.inputTypes().toSeq + override lazy val deterministic: Boolean = function.isDeterministic && + children.forall(_.deterministic) private lazy val reusedRow = new SpecificInternalRow(function.inputTypes()) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 50e214011b616..6d251b6d1007d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -240,6 +240,8 @@ object SerializerSupport { * without invoking the function. * @param returnNullable When false, indicating the invoked method will always return * non-null value. + * @param isDeterministic Whether the method invocation is deterministic or not. If false, Spark + * will not apply certain optimizations such as constant folding. */ case class StaticInvoke( staticObject: Class[_], @@ -248,7 +250,8 @@ case class StaticInvoke( arguments: Seq[Expression] = Nil, inputTypes: Seq[AbstractDataType] = Nil, propagateNull: Boolean = true, - returnNullable: Boolean = true) extends InvokeLike { + returnNullable: Boolean = true, + isDeterministic: Boolean = true) extends InvokeLike { val objectName = staticObject.getName.stripSuffix("$") val cls = if (staticObject.getName == objectName) { @@ -259,6 +262,7 @@ case class StaticInvoke( override def nullable: Boolean = needNullCheck || returnNullable override def children: Seq[Expression] = arguments + override lazy val deterministic: Boolean = isDeterministic && arguments.forall(_.deterministic) lazy val argClasses = ScalaReflection.expressionJavaClasses(arguments) @transient lazy val method = findMethod(cls, functionName, argClasses) @@ -340,6 +344,8 @@ case class StaticInvoke( * without invoking the function. * @param returnNullable When false, indicating the invoked method will always return * non-null value. + * @param isDeterministic Whether the method invocation is deterministic or not. If false, Spark + * will not apply certain optimizations such as constant folding. */ case class Invoke( targetObject: Expression, @@ -348,12 +354,14 @@ case class Invoke( arguments: Seq[Expression] = Nil, methodInputTypes: Seq[AbstractDataType] = Nil, propagateNull: Boolean = true, - returnNullable : Boolean = true) extends InvokeLike { + returnNullable : Boolean = true, + isDeterministic: Boolean = true) extends InvokeLike { lazy val argClasses = ScalaReflection.expressionJavaClasses(arguments) override def nullable: Boolean = targetObject.nullable || needNullCheck || returnNullable override def children: Seq[Expression] = targetObject +: arguments + override lazy val deterministic: Boolean = isDeterministic && arguments.forall(_.deterministic) override def inputTypes: Seq[AbstractDataType] = if (methodInputTypes.nonEmpty) { Seq(targetObject.dataType) ++ methodInputTypes diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaLongAdd.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaLongAdd.java index e5b9c7f5bafaa..75ef5275684d6 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaLongAdd.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaLongAdd.java @@ -66,7 +66,7 @@ public String description() { return "long_add"; } - private abstract static class JavaLongAddBase implements ScalarFunction { + public abstract static class JavaLongAddBase implements ScalarFunction { private final boolean isResultNullable; JavaLongAddBase(boolean isResultNullable) { diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaRandomAdd.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaRandomAdd.java new file mode 100644 index 0000000000000..b315fafd8ece8 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaRandomAdd.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.connector.catalog.functions; + +import java.util.Random; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.catalog.functions.BoundFunction; +import org.apache.spark.sql.connector.catalog.functions.ScalarFunction; +import org.apache.spark.sql.connector.catalog.functions.UnboundFunction; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.IntegerType; +import org.apache.spark.sql.types.StructType; + +/** + * Test V2 function which add a random number to the input integer. + */ +public class JavaRandomAdd implements UnboundFunction { + private final BoundFunction fn; + + public JavaRandomAdd(BoundFunction fn) { + this.fn = fn; + } + + @Override + public String name() { + return "rand"; + } + + @Override + public BoundFunction bind(StructType inputType) { + if (inputType.fields().length != 1) { + throw new UnsupportedOperationException("Expect exactly one argument"); + } + if (inputType.fields()[0].dataType() instanceof IntegerType) { + return fn; + } + throw new UnsupportedOperationException("Expect IntegerType"); + } + + @Override + public String description() { + return "rand_add: add a random integer to the input\n" + + "rand_add(int) -> int"; + } + + public abstract static class JavaRandomAddBase implements ScalarFunction { + @Override + public DataType[] inputTypes() { + return new DataType[] { DataTypes.IntegerType }; + } + + @Override + public DataType resultType() { + return DataTypes.IntegerType; + } + + @Override + public String name() { + return "rand_add"; + } + + @Override + public boolean isDeterministic() { + return false; + } + } + + public static class JavaRandomAddDefault extends JavaRandomAddBase { + private final Random rand = new Random(); + + @Override + public Integer produceResult(InternalRow input) { + return input.getInt(0) + rand.nextInt(); + } + } + + public static class JavaRandomAddMagic extends JavaRandomAddBase { + private final Random rand = new Random(); + + public int invoke(int input) { + return input + rand.nextInt(); + } + } + + public static class JavaRandomAddStaticMagic extends JavaRandomAddBase { + private static final Random rand = new Random(); + + public static int invoke(int input) { + return input + rand.nextInt(); + } + } +} + diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaStrLen.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaStrLen.java index 1b1689668e1f6..dade2a113ef45 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaStrLen.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaStrLen.java @@ -49,7 +49,7 @@ public BoundFunction bind(StructType inputType) { return fn; } - throw new UnsupportedOperationException("Except StringType"); + throw new UnsupportedOperationException("Expect StringType"); } @Override diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala index d5417be0f229f..ace66199f3e4a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala @@ -20,16 +20,18 @@ package org.apache.spark.sql.connector import java.util import java.util.Collections -import test.org.apache.spark.sql.connector.catalog.functions.{JavaAverage, JavaLongAdd, JavaStrLen} -import test.org.apache.spark.sql.connector.catalog.functions.JavaLongAdd.{JavaLongAddDefault, JavaLongAddMagic, JavaLongAddMismatchMagic, JavaLongAddStaticMagic} +import test.org.apache.spark.sql.connector.catalog.functions._ +import test.org.apache.spark.sql.connector.catalog.functions.JavaLongAdd._ +import test.org.apache.spark.sql.connector.catalog.functions.JavaRandomAdd._ import test.org.apache.spark.sql.connector.catalog.functions.JavaStrLen._ import org.apache.spark.SparkException -import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql.{AnalysisException, DataFrame, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode.{FALLBACK, NO_CODEGEN} import org.apache.spark.sql.connector.catalog.{BasicInMemoryTableCatalog, Identifier, InMemoryCatalog, SupportsNamespaces} import org.apache.spark.sql.connector.catalog.functions.{AggregateFunction, _} +import org.apache.spark.sql.execution.ProjectExec import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -365,6 +367,31 @@ class DataSourceV2FunctionSuite extends DatasourceV2SQLBase { } } + test("SPARK-37957: pass deterministic flag when creating V2 function expression") { + def checkDeterministic(df: DataFrame): Unit = { + val result = df.queryExecution.executedPlan.find(_.isInstanceOf[ProjectExec]) + assert(result.isDefined, s"Expect to find ProjectExec") + assert(!result.get.asInstanceOf[ProjectExec].projectList.exists(_.deterministic), + "Expect expressions in projectList to be non-deterministic") + } + + catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps) + Seq(new JavaRandomAddDefault, new JavaRandomAddMagic, + new JavaRandomAddStaticMagic).foreach { fn => + addFunction(Identifier.of(Array("ns"), "rand_add"), new JavaRandomAdd(fn)) + checkDeterministic(sql("SELECT testcat.ns.rand_add(42)")) + } + + // A function call is non-deterministic if one of its arguments is non-deterministic + Seq(new JavaLongAddDefault(true), new JavaLongAddMagic(true), + new JavaLongAddStaticMagic(true)).foreach { fn => + addFunction(Identifier.of(Array("ns"), "add"), new JavaLongAdd(fn)) + addFunction(Identifier.of(Array("ns"), "rand_add"), + new JavaRandomAdd(new JavaRandomAddDefault)) + checkDeterministic(sql("SELECT testcat.ns.add(10, testcat.ns.rand_add(42))")) + } + } + private case class StrLen(impl: BoundFunction) extends UnboundFunction { override def description(): String = """strlen: returns the length of the input string