From 899fad4710bef174684deee64314ac483c16c494 Mon Sep 17 00:00:00 2001 From: Marko Date: Tue, 20 Aug 2024 10:56:19 +0200 Subject: [PATCH] [SPARK-49043][SQL] Fix interpreted codepath group by on map containing collated strings MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? Added ordering for PhysicalMapType in `PhysicalDataType.scala`. ### Why are the changes needed? This feature is needed to compare maps for equality in group-by queries when they contain collated strings. It was already functional in the codegen path. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added tests to `CollationSuite.scala` ### Was this patch authored or co-authored using generative AI tooling? No Closes #47521 from ilicmarkodb/fix_group_by_on_map. Lead-authored-by: Marko Co-authored-by: Marko Ilić Co-authored-by: Marko Ilic Signed-off-by: Max Gekk --- .../sql/catalyst/types/PhysicalDataType.scala | 70 ++++++++++- .../resources/sql-tests/results/mode.sql.out | 10 +- .../org/apache/spark/sql/CollationSuite.scala | 117 ++++++++++++++++++ 3 files changed, 185 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/PhysicalDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/PhysicalDataType.scala index f80aee4c8cbea..03389f14afa01 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/PhysicalDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/PhysicalDataType.scala @@ -21,7 +21,7 @@ import scala.reflect.runtime.universe.TypeTag import scala.reflect.runtime.universe.typeTag import org.apache.spark.sql.catalyst.expressions.{Ascending, BoundReference, InterpretedOrdering, SortOrder} -import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory, SQLOrderingUtil} +import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory, MapData, SQLOrderingUtil} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SqlApiConf import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteExactNumeric, ByteType, CalendarIntervalType, CharType, DataType, DateType, DayTimeIntervalType, Decimal, DecimalExactNumeric, DecimalType, DoubleExactNumeric, DoubleType, FloatExactNumeric, FloatType, FractionalType, IntegerExactNumeric, IntegerType, IntegralType, LongExactNumeric, LongType, MapType, NullType, NumericType, ShortExactNumeric, ShortType, StringType, StructField, StructType, TimestampNTZType, TimestampType, VarcharType, VariantType, YearMonthIntervalType} @@ -234,10 +234,72 @@ case object PhysicalLongType extends PhysicalLongType case class PhysicalMapType(keyType: DataType, valueType: DataType, valueContainsNull: Boolean) extends PhysicalDataType { - override private[sql] def ordering = - throw QueryExecutionErrors.orderedOperationUnsupportedByDataTypeError("PhysicalMapType") - override private[sql] type InternalType = Any + // maps are not orderable, we use `ordering` just to support group by queries + override private[sql] def ordering = interpretedOrdering + override private[sql] type InternalType = MapData @transient private[sql] lazy val tag = typeTag[InternalType] + + @transient + private[sql] lazy val interpretedOrdering: Ordering[MapData] = new Ordering[MapData] { + private[this] val keyOrdering = + PhysicalDataType(keyType).ordering.asInstanceOf[Ordering[Any]] + private[this] val valuesOrdering = + PhysicalDataType(valueType).ordering.asInstanceOf[Ordering[Any]] + + override def compare(left: MapData, right: MapData): Int = { + val lengthLeft = left.numElements() + val lengthRight = right.numElements() + val keyArrayLeft = left.keyArray() + val valueArrayLeft = left.valueArray() + val keyArrayRight = right.keyArray() + val valueArrayRight = right.valueArray() + val minLength = math.min(lengthLeft, lengthRight) + var i = 0 + while (i < minLength) { + var comp = compareElements(keyArrayLeft, keyArrayRight, keyType, i, keyOrdering) + if (comp != 0) { + return comp + } + comp = compareElements(valueArrayLeft, valueArrayRight, valueType, i, valuesOrdering) + if (comp != 0) { + return comp + } + + i += 1 + } + + if (lengthLeft < lengthRight) { + -1 + } else if (lengthLeft > lengthRight) { + 1 + } else { + 0 + } + } + + private def compareElements( + arrayLeft: ArrayData, + arrayRight: ArrayData, + dataType: DataType, + position: Int, + ordering: Ordering[Any]): Int = { + val isNullLeft = arrayLeft.isNullAt(position) + val isNullRight = arrayRight.isNullAt(position) + + if (isNullLeft && isNullRight) { + 0 + } else if (isNullLeft) { + -1 + } else if (isNullRight) { + 1 + } else { + ordering.compare( + arrayLeft.get(position, dataType), + arrayRight.get(position, dataType) + ) + } + } + } } class PhysicalNullType() extends PhysicalDataType with PhysicalPrimitiveType { diff --git a/sql/core/src/test/resources/sql-tests/results/mode.sql.out b/sql/core/src/test/resources/sql-tests/results/mode.sql.out index 9eac2c40e3eea..ad7d59eeb1634 100644 --- a/sql/core/src/test/resources/sql-tests/results/mode.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/mode.sql.out @@ -182,15 +182,9 @@ struct> -- !query SELECT mode(col, true) FROM VALUES (map(1, 'a')) AS tab(col) -- !query schema -struct<> +struct> -- !query output -org.apache.spark.SparkIllegalArgumentException -{ - "errorClass" : "_LEGACY_ERROR_TEMP_2005", - "messageParameters" : { - "dataType" : "PhysicalMapType" - } -} +{1:"a"} -- !query diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 3757284d7d3e3..5e7feec149c97 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -1028,6 +1028,123 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { } } + for (collation <- Seq("UTF8_BINARY", "UTF8_LCASE", "UNICODE", "UNICODE_CI", "")) { + for (codeGen <- Seq("NO_CODEGEN", "CODEGEN_ONLY")) { + val collationSetup = if (collation.isEmpty) "" else " COLLATE " + collation + val supportsBinaryEquality = collation.isEmpty || collation == "UNICODE" || + CollationFactory.fetchCollation(collation).supportsBinaryEquality + + test(s"Group by on map containing$collationSetup strings ($codeGen)") { + val tableName = "t" + + withTable(tableName) { + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codeGen) { + sql(s"create table $tableName" + + s" (m map)") + sql(s"insert into $tableName values (map('aaa', 'AAA'))") + sql(s"insert into $tableName values (map('AAA', 'aaa'))") + sql(s"insert into $tableName values (map('aaa', 'AAA'))") + sql(s"insert into $tableName values (map('bbb', 'BBB'))") + sql(s"insert into $tableName values (map('aAA', 'AaA'))") + sql(s"insert into $tableName values (map('BBb', 'bBB'))") + sql(s"insert into $tableName values (map('aaaa', 'AAA'))") + + val df = sql(s"select count(*) from $tableName group by m") + if (supportsBinaryEquality) { + checkAnswer(df, Seq(Row(2), Row(1), Row(1), Row(1), Row(1), Row(1))) + } else { + checkAnswer(df, Seq(Row(4), Row(2), Row(1))) + } + } + } + } + + test(s"Group by on map containing structs with $collationSetup strings ($codeGen)") { + val tableName = "t" + + withTable(tableName) { + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codeGen) { + sql(s"create table $tableName" + + s" (m map, " + + s"struct>)") + sql(s"insert into $tableName values " + + s"(map(struct('aaa', 'bbb'), struct('ccc', 'ddd')))") + sql(s"insert into $tableName values " + + s"(map(struct('Aaa', 'BBB'), struct('cCC', 'dDd')))") + sql(s"insert into $tableName values " + + s"(map(struct('AAA', 'BBb'), struct('cCc', 'DDD')))") + sql(s"insert into $tableName values " + + s"(map(struct('aaa', 'bbB'), struct('CCC', 'DDD')))") + + val df = sql(s"select count(*) from $tableName group by m") + if (supportsBinaryEquality) { + checkAnswer(df, Seq(Row(1), Row(1), Row(1), Row(1))) + } else { + checkAnswer(df, Seq(Row(4))) + } + } + } + } + + test(s"Group by on map containing arrays with$collationSetup strings ($codeGen)") { + val tableName = "t" + + withTable(tableName) { + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codeGen) { + sql(s"create table $tableName " + + s"(m map, array>)") + sql(s"insert into $tableName values (map(array('aaa', 'bbb'), array('ccc', 'ddd')))") + sql(s"insert into $tableName values (map(array('AAA', 'BbB'), array('Ccc', 'ddD')))") + sql(s"insert into $tableName values (map(array('AAA', 'BbB', 'Ccc'), array('ddD')))") + sql(s"insert into $tableName values (map(array('aAa', 'Bbb'), array('CCC', 'DDD')))") + sql(s"insert into $tableName values (map(array('AAa', 'BBb'), array('cCC', 'DDd')))") + sql(s"insert into $tableName values (map(array('AAA', 'BBB', 'CCC'), array('DDD')))") + + val df = sql(s"select count(*) from $tableName group by m") + if (supportsBinaryEquality) { + checkAnswer(df, Seq(Row(1), Row(1), Row(1), Row(1), Row(1), Row(1))) + } else { + checkAnswer(df, Seq(Row(4), Row(2))) + } + } + } + } + + test(s"Check that order by on map with$collationSetup strings fails ($codeGen)") { + val tableName = "t" + withTable(tableName) { + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codeGen) { + sql(s"create table $tableName" + + s" (m map, " + + s" c integer)") + sql(s"insert into $tableName values (map('aaa', 'AAA'), 1)") + sql(s"insert into $tableName values (map('BBb', 'bBB'), 2)") + + // `collationSetupError` is created because "COLLATE UTF8_BINARY" is omitted in data + // type in checkError + val collationSetupError = if (collation != "UTF8_BINARY") collationSetup else "" + val query = s"select c from $tableName order by m" + val ctx = "m" + checkError( + exception = intercept[AnalysisException](sql(query)), + errorClass = "DATATYPE_MISMATCH.INVALID_ORDERING_TYPE", + parameters = Map( + "functionName" -> "`sortorder`", + "dataType" -> s"\"MAP\"", + "sqlExpr" -> "\"m ASC NULLS FIRST\"" + ), + context = ExpectedContext( + fragment = ctx, + start = query.length - ctx.length, + stop = query.length - 1 + ) + ) + } + } + } + } + } + test("Support operations on complex types containing collated strings") { checkAnswer(sql("select reverse('abc' collate utf8_lcase)"), Seq(Row("cba"))) checkAnswer(sql(