Skip to content

Commit

Permalink
[SPARK-49202][PS] Apply ArrayBinarySearch for histogram
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
 Apply `ArrayBinarySearch` for histogram

### Why are the changes needed?
this expression is dedicated for histogram, and supports codegen

```

(5) Project [codegen id : 1]
Output [2]: [__group_id#37, cast(CASE WHEN ((__value#38 >= 1.0) AND (__value#38 <= 12.0)) THEN CASE WHEN (__value#38 = 12.0) THEN 11 WHEN (static_invoke(ArrayExpressionUtils.binarySearchNullSafe([1.0,1.9166666666666665,2.833333333333333,3.75,4.666666666666666,5.583333333333333,6.5,7.416666666666666,8.333333333333332,9.25,10.166666666666666,11.083333333333332,12.0], __value#38)) > 0) THEN static_invoke(ArrayExpressionUtils.binarySearchNullSafe([1.0,1.9166666666666665,2.833333333333333,3.75,4.666666666666666,5.583333333333333,6.5,7.416666666666666,8.333333333333332,9.25,10.166666666666666,11.083333333333332,12.0], __value#38)) ELSE (-static_invoke(ArrayExpressionUtils.binarySearchNullSafe([1.0,1.9166666666666665,2.833333333333333,3.75,4.666666666666666,5.583333333333333,6.5,7.416666666666666,8.333333333333332,9.25,10.166666666666666,11.083333333333332,12.0], __value#38)) - 2) END WHEN isnan(__value#38) THEN cast(raise_error(USER_RAISED_EXCEPTION, map(keys: [errorMessage], values: [Histogram encountered NaN value.]), NullType) as int) ELSE cast(raise_error(USER_RAISED_EXCEPTION, map(errorMessage, printf(value %s out of the bins bounds: [%s, %s], __value#38, 1.0, 12.0)), NullType) as int) END as double) AS __bucket#46]
Input [2]: [__group_id#37, __value#38]
```

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
CI and manually check

### Was this patch authored or co-authored using generative AI tooling?
No

Closes apache#47970 from zhengruifeng/ps_apply_binary_search.

Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
  • Loading branch information
zhengruifeng committed Sep 4, 2024
1 parent 39d4bd8 commit 339d1c9
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 21 deletions.
34 changes: 13 additions & 21 deletions python/pyspark/pandas/plot/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# limitations under the License.
#

import bisect
import importlib
import math

Expand All @@ -25,7 +24,6 @@
from pandas.core.dtypes.inference import is_integer

from pyspark.sql import functions as F, Column
from pyspark.sql.types import DoubleType
from pyspark.pandas.spark import functions as SF
from pyspark.pandas.missing import unsupported_function
from pyspark.pandas.config import get_option
Expand Down Expand Up @@ -182,33 +180,27 @@ def compute_hist(psdf, bins):
colnames = sdf.columns
bucket_names = ["__{}_bucket".format(colname) for colname in colnames]

# TODO(SPARK-49202): register this function in scala side
@F.udf(returnType=DoubleType())
def binary_search_for_buckets(value):
# Given bins = [1.0, 2.0, 3.0, 4.0]
# the intervals are:
# [1.0, 2.0) -> 0.0
# [2.0, 3.0) -> 1.0
# [3.0, 4.0] -> 2.0 (the last bucket is a closed interval)
if value < bins[0] or value > bins[-1]:
raise ValueError(f"value {value} out of the bins bounds: [{bins[0]}, {bins[-1]}]")

if value == bins[-1]:
idx = len(bins) - 2
else:
idx = bisect.bisect(bins, value) - 1
return float(idx)
# refers to org.apache.spark.ml.feature.Bucketizer#binarySearchForBuckets
def binary_search_for_buckets(value: Column):
index = SF.binary_search(F.lit(bins), value)
bucket = F.when(index >= 0, index).otherwise(-index - 2)
unboundErrMsg = F.lit(f"value %s out of the bins bounds: [{bins[0]}, {bins[-1]}]")
return (
F.when(value == F.lit(bins[-1]), F.lit(len(bins) - 2))
.when(value.between(F.lit(bins[0]), F.lit(bins[-1])), bucket)
.otherwise(F.raise_error(F.printf(unboundErrMsg, value)))
)

output_df = (
sdf.select(
F.posexplode(
F.array([F.col(colname).cast("double") for colname in colnames])
).alias("__group_id", "__value")
)
# to match handleInvalid="skip" in Bucketizer
.where(F.col("__value").isNotNull() & ~F.col("__value").isNaN()).select(
.where(F.col("__value").isNotNull() & ~F.col("__value").isNaN())
.select(
F.col("__group_id"),
binary_search_for_buckets(F.col("__value")).alias("__bucket"),
binary_search_for_buckets(F.col("__value")).cast("double").alias("__bucket"),
)
)

Expand Down
13 changes: 13 additions & 0 deletions python/pyspark/pandas/spark/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,19 @@ def collect_top_k(col: Column, num: int, reverse: bool) -> Column:
return Column(sc._jvm.PythonSQLUtils.collect_top_k(col._jc, num, reverse))


def binary_search(col: Column, value: Column) -> Column:
if is_remote():
from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns

return _invoke_function_over_columns("array_binary_search", col, value)

else:
from pyspark import SparkContext

sc = SparkContext._active_spark_context
return Column(sc._jvm.PythonSQLUtils.binary_search(col._jc, value._jc))


def make_interval(unit: str, e: Union[Column, int, float]) -> Column:
unit_mapping = {
"YEAR": "years",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ private[sql] object PythonSQLUtils extends Logging {
def collect_top_k(e: Column, num: Int, reverse: Boolean): Column =
Column.internalFn("collect_top_k", e, lit(num), lit(reverse))

def binary_search(e: Column, value: Column): Column =
Column.internalFn("array_binary_search", e, value)

def pandasProduct(e: Column, ignoreNA: Boolean): Column =
Column.internalFn("pandas_product", e, lit(ignoreNA))

Expand Down

0 comments on commit 339d1c9

Please sign in to comment.