Skip to content

Commit

Permalink
Improve split_part udf by using a GenericStringBuilder
Browse files Browse the repository at this point in the history
  • Loading branch information
Lordworms committed Aug 21, 2024
1 parent 37e54ee commit dcd223f
Showing 1 changed file with 86 additions and 66 deletions.
152 changes: 86 additions & 66 deletions datafusion/functions/src/string/split_part.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,25 @@
// specific language governing permissions and limitations
// under the License.

use arrow::array::{
ArrayRef, GenericStringArray, Int64Array, OffsetSizeTrait, StringViewArray,
};
use arrow::array::{AsArray, GenericStringBuilder};
use arrow::datatypes::DataType;
use datafusion_common::cast::as_int64_array;
use std::any::Any;
use std::fmt::Write;
use std::sync::Arc;

use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait};
use arrow::datatypes::DataType;

use datafusion_common::cast::{
as_generic_string_array, as_int64_array, as_string_view_array,
};
use datafusion_common::{exec_err, Result};
use datafusion_common::{exec_err, DataFusionError, Result};
use datafusion_expr::TypeSignature::*;
use datafusion_expr::{ColumnarValue, Volatility};
use datafusion_expr::{ScalarUDFImpl, Signature};

use crate::utils::{make_scalar_function, utf8_to_str_type};

use super::common::StringArrayType;

#[derive(Debug)]
pub struct SplitPartFunc {
signature: Signature,
Expand Down Expand Up @@ -90,12 +93,12 @@ impl ScalarUDFImpl for SplitPartFunc {
(DataType::LargeUtf8, DataType::LargeUtf8) => {
make_scalar_function(split_part::<i64, i64>, vec![])(args)
}
(_, DataType::LargeUtf8) => {
make_scalar_function(split_part::<i32, i64>, vec![])(args)
}
(DataType::LargeUtf8, _) => {
(DataType::LargeUtf8, DataType::Utf8 | DataType::Utf8View) => {
make_scalar_function(split_part::<i64, i32>, vec![])(args)
}
(DataType::Utf8 | DataType::Utf8View, DataType::LargeUtf8) => {
make_scalar_function(split_part::<i32, i64>, vec![])(args)
}
(first_type, second_type) => exec_err!(
"unsupported first type {} and second type {} for split_part function",
first_type,
Expand All @@ -105,13 +108,70 @@ impl ScalarUDFImpl for SplitPartFunc {
}
}

macro_rules! process_split_part {
($string_array: expr, $delimiter_array: expr, $n_array: expr) => {{
let result = $string_array
.iter()
.zip($delimiter_array.iter())
.zip($n_array.iter())
.map(|((string, delimiter), n)| match (string, delimiter, n) {
/// Splits string at occurrences of delimiter and returns the n'th field (counting from one).
/// split_part('abc~@~def~@~ghi', '~@~', 2) = 'def'
pub fn split_part<StringArrayLen: OffsetSizeTrait, DelimiterArrayLen: OffsetSizeTrait>(
args: &[ArrayRef],
) -> Result<ArrayRef> {
let n_array = as_int64_array(&args[2])?;

match (args[0].data_type(), args[1].data_type()) {
(DataType::Utf8View, DataType::Utf8View) => {
split_part_impl::<&StringViewArray, &StringViewArray, StringArrayLen>(
args[0].as_string_view(),
args[1].as_string_view(),
n_array,
)
}
(_, DataType::Utf8View) => split_part_impl::<
&GenericStringArray<StringArrayLen>,
&StringViewArray,
StringArrayLen,
>(
args[0].as_string::<StringArrayLen>(),
args[1].as_string_view(),
n_array,
),
(DataType::Utf8View, _) => split_part_impl::<
&StringViewArray,
&GenericStringArray<DelimiterArrayLen>,
StringArrayLen,
>(
args[0].as_string_view(),
args[1].as_string::<DelimiterArrayLen>(),
n_array,
),
(_, _) => split_part_impl::<
&GenericStringArray<StringArrayLen>,
&GenericStringArray<DelimiterArrayLen>,
StringArrayLen,
>(
args[0].as_string::<StringArrayLen>(),
args[1].as_string::<DelimiterArrayLen>(),
n_array,
),
}
}

/// impl
pub fn split_part_impl<'a, StringArrType, DelimiterArrType, StringArrayLen>(
string_array: StringArrType,
delimiter_array: DelimiterArrType,
n_array: &Int64Array,
) -> Result<ArrayRef>
where
StringArrType: StringArrayType<'a>,
DelimiterArrType: StringArrayType<'a>,
StringArrayLen: OffsetSizeTrait,
{
let mut builder: GenericStringBuilder<StringArrayLen> = GenericStringBuilder::new();

string_array
.iter()
.zip(delimiter_array.iter())
.zip(n_array.iter())
.try_for_each(|((string, delimiter), n)| -> Result<(), DataFusionError> {
match (string, delimiter, n) {
(Some(string), Some(delimiter), Some(n)) => {
let split_string: Vec<&str> = string.split(delimiter).collect();
let len = split_string.len();
Expand All @@ -125,58 +185,18 @@ macro_rules! process_split_part {
} as usize;

if index < len {
Ok(Some(split_string[index]))
builder.write_str(split_string[index])?;
builder.append_value("");
} else {
Ok(Some(""))
builder.write_str("")?;
}
}
_ => Ok(None),
})
.collect::<Result<GenericStringArray<StringLen>>>()?;
Ok(Arc::new(result) as ArrayRef)
}};
}

/// Splits string at occurrences of delimiter and returns the n'th field (counting from one).
/// split_part('abc~@~def~@~ghi', '~@~', 2) = 'def'
fn split_part<StringLen: OffsetSizeTrait, DelimiterLen: OffsetSizeTrait>(
args: &[ArrayRef],
) -> Result<ArrayRef> {
let n_array = as_int64_array(&args[2])?;
match (args[0].data_type(), args[1].data_type()) {
(DataType::Utf8View, _) => {
let string_array = as_string_view_array(&args[0])?;
match args[1].data_type() {
DataType::Utf8View => {
let delimiter_array = as_string_view_array(&args[1])?;
process_split_part!(string_array, delimiter_array, n_array)
}
_ => {
let delimiter_array =
as_generic_string_array::<DelimiterLen>(&args[1])?;
process_split_part!(string_array, delimiter_array, n_array)
}
}
}
(_, DataType::Utf8View) => {
let delimiter_array = as_string_view_array(&args[1])?;
match args[0].data_type() {
DataType::Utf8View => {
let string_array = as_string_view_array(&args[0])?;
process_split_part!(string_array, delimiter_array, n_array)
}
_ => {
let string_array = as_generic_string_array::<StringLen>(&args[0])?;
process_split_part!(string_array, delimiter_array, n_array)
}
_ => builder.append_null(),
}
}
(_, _) => {
let string_array = as_generic_string_array::<StringLen>(&args[0])?;
let delimiter_array = as_generic_string_array::<DelimiterLen>(&args[1])?;
process_split_part!(string_array, delimiter_array, n_array)
}
}
Ok(())
})?;

Ok(Arc::new(builder.finish()) as ArrayRef)
}

#[cfg(test)]
Expand Down

0 comments on commit dcd223f

Please sign in to comment.