Skip to content

Commit

Permalink
Update concat_ws scalar function to support Utf8View
Browse files Browse the repository at this point in the history
Signed-off-by: Devan <[email protected]>
  • Loading branch information
devanbenz committed Sep 4, 2024
1 parent 0cd7c25 commit 22fd1b5
Show file tree
Hide file tree
Showing 3 changed files with 279 additions and 48 deletions.
10 changes: 3 additions & 7 deletions datafusion/functions/src/string/concat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,9 @@ impl ScalarUDFImpl for ConcatFunc {

for arg in args {
match arg {
ColumnarValue::Scalar(ScalarValue::Utf8(maybe_value)) => {
if let Some(s) = maybe_value {
data_size += s.len() * len;
columns.push(ColumnarValueRef::Scalar(s.as_bytes()));
}
}
ColumnarValue::Scalar(ScalarValue::Utf8View(maybe_value)) => {
ColumnarValue::Scalar(ScalarValue::Utf8(maybe_value))
| ColumnarValue::Scalar(ScalarValue::LargeUtf8(maybe_value))
| ColumnarValue::Scalar(ScalarValue::Utf8View(maybe_value)) => {
if let Some(s) = maybe_value {
data_size += s.len() * len;
columns.push(ColumnarValueRef::Scalar(s.as_bytes()));
Expand Down
238 changes: 198 additions & 40 deletions datafusion/functions/src/string/concat_ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,14 @@
// specific language governing permissions and limitations
// under the License.

use arrow::array::StringArray;
use arrow::array::{as_largestring_array, Array, StringArray};
use std::any::Any;
use std::sync::Arc;

use arrow::datatypes::DataType;
use arrow::datatypes::DataType::Utf8;

use datafusion_common::cast::as_string_array;
use datafusion_common::{exec_err, internal_err, Result, ScalarValue};
use datafusion_common::cast::{as_string_array, as_string_view_array};
use datafusion_common::{exec_err, internal_err, plan_err, Result, ScalarValue};
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
use datafusion_expr::{lit, ColumnarValue, Expr, Volatility};
Expand All @@ -46,9 +45,10 @@ impl Default for ConcatWsFunc {

impl ConcatWsFunc {
pub fn new() -> Self {
use DataType::*;
Self {
signature: Signature::variadic(vec![Utf8], Volatility::Immutable),
signature: Signature::variadic_any(
Volatility::Immutable,
),
}
}
}
Expand All @@ -66,8 +66,19 @@ impl ScalarUDFImpl for ConcatWsFunc {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(Utf8)
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
use DataType::*;
let mut dt = &Utf8;
arg_types.iter().for_each(|data_type| {
if data_type == &Utf8View {
dt = data_type;
}
if data_type == &LargeUtf8 && dt != &Utf8View {
dt = data_type;
}
});

Ok(dt.to_owned())
}

/// Concatenates all but the first argument, with separators. The first argument is used as the separator string, and should not be NULL. Other NULL arguments are ignored.
Expand All @@ -89,6 +100,18 @@ impl ScalarUDFImpl for ConcatWsFunc {
})
.next();

let mut return_datatype = DataType::Utf8;
args.iter().for_each(|col| {
if col.data_type() == DataType::Utf8View {
return_datatype = col.data_type();
}
if col.data_type() == DataType::LargeUtf8
&& return_datatype != DataType::Utf8View
{
return_datatype = col.data_type();
}
});

// Scalar
if array_len.is_none() {
let sep = match &args[0] {
Expand All @@ -104,27 +127,43 @@ impl ScalarUDFImpl for ConcatWsFunc {

for arg in iter.by_ref() {
match arg {
ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => {
ColumnarValue::Scalar(ScalarValue::Utf8(Some(s)))
| ColumnarValue::Scalar(ScalarValue::Utf8View(Some(s)))
| ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(s))) => {
result.push_str(s);
break;
}
ColumnarValue::Scalar(ScalarValue::Utf8(None)) => {}
ColumnarValue::Scalar(ScalarValue::Utf8(None))
| ColumnarValue::Scalar(ScalarValue::Utf8View(None))
| ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) => {}
_ => unreachable!(),
}
}

for arg in iter.by_ref() {
match arg {
ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => {
ColumnarValue::Scalar(ScalarValue::Utf8(Some(s)))
| ColumnarValue::Scalar(ScalarValue::Utf8View(Some(s)))
| ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(s))) => {
result.push_str(sep);
result.push_str(s);
}
ColumnarValue::Scalar(ScalarValue::Utf8(None)) => {}
ColumnarValue::Scalar(ScalarValue::Utf8(None))
| ColumnarValue::Scalar(ScalarValue::Utf8View(None))
| ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) => {}
_ => unreachable!(),
}
}

return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result))));
return match return_datatype {
DataType::Utf8View => {
Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result))))
}
DataType::LargeUtf8 => {
Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result))))
}
_ => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result)))),
};
}

