Skip to content

Commit

Permalink
feat: support nvl2 function (#9364)
Browse files Browse the repository at this point in the history
* feat: support nvl2 function

* fix signature && test case
  • Loading branch information
guojidan authored Feb 28, 2024
1 parent ae4b3a0 commit 5f90ead
Show file tree
Hide file tree
Showing 4 changed files with 257 additions and 1 deletion.
5 changes: 4 additions & 1 deletion datafusion/functions/src/core/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,17 @@
mod nullif;
mod nvl;
mod nvl2;

// create UDFs
make_udf_function!(nullif::NullIfFunc, NULLIF, nullif);
make_udf_function!(nvl::NVLFunc, NVL, nvl);
make_udf_function!(nvl2::NVL2Func, NVL2, nvl2);

// Export the functions out of this package, both as expr_fn as well as a list of functions
export_functions!(
(nullif, arg_1 arg_2, "returns NULL if value1 equals value2; otherwise it returns value1. This can be used to perform the inverse operation of the COALESCE expression."),
(nvl, arg_1 arg_2, "returns value2 if value1 is NULL; otherwise it returns value1")
(nvl, arg_1 arg_2, "returns value2 if value1 is NULL; otherwise it returns value1"),
(nvl2, arg_1 arg_2 arg_3, "Returns value2 if value1 is not NULL; otherwise, it returns value3.")
);

115 changes: 115 additions & 0 deletions datafusion/functions/src/core/nvl2.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use arrow::datatypes::DataType;
use datafusion_common::{internal_err, plan_datafusion_err, DataFusionError, Result};
use datafusion_expr::{utils, ColumnarValue, ScalarUDFImpl, Signature, Volatility};
use arrow::compute::kernels::zip::zip;
use arrow::compute::is_not_null;
use arrow::array::Array;

#[derive(Debug)]
pub(super) struct NVL2Func {
signature: Signature,
}

impl NVL2Func {
pub fn new() -> Self {
Self {
signature:
Signature::variadic_equal(
Volatility::Immutable,
),
}
}
}

impl ScalarUDFImpl for NVL2Func {
fn as_any(&self) -> &dyn std::any::Any {
self
}

fn name(&self) -> &str {
"nvl2"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
if arg_types.len() != 3 {
return Err(plan_datafusion_err!(
"{}",
utils::generate_signature_error_msg(
self.name(),
self.signature().clone(),
arg_types,
)
));
}
Ok(arg_types[1].clone())
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
nvl2_func(args)
}
}

fn nvl2_func(args: &[ColumnarValue]) -> Result<ColumnarValue> {
if args.len() != 3 {
return internal_err!(
"{:?} args were supplied but NVL2 takes exactly three args",
args.len()
);
}
let mut len = 1;
let mut is_array = false;
for arg in args {
if let ColumnarValue::Array(array) = arg {
len = array.len();
is_array = true;
break;
}
}
if is_array {
let args = args.iter().map(|arg| match arg {
ColumnarValue::Scalar(scalar) => {
scalar.to_array_of_size(len)
}
ColumnarValue::Array(array) => {
Ok(array.clone())
}
}).collect::<Result<Vec<_>>>()?;
let to_apply = is_not_null(&args[0])?;
let value = zip(&to_apply, &args[1], &args[2])?;
Ok(ColumnarValue::Array(value))
} else {
let mut current_value = &args[1];
match &args[0] {
ColumnarValue::Array(_) => {
internal_err!("except Scalar value, but got Array")
}
ColumnarValue::Scalar(scalar) => {
if scalar.is_null() {
current_value = &args[2];
}
Ok(current_value.clone())
}
}
}
}
120 changes: 120 additions & 0 deletions datafusion/sqllogictest/test_files/nvl2.slt
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

statement ok
CREATE TABLE test(
int_field INT,
bool_field BOOLEAN,
text_field TEXT,
more_ints INT
) as VALUES
(1, true, 'abc', 2),
(2, false, 'def', 2),
(3, NULL, 'ghij', 3),
(NULL, NULL, NULL, 4),
(4, false, 'zxc', 5),
(NULL, true, NULL, 6)
;

# Arrays tests
query I
SELECT NVL2(int_field, 2, 3) FROM test ORDER BY more_ints;;
----
2
2
2
3
2
3


query B
SELECT NVL2(bool_field, false, true) FROM test ORDER BY more_ints;;
----
false
false
true
true
false
false


query T
SELECT NVL2(text_field, 'zxb', 'xyz') FROM test ORDER BY more_ints;;
----
zxb
zxb
zxb
xyz
zxb
xyz


query I
SELECT NVL2(int_field, more_ints, 10) FROM test ORDER BY more_ints;;
----
2
2
3
10
5
10


query I
SELECT NVL2(3, int_field, more_ints) FROM test ORDER BY more_ints;;
----
1
2
3
NULL
4
NULL


# Scalar values tests
query I
SELECT NVL2(1, 2, 3);
----
2

query I
SELECT NVL2(NULL, 2, 3);
----
3

query ?
SELECT NVL2(NULL, NULL, NULL);
----
NULL
18 changes: 18 additions & 0 deletions docs/source/user-guide/sql/scalar_functions.md
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,7 @@ trunc(numeric_expression[, decimal_places])
- [coalesce](#coalesce)
- [nullif](#nullif)
- [nvl](#nvl)
- [nvl2](#nvl2)
- [ifnull](#ifnull)

### `coalesce`
Expand Down Expand Up @@ -620,6 +621,23 @@ nvl(expression1, expression2)
- **expression2**: return if expression1 is NULL.
Can be a constant, column, or function, and any combination of arithmetic operators.

### `nvl2`

Returns _expression2_ if _expression1_ is not NULL; otherwise it returns _expression3_.

```
nvl2(expression1, expression2, expression3)
```

#### Arguments

- **expression1**: conditional expression.
Can be a constant, column, or function, and any combination of arithmetic operators.
- **expression2**: return if expression1 is not NULL.
Can be a constant, column, or function, and any combination of arithmetic operators.
- **expression3**: return if expression1 is NULL.
Can be a constant, column, or function, and any combination of arithmetic operators.

### `ifnull`

_Alias of [nvl](#nvl)._
Expand Down

0 comments on commit 5f90ead

Please sign in to comment.