diff --git a/core/src/execution/datafusion/expressions/scalar_funcs.rs b/core/src/execution/datafusion/expressions/scalar_funcs.rs index e55425642..d47172dfa 100644 --- a/core/src/execution/datafusion/expressions/scalar_funcs.rs +++ b/core/src/execution/datafusion/expressions/scalar_funcs.rs @@ -24,7 +24,7 @@ use arrow::{ }, datatypes::{validate_decimal_precision, Decimal128Type, Int64Type}, }; -use arrow_array::{Array, ArrowNativeTypeOp, Decimal128Array}; +use arrow_array::{Array, ArrowNativeTypeOp, BooleanArray, Decimal128Array}; use arrow_schema::DataType; use datafusion::{ execution::FunctionRegistry, @@ -129,6 +129,10 @@ pub fn create_comet_physical_fun( let func = Arc::new(spark_chr); make_comet_scalar_udf!("chr", func, without data_type) } + "isnan" => { + let func = Arc::new(spark_isnan); + make_comet_scalar_udf!("isnan", func, without data_type) + } sha if sha2_functions.contains(&sha) => { // Spark requires hex string as the result of sha2 functions, we have to wrap the // result of digest functions as hex string @@ -634,3 +638,49 @@ fn spark_decimal_div( let result = result.with_data_type(DataType::Decimal128(p3, s3)); Ok(ColumnarValue::Array(Arc::new(result))) } + +fn spark_isnan(args: &[ColumnarValue]) -> Result { + fn set_nulls_to_false(is_nan: BooleanArray) -> ColumnarValue { + match is_nan.nulls() { + Some(nulls) => { + let is_not_null = nulls.inner(); + ColumnarValue::Array(Arc::new(BooleanArray::new( + is_nan.values() & is_not_null, + None, + ))) + } + None => ColumnarValue::Array(Arc::new(is_nan)), + } + } + let value = &args[0]; + match value { + ColumnarValue::Array(array) => match array.data_type() { + DataType::Float64 => { + let array = array.as_any().downcast_ref::().unwrap(); + let is_nan = BooleanArray::from_unary(array, |x| x.is_nan()); + Ok(set_nulls_to_false(is_nan)) + } + DataType::Float32 => { + let array = array.as_any().downcast_ref::().unwrap(); + let is_nan = BooleanArray::from_unary(array, |x| x.is_nan()); + Ok(set_nulls_to_false(is_nan)) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function isnan", + other, + ))), + }, + ColumnarValue::Scalar(a) => match a { + ScalarValue::Float64(a) => Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some( + a.map(|x| x.is_nan()).unwrap_or(false), + )))), + ScalarValue::Float32(a) => Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some( + a.map(|x| x.is_nan()).unwrap_or(false), + )))), + _ => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function isnan", + value.data_type(), + ))), + }, + } +} diff --git a/docs/source/user-guide/expressions.md b/docs/source/user-guide/expressions.md index 14b6f18d0..1d4ce3736 100644 --- a/docs/source/user-guide/expressions.md +++ b/docs/source/user-guide/expressions.md @@ -119,6 +119,7 @@ The following Spark expressions are currently available. Any known compatibility | Cos | | | Exp | | | Floor | | +| IsNaN | | | Log | log(0) will produce `-Infinity` unlike Spark which returns `null` | | Log2 | log2(0) will produce `-Infinity` unlike Spark which returns `null` | | Log10 | log10(0) will produce `-Infinity` unlike Spark which returns `null` | diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 3465315d1..4c4530db3 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -1476,6 +1476,13 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim None } + case IsNaN(child) => + val childExpr = exprToProtoInternal(child, inputs) + val optExpr = + scalarExprToProtoWithReturnType("isnan", BooleanType, childExpr) + + optExprWithInfo(optExpr, expr, child) + case SortOrder(child, direction, nullOrdering, _) => val childExpr = exprToProtoInternal(child, inputs) diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index e80549988..498a305ed 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -1722,4 +1722,20 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } } + + test("isnan") { + Seq("true", "false").foreach { dictionary => + withSQLConf("parquet.enable.dictionary" -> dictionary) { + withParquetTable( + Seq(Some(1.0), Some(Double.NaN), None).map(i => Tuple1(i)), + "tbl", + withDictionary = dictionary.toBoolean) { + checkSparkAnswerAndOperator("SELECT isnan(_1), isnan(cast(_1 as float)) FROM tbl") + // Use inside a nullable statement to make sure isnan has correct behavior for null input + checkSparkAnswerAndOperator( + "SELECT CASE WHEN (_1 > 0) THEN NULL ELSE isnan(_1) END FROM tbl") + } + } + } + } }