Skip to content

Commit

Permalink
Validate and unpack function arguments tersely (apache#14513)
Browse files Browse the repository at this point in the history
* Validate and unpack function arguments tersely

Add a `take_function_args` helper that provides convenient unpacking of
function arguments along with validation that the provided argument
count matches the expected.  A few functions are updated to leverage the
new pattern to demonstrate its usefulness.

* Add example in rust doc

Co-authored-by: Andrew Lamb <[email protected]>

* fix fmt

* Export function utils publicly

this exports only the newly added take_function_args function. all other
utils members are pub(crate)

* use compact format pattern

Co-authored-by: Matthijs Brobbel <[email protected]>

* fix example

* fixup! fix example

* fix license header

Co-authored-by: Oleks V <[email protected]>

* Name args in nvl2 and use take_function_args in execution too

---------

Co-authored-by: Andrew Lamb <[email protected]>
Co-authored-by: Matthijs Brobbel <[email protected]>
Co-authored-by: Oleks V <[email protected]>
  • Loading branch information
4 people authored Feb 5, 2025
1 parent 304488d commit 5239d1a
Show file tree
Hide file tree
Showing 14 changed files with 132 additions and 146 deletions.
6 changes: 3 additions & 3 deletions datafusion/functions/src/core/arrow_cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use datafusion_common::{
use datafusion_common::{exec_datafusion_err, DataFusionError};
use std::any::Any;

use crate::utils::take_function_args;
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
use datafusion_expr::{
ColumnarValue, Documentation, Expr, ReturnInfo, ReturnTypeArgs, ScalarUDFImpl,
Expand Down Expand Up @@ -117,10 +118,9 @@ impl ScalarUDFImpl for ArrowCastFunc {
fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result<ReturnInfo> {
let nullable = args.nullables.iter().any(|&nullable| nullable);

// Length check handled in the signature
debug_assert_eq!(args.scalar_arguments.len(), 2);
let [_, type_arg] = take_function_args(self.name(), args.scalar_arguments)?;

args.scalar_arguments[1]
type_arg
.and_then(|sv| sv.try_as_str().flatten().filter(|s| !s.is_empty()))
.map_or_else(
|| {
Expand Down
13 changes: 4 additions & 9 deletions datafusion/functions/src/core/arrowtypeof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
// specific language governing permissions and limitations
// under the License.

use crate::utils::take_function_args;
use arrow::datatypes::DataType;
use datafusion_common::{exec_err, Result, ScalarValue};
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::{ColumnarValue, Documentation};
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
use datafusion_macros::user_doc;
Expand Down Expand Up @@ -80,14 +81,8 @@ impl ScalarUDFImpl for ArrowTypeOfFunc {
args: &[ColumnarValue],
_number_rows: usize,
) -> Result<ColumnarValue> {
if args.len() != 1 {
return exec_err!(
"arrow_typeof function requires 1 arguments, got {}",
args.len()
);
}

let input_data_type = args[0].data_type();
let [arg] = take_function_args(self.name(), args)?;
let input_data_type = arg.data_type();
Ok(ColumnarValue::Scalar(ScalarValue::from(format!(
"{input_data_type}"
))))
Expand Down
12 changes: 4 additions & 8 deletions datafusion/functions/src/core/getfield.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use crate::utils::take_function_args;
use arrow::array::{
make_array, Array, Capacities, MutableArrayData, Scalar, StringArray,
};
Expand Down Expand Up @@ -99,14 +100,9 @@ impl ScalarUDFImpl for GetFieldFunc {
}

fn display_name(&self, args: &[Expr]) -> Result<String> {
if args.len() != 2 {
return exec_err!(
"get_field function requires 2 arguments, got {}",
args.len()
);
}
let [base, field_name] = take_function_args(self.name(), args)?;

let name = match &args[1] {
let name = match field_name {
Expr::Literal(name) => name,
_ => {
return exec_err!(
Expand All @@ -115,7 +111,7 @@ impl ScalarUDFImpl for GetFieldFunc {
}
};

Ok(format!("{}[{}]", args[0], name))
Ok(format!("{base}[{name}]"))
}

fn schema_name(&self, args: &[Expr]) -> Result<String> {
Expand Down
13 changes: 4 additions & 9 deletions datafusion/functions/src/core/nullif.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,17 @@
// under the License.

use arrow::datatypes::DataType;
use datafusion_common::{exec_err, Result};
use datafusion_common::Result;
use datafusion_expr::{ColumnarValue, Documentation};

use crate::utils::take_function_args;
use arrow::compute::kernels::cmp::eq;
use arrow::compute::kernels::nullif::nullif;
use datafusion_common::ScalarValue;
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
use datafusion_macros::user_doc;
use std::any::Any;

#[user_doc(
doc_section(label = "Conditional Functions"),
description = "Returns _null_ if _expression1_ equals _expression2_; otherwise it returns _expression1_.
Expand Down Expand Up @@ -119,14 +121,7 @@ impl ScalarUDFImpl for NullIfFunc {
/// 1 - if the left is equal to this expr2, then the result is NULL, otherwise left value is passed.
///
fn nullif_func(args: &[ColumnarValue]) -> Result<ColumnarValue> {
if args.len() != 2 {
return exec_err!(
"{:?} args were supplied but NULLIF takes exactly two args",
args.len()
);
}

let (lhs, rhs) = (&args[0], &args[1]);
let [lhs, rhs] = take_function_args("nullif", args)?;

match (lhs, rhs) {
(ColumnarValue::Array(lhs), ColumnarValue::Scalar(rhs)) => {
Expand Down
13 changes: 5 additions & 8 deletions datafusion/functions/src/core/nvl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,18 @@
// specific language governing permissions and limitations
// under the License.

use crate::utils::take_function_args;
use arrow::array::Array;
use arrow::compute::is_not_null;
use arrow::compute::kernels::zip::zip;
use arrow::datatypes::DataType;
use datafusion_common::{internal_err, Result};
use datafusion_common::Result;
use datafusion_expr::{
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
};
use datafusion_macros::user_doc;
use std::sync::Arc;

#[user_doc(
doc_section(label = "Conditional Functions"),
description = "Returns _expression2_ if _expression1_ is NULL otherwise it returns _expression1_.",
Expand Down Expand Up @@ -133,13 +135,8 @@ impl ScalarUDFImpl for NVLFunc {
}

fn nvl_func(args: &[ColumnarValue]) -> Result<ColumnarValue> {
if args.len() != 2 {
return internal_err!(
"{:?} args were supplied but NVL/IFNULL takes exactly two args",
args.len()
);
}
let (lhs_array, rhs_array) = match (&args[0], &args[1]) {
let [lhs, rhs] = take_function_args("nvl/ifnull", args)?;
let (lhs_array, rhs_array) = match (lhs, rhs) {
(ColumnarValue::Array(lhs), ColumnarValue::Scalar(rhs)) => {
(Arc::clone(lhs), rhs.to_array_of_size(lhs.len())?)
}
Expand Down
60 changes: 26 additions & 34 deletions datafusion/functions/src/core/nvl2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
// specific language governing permissions and limitations
// under the License.

use crate::utils::take_function_args;
use arrow::array::Array;
use arrow::compute::is_not_null;
use arrow::compute::kernels::zip::zip;
use arrow::datatypes::DataType;
use datafusion_common::{exec_err, internal_err, Result};
use datafusion_common::{internal_err, Result};
use datafusion_expr::{
type_coercion::binary::comparison_coercion, ColumnarValue, Documentation,
ScalarUDFImpl, Signature, Volatility,
Expand Down Expand Up @@ -104,27 +105,22 @@ impl ScalarUDFImpl for NVL2Func {
}

fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
if arg_types.len() != 3 {
return exec_err!(
"NVL2 takes exactly three arguments, but got {}",
arg_types.len()
);
}
let new_type = arg_types.iter().skip(1).try_fold(
arg_types.first().unwrap().clone(),
|acc, x| {
// The coerced types found by `comparison_coercion` are not guaranteed to be
// coercible for the arguments. `comparison_coercion` returns more loose
// types that can be coerced to both `acc` and `x` for comparison purpose.
// See `maybe_data_types` for the actual coercion.
let coerced_type = comparison_coercion(&acc, x);
if let Some(coerced_type) = coerced_type {
Ok(coerced_type)
} else {
internal_err!("Coercion from {acc:?} to {x:?} failed.")
}
},
)?;
let [tested, if_non_null, if_null] = take_function_args(self.name(), arg_types)?;
let new_type =
[if_non_null, if_null]
.iter()
.try_fold(tested.clone(), |acc, x| {
// The coerced types found by `comparison_coercion` are not guaranteed to be
// coercible for the arguments. `comparison_coercion` returns more loose
// types that can be coerced to both `acc` and `x` for comparison purpose.
// See `maybe_data_types` for the actual coercion.
let coerced_type = comparison_coercion(&acc, x);
if let Some(coerced_type) = coerced_type {
Ok(coerced_type)
} else {
internal_err!("Coercion from {acc:?} to {x:?} failed.")
}
})?;
Ok(vec![new_type; arg_types.len()])
}

Expand All @@ -134,12 +130,6 @@ impl ScalarUDFImpl for NVL2Func {
}

fn nvl2_func(args: &[ColumnarValue]) -> Result<ColumnarValue> {
if args.len() != 3 {
return internal_err!(
"{:?} args were supplied but NVL2 takes exactly three args",
args.len()
);
}
let mut len = 1;
let mut is_array = false;
for arg in args {
Expand All @@ -157,20 +147,22 @@ fn nvl2_func(args: &[ColumnarValue]) -> Result<ColumnarValue> {
ColumnarValue::Array(array) => Ok(Arc::clone(array)),
})
.collect::<Result<Vec<_>>>()?;
let to_apply = is_not_null(&args[0])?;
let value = zip(&to_apply, &args[1], &args[2])?;
let [tested, if_non_null, if_null] = take_function_args("nvl2", args)?;
let to_apply = is_not_null(&tested)?;
let value = zip(&to_apply, &if_non_null, &if_null)?;
Ok(ColumnarValue::Array(value))
} else {
let mut current_value = &args[1];
match &args[0] {
let [tested, if_non_null, if_null] = take_function_args("nvl2", args)?;
match &tested {
ColumnarValue::Array(_) => {
internal_err!("except Scalar value, but got Array")
}
ColumnarValue::Scalar(scalar) => {
if scalar.is_null() {
current_value = &args[2];
Ok(if_null.clone())
} else {
Ok(if_non_null.clone())
}
Ok(current_value.clone())
}
}
}
Expand Down
15 changes: 6 additions & 9 deletions datafusion/functions/src/core/version.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@

//! [`VersionFunc`]: Implementation of the `version` function.
use crate::utils::take_function_args;
use arrow::datatypes::DataType;
use datafusion_common::{internal_err, plan_err, Result, ScalarValue};
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::{
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
};
use datafusion_macros::user_doc;
use std::any::Any;

#[user_doc(
doc_section(label = "Other Functions"),
description = "Returns the version of DataFusion.",
Expand Down Expand Up @@ -70,21 +72,16 @@ impl ScalarUDFImpl for VersionFunc {
}

fn return_type(&self, args: &[DataType]) -> Result<DataType> {
if args.is_empty() {
Ok(DataType::Utf8)
} else {
plan_err!("version expects no arguments")
}
let [] = take_function_args(self.name(), args)?;
Ok(DataType::Utf8)
}

fn invoke_batch(
&self,
args: &[ColumnarValue],
_number_rows: usize,
) -> Result<ColumnarValue> {
if !args.is_empty() {
return internal_err!("{} function does not accept arguments", self.name());
}
let [] = take_function_args(self.name(), args)?;
// TODO it would be great to add rust version and arrow version,
// but that requires a `build.rs` script and/or adding a version const to arrow-rs
let version = format!(
Expand Down
33 changes: 8 additions & 25 deletions datafusion/functions/src/crypto/basic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use blake2::{Blake2b512, Blake2s256, Digest};
use blake3::Hasher as Blake3;
use datafusion_common::cast::as_binary_array;

use crate::utils::take_function_args;
use arrow::compute::StringArrayType;
use datafusion_common::plan_err;
use datafusion_common::{
Expand All @@ -41,14 +42,8 @@ macro_rules! define_digest_function {
($NAME: ident, $METHOD: ident, $DOC: expr) => {
#[doc = $DOC]
pub fn $NAME(args: &[ColumnarValue]) -> Result<ColumnarValue> {
if args.len() != 1 {
return exec_err!(
"{:?} args were supplied but {} takes exactly one argument",
args.len(),
DigestAlgorithm::$METHOD.to_string()
);
}
digest_process(&args[0], DigestAlgorithm::$METHOD)
let [data] = take_function_args(&DigestAlgorithm::$METHOD.to_string(), args)?;
digest_process(data, DigestAlgorithm::$METHOD)
}
};
}
Expand Down Expand Up @@ -114,13 +109,8 @@ pub enum DigestAlgorithm {
/// Second argument is the algorithm to use.
/// Standard algorithms are md5, sha1, sha224, sha256, sha384 and sha512.
pub fn digest(args: &[ColumnarValue]) -> Result<ColumnarValue> {
if args.len() != 2 {
return exec_err!(
"{:?} args were supplied but digest takes exactly two arguments",
args.len()
);
}
let digest_algorithm = match &args[1] {
let [data, digest_algorithm] = take_function_args("digest", args)?;
let digest_algorithm = match digest_algorithm {
ColumnarValue::Scalar(scalar) => match scalar.try_as_str() {
Some(Some(method)) => method.parse::<DigestAlgorithm>(),
_ => exec_err!("Unsupported data type {scalar:?} for function digest"),
Expand All @@ -129,7 +119,7 @@ pub fn digest(args: &[ColumnarValue]) -> Result<ColumnarValue> {
internal_err!("Digest using dynamically decided method is not yet supported")
}
}?;
digest_process(&args[0], digest_algorithm)
digest_process(data, digest_algorithm)
}

impl FromStr for DigestAlgorithm {
Expand Down Expand Up @@ -175,15 +165,8 @@ impl fmt::Display for DigestAlgorithm {

/// computes md5 hash digest of the given input
pub fn md5(args: &[ColumnarValue]) -> Result<ColumnarValue> {
if args.len() != 1 {
return exec_err!(
"{:?} args were supplied but {} takes exactly one argument",
args.len(),
DigestAlgorithm::Md5
);
}

let value = digest_process(&args[0], DigestAlgorithm::Md5)?;
let [data] = take_function_args("md5", args)?;
let value = digest_process(data, DigestAlgorithm::Md5)?;

// md5 requires special handling because of its unique utf8 return type
Ok(match value {
Expand Down
6 changes: 3 additions & 3 deletions datafusion/functions/src/datetime/date_part.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use arrow::datatypes::DataType::{
use arrow::datatypes::TimeUnit::{Microsecond, Millisecond, Nanosecond, Second};
use arrow::datatypes::{DataType, TimeUnit};

use crate::utils::take_function_args;
use datafusion_common::not_impl_err;
use datafusion_common::{
cast::{
Expand Down Expand Up @@ -140,10 +141,9 @@ impl ScalarUDFImpl for DatePartFunc {
}

fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result<ReturnInfo> {
// Length check handled in the signature
debug_assert_eq!(args.scalar_arguments.len(), 2);
let [field, _] = take_function_args(self.name(), args.scalar_arguments)?;

args.scalar_arguments[0]
field
.and_then(|sv| {
sv.try_as_str()
.flatten()
Expand Down
Loading

0 comments on commit 5239d1a

Please sign in to comment.