Skip to content

Commit

Permalink
Support multi string contains [databricks] (#11413)
Browse files Browse the repository at this point in the history
* Add combiner for string contains
* Fix bug: GpuLiteral data for string type shoule be a UTF8String instead of a java String

---------
Signed-off-by: Chong Gao <[email protected]>
Co-authored-by: Chong Gao <[email protected]>
  • Loading branch information
res-life authored Nov 19, 2024
1 parent 9b06ae3 commit b16d107
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 2 deletions.
27 changes: 27 additions & 0 deletions integration_tests/src/main/python/conditionals_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,3 +379,30 @@ def test_case_when_all_then_values_are_scalars_with_nulls():
"tab",
sql_without_else,
conf = {'spark.rapids.sql.case_when.fuse': 'true'})

@pytest.mark.parametrize('combine_string_contains_enabled', [True, False])
def test_combine_string_contains_in_case_when(combine_string_contains_enabled):
data_gen = [("c1", string_gen)]
sql = """
SELECT
CASE
WHEN INSTR(c1, 'a') > 0 THEN 'a'
WHEN INSTR(c1, 'b') > 0 THEN 'b'
WHEN INSTR(c1, 'c') > 0 THEN 'c'
ELSE ''
END as output_1,
CASE
WHEN INSTR(c1, 'c') > 0 THEN 'c'
WHEN INSTR(c1, 'd') > 0 THEN 'd'
WHEN INSTR(c1, 'e') > 0 THEN 'e'
ELSE ''
END as output_2
from tab
"""
# spark.rapids.sql.combined.expressions.enabled is true by default
assert_gpu_and_cpu_are_equal_sql(
lambda spark : gen_df(spark, data_gen),
"tab",
sql,
{ "spark.rapids.sql.expression.combined.GpuContains" : combine_string_contains_enabled}
)
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import java.nio.charset.Charset
import java.text.DecimalFormatSymbols
import java.util.{EnumSet, Locale, Optional}

import scala.annotation.tailrec
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import ai.rapids.cudf.{BinaryOp, BinaryOperable, CaptureGroups, ColumnVector, ColumnView, DType, PadSide, RegexFlag, RegexProgram, RoundMode, Scalar}
Expand All @@ -32,6 +34,7 @@ import com.nvidia.spark.rapids.jni.RegexRewriteUtils
import com.nvidia.spark.rapids.shims.{ShimExpression, SparkShimImpl}

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.rapids.catalyst.expressions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.unsafe.types.UTF8String
Expand Down Expand Up @@ -391,7 +394,8 @@ case class GpuContains(left: Expression, right: Expression)
extends GpuBinaryExpressionArgsAnyScalar
with Predicate
with ImplicitCastInputTypes
with NullIntolerant {
with NullIntolerant
with GpuCombinable {

override def inputTypes: Seq[DataType] = Seq(StringType)

Expand All @@ -411,6 +415,106 @@ case class GpuContains(left: Expression, right: Expression)
doColumnar(expandedLhs, rhs)
}
}

/**
* Get a combiner that can be used to find candidates to combine
*/
override def getCombiner(): GpuExpressionCombiner = new ContainsCombiner(this)
}

case class GpuMultiContains(left: Expression, targets: Seq[UTF8String], output: StructType)
extends GpuExpression with ShimExpression {

override def otherCopyArgs: Seq[AnyRef] = Nil

override def dataType: DataType = output

override def nullable: Boolean = false

override def prettyName: String = "multi_contains"

override def columnarEval(batch: ColumnarBatch): GpuColumnVector = {
val targetsBytes = targets.map(t => t.getBytes).toArray
val boolCvs = withResource(ColumnVector.fromUTF8Strings(targetsBytes: _*)) { targetsCv =>
withResource(left.columnarEval(batch)) { lhs =>
lhs.getBase.stringContains(targetsCv)
}
}
withResource(boolCvs) { _ =>
val retView = ColumnView.makeStructView(batch.numRows(), boolCvs: _*)
GpuColumnVector.from(retView.copyToColumnVector(), dataType)
}
}

override def children: Seq[Expression] = Seq(left)
}

class ContainsCombiner(private val exp: GpuContains) extends GpuExpressionCombiner {
private var outputLocation = 0
/**
* A mapping between an expression and where in the output struct of
* the MultiGetJsonObject will the output be.
*/
private val toCombine = mutable.HashMap.empty[GpuExpressionEquals, Int]
addExpression(exp)

override def toString: String = s"ContainsCombiner $toCombine"

override def hashCode: Int = {
// We already know that we are Contains, and what we can combine is based
// on the string column being the same.
"Contains".hashCode + (exp.left.semanticHash() * 17)
}

/**
* only combine when targets are literals
*/
override def equals(o: Any): Boolean = o match {
case other: ContainsCombiner => exp.left.semanticEquals(other.exp.left) &&
exp.right.isInstanceOf[GpuLiteral] && other.exp.right.isInstanceOf[GpuLiteral]
case _ => false
}

override def addExpression(e: Expression): Unit = {
val localOutputLocation = outputLocation
outputLocation += 1
val key = GpuExpressionEquals(e)
if (!toCombine.contains(key)) {
toCombine.put(key, localOutputLocation)
}
}

override def useCount: Int = toCombine.size

private def fieldName(id: Int): String =
s"_mc_$id"

@tailrec
private def extractLiteral(exp: Expression): GpuLiteral = exp match {
case l: GpuLiteral => l
case a: Alias => extractLiteral(a.child)
case other => throw new RuntimeException("Unsupported expression in contains combiner, " +
"should be a literal type, actual type is " + other.getClass.getName)
}

private lazy val multiContains: GpuMultiContains = {
val input = toCombine.head._1.e.asInstanceOf[GpuContains].left
val fieldsNPaths = toCombine.toSeq.map {
case (k, id) =>
(id, k.e)
}.sortBy(_._1).map {
case (id, e: GpuContains) =>
val target = extractLiteral(e.right).value.asInstanceOf[UTF8String]
(StructField(fieldName(id), e.dataType, e.nullable), target)
}
val dt = StructType(fieldsNPaths.map(_._1))
GpuMultiContains(input, fieldsNPaths.map(_._2), dt)
}

override def getReplacementExpression(e: Expression): Expression = {
val localId = toCombine(GpuExpressionEquals(e))
GpuGetStructField(multiContains, localId, Some(fieldName(localId)))
}
}

case class GpuSubstring(str: Expression, pos: Expression, len: Expression)
Expand Down Expand Up @@ -1097,7 +1201,7 @@ class GpuRLikeMeta(
GpuRLike(lhs, rhs, patternStr)
}
case StartsWith(s) => GpuStartsWith(lhs, GpuLiteral(s, StringType))
case Contains(s) => GpuContains(lhs, GpuLiteral(s, StringType))
case Contains(s) => GpuContains(lhs, GpuLiteral(UTF8String.fromString(s), StringType))
case MultipleContains(ls) => GpuMultipleContains(lhs, ls)
case PrefixRange(s, length, start, end) =>
GpuLiteralRangePattern(lhs, GpuLiteral(s, StringType), length, start, end)
Expand Down

0 comments on commit b16d107

Please sign in to comment.