diff --git a/datafusion/functions/src/core/mod.rs b/datafusion/functions/src/core/mod.rs index 9aab4bd450d1..db47c622188d 100644 --- a/datafusion/functions/src/core/mod.rs +++ b/datafusion/functions/src/core/mod.rs @@ -18,12 +18,15 @@ //! "core" DataFusion functions mod nullif; +mod nvl; // create UDFs make_udf_function!(nullif::NullIfFunc, NULLIF, nullif); +make_udf_function!(nvl::NVLFunc, NVL, nvl); // Export the functions out of this package, both as expr_fn as well as a list of functions export_functions!( - (nullif, arg_1 arg_2, "returns NULL if value1 equals value2; otherwise it returns value1. This can be used to perform the inverse operation of the COALESCE expression.") + (nullif, arg_1 arg_2, "returns NULL if value1 equals value2; otherwise it returns value1. This can be used to perform the inverse operation of the COALESCE expression."), + (nvl, arg_1 arg_2, "returns value2 if value1 is NULL; otherwise it returns value1") ); diff --git a/datafusion/functions/src/core/nvl.rs b/datafusion/functions/src/core/nvl.rs new file mode 100644 index 000000000000..6d6ad1cdeb21 --- /dev/null +++ b/datafusion/functions/src/core/nvl.rs @@ -0,0 +1,277 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::DataType; +use datafusion_common::{internal_err, Result, DataFusionError}; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use arrow::compute::kernels::zip::zip; +use arrow::compute::is_not_null; +use arrow::array::Array; + +#[derive(Debug)] +pub(super) struct NVLFunc { + signature: Signature, + aliases: Vec, +} + +/// Currently supported types by the nvl/ifnull function. +/// The order of these types correspond to the order on which coercion applies +/// This should thus be from least informative to most informative +static SUPPORTED_NVL_TYPES: &[DataType] = &[ + DataType::Boolean, + DataType::UInt8, + DataType::UInt16, + DataType::UInt32, + DataType::UInt64, + DataType::Int8, + DataType::Int16, + DataType::Int32, + DataType::Int64, + DataType::Float32, + DataType::Float64, + DataType::Utf8, + DataType::LargeUtf8, +]; + +impl NVLFunc { + pub fn new() -> Self { + Self { + signature: + Signature::uniform(2, SUPPORTED_NVL_TYPES.to_vec(), + Volatility::Immutable, + ), + aliases: vec![String::from("ifnull")], + } + } +} + +impl ScalarUDFImpl for NVLFunc { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "nvl" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + // NVL has two args and they might get coerced, get a preview of this + let coerced_types = datafusion_expr::type_coercion::functions::data_types(arg_types, &self.signature); + coerced_types.map(|typs| typs[0].clone()) + .map_err(|e| e.context("Failed to coerce arguments for NVL") + ) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + nvl_func(args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +fn nvl_func(args: &[ColumnarValue]) -> Result { + if args.len() != 2 { + return internal_err!( + "{:?} args were supplied but NVL/IFNULL takes exactly two args", + args.len() + ); + } + let (lhs_array, rhs_array) = match (&args[0], &args[1]) { + (ColumnarValue::Array(lhs), ColumnarValue::Scalar(rhs)) => { + (lhs.clone(), rhs.to_array_of_size(lhs.len())?) + } + (ColumnarValue::Array(lhs), ColumnarValue::Array(rhs)) => { + (lhs.clone(), rhs.clone()) + } + (ColumnarValue::Scalar(lhs), ColumnarValue::Array(rhs)) => { + (lhs.to_array_of_size(rhs.len())?, rhs.clone()) + } + (ColumnarValue::Scalar(lhs), ColumnarValue::Scalar(rhs)) => { + let mut current_value = lhs; + if lhs.is_null() { + current_value = rhs; + } + return Ok(ColumnarValue::Scalar(current_value.clone())); + } + }; + let to_apply = is_not_null(&lhs_array)?; + let value = zip(&to_apply, &lhs_array, &rhs_array)?; + Ok(ColumnarValue::Array(value)) +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow::array::*; + + use super::*; + use datafusion_common::{Result, ScalarValue}; + + #[test] + fn nvl_int32() -> Result<()> { + let a = Int32Array::from(vec![ + Some(1), + Some(2), + None, + None, + Some(3), + None, + None, + Some(4), + Some(5), + ]); + let a = ColumnarValue::Array(Arc::new(a)); + + let lit_array = ColumnarValue::Scalar(ScalarValue::Int32(Some(6i32))); + + let result = nvl_func(&[a, lit_array])?; + let result = result.into_array(0).expect("Failed to convert to array"); + + let expected = Arc::new(Int32Array::from(vec![ + Some(1), + Some(2), + Some(6), + Some(6), + Some(3), + Some(6), + Some(6), + Some(4), + Some(5), + ])) as ArrayRef; + assert_eq!(expected.as_ref(), result.as_ref()); + Ok(()) + } + + #[test] + // Ensure that arrays with no nulls can also invoke nvl() correctly + fn nvl_int32_nonulls() -> Result<()> { + let a = Int32Array::from(vec![1, 3, 10, 7, 8, 1, 2, 4, 5]); + let a = ColumnarValue::Array(Arc::new(a)); + + let lit_array = ColumnarValue::Scalar(ScalarValue::Int32(Some(20i32))); + + let result = nvl_func(&[a, lit_array])?; + let result = result.into_array(0).expect("Failed to convert to array"); + + let expected = Arc::new(Int32Array::from(vec![ + Some(1), + Some(3), + Some(10), + Some(7), + Some(8), + Some(1), + Some(2), + Some(4), + Some(5), + ])) as ArrayRef; + assert_eq!(expected.as_ref(), result.as_ref()); + Ok(()) + } + + #[test] + fn nvl_boolean() -> Result<()> { + let a = BooleanArray::from(vec![Some(true), Some(false), None]); + let a = ColumnarValue::Array(Arc::new(a)); + + let lit_array = ColumnarValue::Scalar(ScalarValue::Boolean(Some(false))); + + let result = nvl_func(&[a, lit_array])?; + let result = result.into_array(0).expect("Failed to convert to array"); + + let expected = + Arc::new(BooleanArray::from(vec![Some(true), Some(false), Some(false)])) as ArrayRef; + + assert_eq!(expected.as_ref(), result.as_ref()); + Ok(()) + } + + #[test] + fn nvl_string() -> Result<()> { + let a = StringArray::from(vec![Some("foo"), Some("bar"), None, Some("baz")]); + let a = ColumnarValue::Array(Arc::new(a)); + + let lit_array = ColumnarValue::Scalar(ScalarValue::from("bax")); + + let result = nvl_func(&[a, lit_array])?; + let result = result.into_array(0).expect("Failed to convert to array"); + + let expected = Arc::new(StringArray::from(vec![ + Some("foo"), + Some("bar"), + Some("bax"), + Some("baz"), + ])) as ArrayRef; + + assert_eq!(expected.as_ref(), result.as_ref()); + Ok(()) + } + + #[test] + fn nvl_literal_first() -> Result<()> { + let a = Int32Array::from(vec![Some(1), Some(2), None, None, Some(3), Some(4)]); + let a = ColumnarValue::Array(Arc::new(a)); + + let lit_array = ColumnarValue::Scalar(ScalarValue::Int32(Some(2i32))); + + let result = nvl_func(&[lit_array, a])?; + let result = result.into_array(0).expect("Failed to convert to array"); + + let expected = Arc::new(Int32Array::from(vec![ + Some(2), + Some(2), + Some(2), + Some(2), + Some(2), + Some(2), + ])) as ArrayRef; + assert_eq!(expected.as_ref(), result.as_ref()); + Ok(()) + } + + #[test] + fn nvl_scalar() -> Result<()> { + let a_null = ColumnarValue::Scalar(ScalarValue::Int32(None)); + let b_null = ColumnarValue::Scalar(ScalarValue::Int32(Some(2i32))); + + let result_null = nvl_func(&[a_null, b_null])?; + let result_null = result_null.into_array(1).expect("Failed to convert to array"); + + let expected_null = Arc::new(Int32Array::from(vec![Some(2i32)])) as ArrayRef; + + assert_eq!(expected_null.as_ref(), result_null.as_ref()); + + let a_nnull = ColumnarValue::Scalar(ScalarValue::Int32(Some(2i32))); + let b_nnull = ColumnarValue::Scalar(ScalarValue::Int32(Some(1i32))); + + let result_nnull = nvl_func(&[a_nnull, b_nnull])?; + let result_nnull = result_nnull + .into_array(1) + .expect("Failed to convert to array"); + + let expected_nnull = Arc::new(Int32Array::from(vec![Some(2i32)])) as ArrayRef; + assert_eq!(expected_nnull.as_ref(), result_nnull.as_ref()); + + Ok(()) + } +} diff --git a/datafusion/sqllogictest/test_files/nvl.slt b/datafusion/sqllogictest/test_files/nvl.slt new file mode 100644 index 000000000000..81e79e1eb5b0 --- /dev/null +++ b/datafusion/sqllogictest/test_files/nvl.slt @@ -0,0 +1,120 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +statement ok +CREATE TABLE test( + int_field INT, + bool_field BOOLEAN, + text_field TEXT, + more_ints INT +) as VALUES + (1, true, 'abc', 2), + (2, false, 'def', 2), + (3, NULL, 'ghij', 3), + (NULL, NULL, NULL, 4), + (4, false, 'zxc', 5), + (NULL, true, NULL, 6) +; + +# Arrays tests +query I +SELECT NVL(int_field, 2) FROM test; +---- +1 +2 +3 +2 +4 +2 + + +query B +SELECT NVL(bool_field, false) FROM test; +---- +true +false +false +false +false +true + + +query T +SELECT NVL(text_field, 'zxb') FROM test; +---- +abc +def +ghij +zxb +zxc +zxb + + +query I +SELECT IFNULL(int_field, more_ints) FROM test; +---- +1 +2 +3 +4 +4 +6 + + +query I +SELECT NVL(3, int_field) FROM test; +---- +3 +3 +3 +3 +3 +3 + + +# Scalar values tests +query I +SELECT NVL(1, 1); +---- +1 + +query I +SELECT NVL(1, 3); +---- +1 + +query I +SELECT NVL(NULL, NULL); +---- +NULL diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 707e8c24b326..d4eb5944ad09 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -569,6 +569,8 @@ trunc(numeric_expression[, decimal_places]) - [coalesce](#coalesce) - [nullif](#nullif) +- [nvl](#nvl) +- [ifnull](#ifnull) ### `coalesce` @@ -603,6 +605,25 @@ nullif(expression1, expression2) - **expression2**: Expression to compare to expression1. Can be a constant, column, or function, and any combination of arithmetic operators. +### `nvl` + +Returns _expression2_ if _expression1_ is NULL; otherwise it returns _expression1_. + +``` +nvl(expression1, expression2) +``` + +#### Arguments + +- **expression1**: return if expression1 not is NULL. + Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression2**: return if expression1 is NULL. + Can be a constant, column, or function, and any combination of arithmetic operators. + +### `ifnull` + +_Alias of [nvl](#nvl)._ + ## String Functions - [ascii](#ascii)