Skip to content

Commit

Permalink
cleanup to trigger rerun
Browse files Browse the repository at this point in the history
Signed-off-by: jayzhan211 <[email protected]>
  • Loading branch information
jayzhan211 committed Mar 9, 2024
1 parent cda764e commit 0ba624a
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 55 deletions.
32 changes: 0 additions & 32 deletions datafusion/functions-array/src/kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,7 @@ use datafusion_common::cast::{
use datafusion_common::utils::{array_into_list_array, list_ndims};
use datafusion_common::{
exec_err, not_impl_datafusion_err, not_impl_err, plan_err, DataFusionError, Result,
ScalarValue,
};
use datafusion_expr::{ColumnarValue, ScalarFunctionImplementation};
use std::any::type_name;
use std::sync::Arc;
macro_rules! downcast_arg {
Expand Down Expand Up @@ -903,36 +901,6 @@ fn align_array_dimensions<O: OffsetSizeTrait>(
aligned_args
}

pub(crate) fn make_scalar_function_with_hints<F>(inner: F) -> ScalarFunctionImplementation
where
F: Fn(&[ArrayRef]) -> Result<ArrayRef> + Sync + Send + 'static,
{
Arc::new(move |args: &[ColumnarValue]| {
// first, identify if any of the arguments is an Array. If yes, store its `len`,
// as any scalar will need to be converted to an array of len `len`.
let len = args
.iter()
.fold(Option::<usize>::None, |acc, arg| match arg {
ColumnarValue::Scalar(_) => acc,
ColumnarValue::Array(a) => Some(a.len()),
});

let is_scalar = len.is_none();

let args = ColumnarValue::values_to_arrays(args)?;

let result = (inner)(&args);

if is_scalar {
// If all inputs are scalar, keeps output as scalar
let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0));
result.map(ColumnarValue::Scalar)
} else {
result.map(ColumnarValue::Array)
}
})
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
44 changes: 22 additions & 22 deletions datafusion/functions-array/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@ use datafusion_common::utils::list_ndims;
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::type_coercion::binary::get_wider_type;
use datafusion_expr::Expr;
use datafusion_expr::TypeSignature::{Any as expr_Any, Exact, VariadicEqual};
use datafusion_expr::TypeSignature;
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
use std::any::Any;
use std::cmp::Ordering;
use std::sync::Arc;

use crate::kernels::make_scalar_function_with_hints;
use crate::utils::make_scalar_function;

// Create static instances of ScalarUDFs for each function
make_udf_function!(ArrayToString,
Expand Down Expand Up @@ -111,10 +111,10 @@ impl Range {
Self {
signature: Signature::one_of(
vec![
Exact(vec![Int64]),
Exact(vec![Int64, Int64]),
Exact(vec![Int64, Int64, Int64]),
Exact(vec![Date32, Date32, Interval(MonthDayNano)]),
TypeSignature::Exact(vec![Int64]),
TypeSignature::Exact(vec![Int64, Int64]),
TypeSignature::Exact(vec![Int64, Int64, Int64]),
TypeSignature::Exact(vec![Date32, Date32, Interval(MonthDayNano)]),
],
Volatility::Immutable,
),
Expand Down Expand Up @@ -181,10 +181,10 @@ impl GenSeries {
Self {
signature: Signature::one_of(
vec![
Exact(vec![Int64]),
Exact(vec![Int64, Int64]),
Exact(vec![Int64, Int64, Int64]),
Exact(vec![Date32, Date32, Interval(MonthDayNano)]),
TypeSignature::Exact(vec![Int64]),
TypeSignature::Exact(vec![Int64, Int64]),
TypeSignature::Exact(vec![Int64, Int64, Int64]),
TypeSignature::Exact(vec![Date32, Date32, Interval(MonthDayNano)]),
],
Volatility::Immutable,
),
Expand Down Expand Up @@ -444,7 +444,7 @@ impl ScalarUDFImpl for ArrayAppend {
}

fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result<ColumnarValue> {
make_scalar_function_with_hints(crate::kernels::array_append)(args)
make_scalar_function(crate::kernels::array_append)(args)
}

fn aliases(&self) -> &[String] {
Expand All @@ -455,9 +455,9 @@ impl ScalarUDFImpl for ArrayAppend {
make_udf_function!(
ArrayPrepend,
array_prepend,
element array, // arg name
"Prepends an element to the beginning of an array.", // doc
array_prepend_udf // internal function name
element array,
"Prepends an element to the beginning of an array.",
array_prepend_udf
);

#[derive(Debug)]
Expand Down Expand Up @@ -498,7 +498,7 @@ impl ScalarUDFImpl for ArrayPrepend {
}

fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result<ColumnarValue> {
make_scalar_function_with_hints(crate::kernels::array_prepend)(args)
make_scalar_function(crate::kernels::array_prepend)(args)
}

fn aliases(&self) -> &[String] {
Expand All @@ -509,8 +509,8 @@ impl ScalarUDFImpl for ArrayPrepend {
make_udf_function!(
ArrayConcat,
array_concat,
"Concatenates arrays.", // doc
array_concat_udf // internal function name
"Concatenates arrays.",
array_concat_udf
);

#[derive(Debug)]
Expand Down Expand Up @@ -576,7 +576,7 @@ impl ScalarUDFImpl for ArrayConcat {
}

fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result<ColumnarValue> {
make_scalar_function_with_hints(crate::kernels::array_concat)(args)
make_scalar_function(crate::kernels::array_concat)(args)
}

fn aliases(&self) -> &[String] {
Expand All @@ -587,8 +587,8 @@ impl ScalarUDFImpl for ArrayConcat {
make_udf_function!(
MakeArray,
make_array,
"Returns an Arrow array using the specified input expressions.", // doc
make_array_udf // internal function name
"Returns an Arrow array using the specified input expressions.",
make_array_udf
);

#[derive(Debug)]
Expand All @@ -601,7 +601,7 @@ impl MakeArray {
pub fn new() -> Self {
Self {
signature: Signature::one_of(
vec![VariadicEqual, expr_Any(0)],
vec![TypeSignature::VariadicEqual, TypeSignature::Any(0)],
Volatility::Immutable,
),
aliases: vec![String::from("make_array"), String::from("make_list")],
Expand Down Expand Up @@ -646,7 +646,7 @@ impl ScalarUDFImpl for MakeArray {
}

fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result<ColumnarValue> {
make_scalar_function_with_hints(crate::kernels::make_array)(args)
make_scalar_function(crate::kernels::make_array)(args)
}

fn aliases(&self) -> &[String] {
Expand Down
35 changes: 34 additions & 1 deletion datafusion/functions-array/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@

//! array function utils
use std::sync::Arc;

use arrow::{array::ArrayRef, datatypes::DataType};
use datafusion_common::{plan_err, Result};
use datafusion_common::{plan_err, Result, ScalarValue};
use datafusion_expr::{ColumnarValue, ScalarFunctionImplementation};

pub(crate) fn check_datatypes(name: &str, args: &[&ArrayRef]) -> Result<()> {
let data_type = args[0].data_type();
Expand All @@ -32,3 +35,33 @@ pub(crate) fn check_datatypes(name: &str, args: &[&ArrayRef]) -> Result<()> {

Ok(())
}

pub(crate) fn make_scalar_function<F>(inner: F) -> ScalarFunctionImplementation
where
F: Fn(&[ArrayRef]) -> Result<ArrayRef> + Sync + Send + 'static,
{
Arc::new(move |args: &[ColumnarValue]| {
// first, identify if any of the arguments is an Array. If yes, store its `len`,
// as any scalar will need to be converted to an array of len `len`.
let len = args
.iter()
.fold(Option::<usize>::None, |acc, arg| match arg {
ColumnarValue::Scalar(_) => acc,
ColumnarValue::Array(a) => Some(a.len()),
});

let is_scalar = len.is_none();

let args = ColumnarValue::values_to_arrays(args)?;

let result = (inner)(&args);

if is_scalar {
// If all inputs are scalar, keeps output as scalar
let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0));
result.map(ColumnarValue::Scalar)
} else {
result.map(ColumnarValue::Array)
}
})
}

0 comments on commit 0ba624a

Please sign in to comment.