// Array
Expand Down Expand Up @@ -155,52 +194,145 @@ impl ScalarUDFImpl for ConcatWsFunc {
let mut columns = Vec::with_capacity(args.len() - 1);
for arg in &args[1..] {
match arg {
ColumnarValue::Scalar(ScalarValue::Utf8(maybe_value)) => {
ColumnarValue::Scalar(ScalarValue::Utf8(maybe_value))
| ColumnarValue::Scalar(ScalarValue::LargeUtf8(maybe_value))
| ColumnarValue::Scalar(ScalarValue::Utf8View(maybe_value)) => {
if let Some(s) = maybe_value {
data_size += s.len() * len;
columns.push(ColumnarValueRef::Scalar(s.as_bytes()));
}
}
ColumnarValue::Array(array) => {
let string_array = as_string_array(array)?;
data_size += string_array.values().len();
let column = if array.is_nullable() {
ColumnarValueRef::NullableArray(string_array)
} else {
ColumnarValueRef::NonNullableArray(string_array)
match array.data_type() {
DataType::Utf8 => {
let string_array = as_string_array(array)?;

data_size += string_array.values().len();
let column = if array.is_nullable() {
ColumnarValueRef::NullableArray(string_array)
} else {
ColumnarValueRef::NonNullableArray(string_array)
};
columns.push(column);
},
DataType::LargeUtf8 => {
let string_array = as_largestring_array(array);

data_size += string_array.values().len();
let column = if array.is_nullable() {
ColumnarValueRef::NullableLargeStringArray(string_array)
} else {
ColumnarValueRef::NonNullableLargeStringArray(string_array)
};
columns.push(column);
},
DataType::Utf8View => {
let string_array = as_string_view_array(array)?;

data_size += string_array.len();
let column = if array.is_nullable() {
ColumnarValueRef::NullableStringViewArray(string_array)
} else {
ColumnarValueRef::NonNullableStringViewArray(string_array)
};
columns.push(column);
},
other => {
return plan_err!("Input was {other} which is not a supported datatype for concat function")
}
};
columns.push(column);
}
_ => unreachable!(),
}
}

let mut builder = StringArrayBuilder::with_capacity(len, data_size);
for i in 0..len {
if !sep.is_valid(i) {
builder.append_offset();
continue;
}
match return_datatype {
DataType::Utf8 => {
let mut builder = StringArrayBuilder::with_capacity(len, data_size);
for i in 0..len {
if !sep.is_valid(i) {
builder.append_offset();
continue;
}

let mut iter = columns.iter();
for column in iter.by_ref() {
if column.is_valid(i) {
builder.write::<false>(column, i);
break;
let mut iter = columns.iter();
for column in iter.by_ref() {
if column.is_valid(i) {
builder.write::<false>(column, i);
break;
}
}

for column in iter {
if column.is_valid(i) {
builder.write::<false>(&sep, i);
builder.write::<false>(column, i);
}
}

builder.append_offset();
}

Ok(ColumnarValue::Array(Arc::new(builder.finish(sep.nulls()))))
}
DataType::Utf8View => {
let mut builder = StringViewArrayBuilder::with_capacity(len, data_size);
for i in 0..len {
if !sep.is_valid(i) {
builder.append_offset();
continue;
}

for column in iter {
if column.is_valid(i) {
builder.write::<false>(&sep, i);
builder.write::<false>(column, i);
let mut iter = columns.iter();
for column in iter.by_ref() {
if column.is_valid(i) {
builder.write::<false>(column, i);
break;
}
}

for column in iter {
if column.is_valid(i) {
builder.write::<false>(&sep, i);
builder.write::<false>(column, i);
}
}

builder.append_offset();
}

Ok(ColumnarValue::Array(Arc::new(builder.finish())))
}
DataType::LargeUtf8 => {
let mut builder = LargeStringArrayBuilder::with_capacity(len, data_size);
for i in 0..len {
if !sep.is_valid(i) {
builder.append_offset();
continue;
}

builder.append_offset();
}
let mut iter = columns.iter();
for column in iter.by_ref() {
if column.is_valid(i) {
builder.write::<false>(column, i);
break;
}
}

Ok(ColumnarValue::Array(Arc::new(builder.finish(sep.nulls()))))
for column in iter {
if column.is_valid(i) {
builder.write::<false>(&sep, i);
builder.write::<false>(column, i);
}
}

builder.append_offset();
}

Ok(ColumnarValue::Array(Arc::new(builder.finish(sep.nulls()))))
}
_ => unreachable!(),
}
}

/// Simply the `concat_ws` function by
Expand Down Expand Up @@ -304,7 +436,7 @@ mod tests {
use std::sync::Arc;

use arrow::array::{Array, ArrayRef, StringArray};
use arrow::datatypes::DataType::Utf8;
use arrow::datatypes::DataType::{Utf8, LargeUtf8, Utf8View};

use crate::string::concat_ws::ConcatWsFunc;
use datafusion_common::Result;
Expand Down Expand Up @@ -365,6 +497,32 @@ mod tests {
Utf8,
StringArray
);
test_function!(
ConcatWsFunc::new(),
&[
ColumnarValue::Scalar(ScalarValue::from("|")),
ColumnarValue::Scalar(ScalarValue::from("aa")),
ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)),
ColumnarValue::Scalar(ScalarValue::from("cc")),
],
Ok(Some("aa|cc")),
&str,
LargeUtf8,
StringArray
);
test_function!(
ConcatWsFunc::new(),
&[
ColumnarValue::Scalar(ScalarValue::from("|")),
ColumnarValue::Scalar(ScalarValue::from("aa")),
ColumnarValue::Scalar(ScalarValue::Utf8View(None)),
ColumnarValue::Scalar(ScalarValue::from("cc")),
],
Ok(Some("aa|cc")),
&str,
Utf8View,
StringArray
);

Ok(())
}
Expand Down
Loading

0 comments on commit 22fd1b5

Please sign in to comment.