From 3677ff69a24769518a67ee7e8c8eb1991eb8a8b9 Mon Sep 17 00:00:00 2001 From: Chong Gao Date: Mon, 26 Aug 2024 11:16:39 +0800 Subject: [PATCH] Add combiner for string contains Signed-off-by: Chong Gao --- .../src/main/python/conditionals_test.py | 18 +++ .../spark/sql/rapids/stringFunctions.scala | 109 +++++++++++++++++- 2 files changed, 123 insertions(+), 4 deletions(-) diff --git a/integration_tests/src/main/python/conditionals_test.py b/integration_tests/src/main/python/conditionals_test.py index b95ed53f398..a2fb24a7434 100644 --- a/integration_tests/src/main/python/conditionals_test.py +++ b/integration_tests/src/main/python/conditionals_test.py @@ -379,3 +379,21 @@ def test_case_when_all_then_values_are_scalars_with_nulls(): "tab", sql_without_else, conf = {'spark.rapids.sql.case_when.fuse': 'true'}) + +@pytest.mark.parametrize('combine_string_contains_enabled', ['true', 'false']) +def test_combine_string_contains_in_case_when(combine_string_contains_enabled): + data_gen = [("c1", string_gen)] + sql = """ + SELECT + INSTR(c1, 'substring1') > 0, + INSTR(c1, 'substring2') > 0, + INSTR(c1, 'substring3') > 0 + from tab + """ + # spark.rapids.sql.combined.expressions.enabled is true by default + assert_gpu_and_cpu_are_equal_sql( + lambda spark : gen_df(spark, data_gen), + "tab", + sql, + { "spark.rapids.sql.expression.combined.GpuContains" : combine_string_contains_enabled} + ) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala index c8a90dc80ad..3ae66ada010 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala @@ -20,6 +20,8 @@ import java.nio.charset.Charset import java.text.DecimalFormatSymbols import java.util.{Locale, Optional} +import scala.annotation.tailrec +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import ai.rapids.cudf.{BinaryOp, BinaryOperable, CaptureGroups, ColumnVector, ColumnView, DType, PadSide, RegexProgram, RoundMode, Scalar} @@ -32,6 +34,7 @@ import com.nvidia.spark.rapids.jni.RegexRewriteUtils import com.nvidia.spark.rapids.shims.{ShimExpression, SparkShimImpl} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.rapids.catalyst.expressions.{GpuCombinable, GpuExpressionCombiner, GpuExpressionEquals} import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.unsafe.types.UTF8String @@ -388,10 +391,11 @@ case class GpuConcatWs(children: Seq[Expression]) } case class GpuContains(left: Expression, right: Expression) - extends GpuBinaryExpressionArgsAnyScalar - with Predicate - with ImplicitCastInputTypes - with NullIntolerant { + extends GpuBinaryExpressionArgsAnyScalar + with Predicate + with ImplicitCastInputTypes + with NullIntolerant + with GpuCombinable { override def inputTypes: Seq[DataType] = Seq(StringType) @@ -411,6 +415,103 @@ case class GpuContains(left: Expression, right: Expression) doColumnar(expandedLhs, rhs) } } + + /** + * Get a combiner that can be used to find candidates to combine + */ + override def getCombiner(): GpuExpressionCombiner = new ContainsCombiner(this) +} + +case class GpuMultiContains(left: Expression, targets: Seq[UTF8String], output: StructType) + extends GpuExpression with ShimExpression { + + override def otherCopyArgs: Seq[AnyRef] = Nil + + override def dataType: DataType = output + + override def nullable: Boolean = false + + override def prettyName: String = "multi_contains" + + override def columnarEval(batch: ColumnarBatch): GpuColumnVector = { + val targetsBytes = targets.map(t => t.getBytes).toArray + withResource(ColumnVector.fromUTF8Strings(targetsBytes : _*)) { targetsCv => + withResource(left.columnarEval(batch)) { lhs => + withResource(lhs.getBase.stringContains(targetsCv)) { boolCvs => + GpuColumnVector.from(ColumnVector.makeStruct(batch.numRows(), boolCvs: _*), dataType) + } + } + } + } + override def children: Seq[Expression] = Seq(left) +} + +class ContainsCombiner(private val exp: GpuContains) extends GpuExpressionCombiner { + private var outputLocation = 0 + /** + * A mapping between an expression and where in the output struct of + * the MultiGetJsonObject will the output be. + */ + private val toCombine = mutable.HashMap.empty[GpuExpressionEquals, Int] + addExpression(exp) + + override def toString: String = s"ContainsCombiner $toCombine" + + override def hashCode: Int = { + // We already know that we are Contains, and what we can combine is based + // on the string column being the same. + "Contains".hashCode + (exp.left.semanticHash() * 17) + } + + /** + * only combine when targets are literals + */ + override def equals(o: Any): Boolean = o match { + case other: ContainsCombiner => exp.left.semanticEquals(other.exp.left) && + exp.right.isInstanceOf[GpuLiteral] && other.exp.right.isInstanceOf[GpuLiteral] + case _ => false + } + + override def addExpression(e: Expression): Unit = { + val localOutputLocation = outputLocation + outputLocation += 1 + val key = GpuExpressionEquals(e) + if (!toCombine.contains(key)) { + toCombine.put(key, localOutputLocation) + } + } + + override def useCount: Int = toCombine.size + + private def fieldName(id: Int): String = + s"_mc_$id" + + @tailrec + private def extractLiteral(exp: Expression): GpuLiteral = exp match { + case l: GpuLiteral => l + case a: Alias => extractLiteral(a.child) + case other => throw new RuntimeException("Unsupported expression in contains combiner, " + + "should be a literal type, actual type is " + other.getClass.getName) + } + + private lazy val multiContains: GpuMultiContains = { + val input = toCombine.head._1.e.asInstanceOf[GpuContains].left + val fieldsNPaths = toCombine.toSeq.map { + case (k, id) => + (id, k.e) + }.sortBy(_._1).map { + case (id, e: GpuContains) => + val target = extractLiteral(e.right).value.asInstanceOf[UTF8String] + (StructField(fieldName(id), e.dataType, e.nullable), target) + } + val dt = StructType(fieldsNPaths.map(_._1)) + GpuMultiContains(input, fieldsNPaths.map(_._2), dt) + } + + override def getReplacementExpression(e: Expression): Expression = { + val localId = toCombine(GpuExpressionEquals(e)) + GpuGetStructField(multiContains, localId, Some(fieldName(localId))) + } } case class GpuSubstring(str: Expression, pos: Expression, len: Expression)