diff --git a/datafusion/functions/src/string/concat.rs b/datafusion/functions/src/string/concat.rs index 00fe69b0bd33..98f57efef90d 100644 --- a/datafusion/functions/src/string/concat.rs +++ b/datafusion/functions/src/string/concat.rs @@ -136,13 +136,9 @@ impl ScalarUDFImpl for ConcatFunc { for arg in args { match arg { - ColumnarValue::Scalar(ScalarValue::Utf8(maybe_value)) => { - if let Some(s) = maybe_value { - data_size += s.len() * len; - columns.push(ColumnarValueRef::Scalar(s.as_bytes())); - } - } - ColumnarValue::Scalar(ScalarValue::Utf8View(maybe_value)) => { + ColumnarValue::Scalar(ScalarValue::Utf8(maybe_value)) + | ColumnarValue::Scalar(ScalarValue::LargeUtf8(maybe_value)) + | ColumnarValue::Scalar(ScalarValue::Utf8View(maybe_value)) => { if let Some(s) = maybe_value { data_size += s.len() * len; columns.push(ColumnarValueRef::Scalar(s.as_bytes())); diff --git a/datafusion/functions/src/string/concat_ws.rs b/datafusion/functions/src/string/concat_ws.rs index 4d05f4e707b1..1134c525cfca 100644 --- a/datafusion/functions/src/string/concat_ws.rs +++ b/datafusion/functions/src/string/concat_ws.rs @@ -15,24 +15,22 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::StringArray; +use arrow::array::{as_largestring_array, Array, StringArray}; use std::any::Any; use std::sync::Arc; use arrow::datatypes::DataType; -use arrow::datatypes::DataType::Utf8; -use datafusion_common::cast::as_string_array; -use datafusion_common::{exec_err, internal_err, Result, ScalarValue}; +use crate::string::common::*; +use crate::string::concat::simplify_concat; +use crate::string::concat_ws; +use datafusion_common::cast::{as_string_array, as_string_view_array}; +use datafusion_common::{exec_err, internal_err, plan_err, Result, ScalarValue}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{lit, ColumnarValue, Expr, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; -use crate::string::common::*; -use crate::string::concat::simplify_concat; -use crate::string::concat_ws; - #[derive(Debug)] pub struct ConcatWsFunc { signature: Signature, @@ -48,7 +46,10 @@ impl ConcatWsFunc { pub fn new() -> Self { use DataType::*; Self { - signature: Signature::variadic(vec![Utf8], Volatility::Immutable), + signature: Signature::variadic( + vec![Utf8View, Utf8, LargeUtf8], + Volatility::Immutable, + ), } } } @@ -67,13 +68,14 @@ impl ScalarUDFImpl for ConcatWsFunc { } fn return_type(&self, _arg_types: &[DataType]) -> Result { + use DataType::*; Ok(Utf8) } /// Concatenates all but the first argument, with separators. The first argument is used as the separator string, and should not be NULL. Other NULL arguments are ignored. /// concat_ws(',', 'abcde', 2, NULL, 22) = 'abcde,2,22' fn invoke(&self, args: &[ColumnarValue]) -> Result { - // do not accept 0 or 1 arguments. + // do not accept 0 arguments. if args.len() < 2 { return exec_err!( "concat_ws was called with {} arguments. It requires at least 2.", @@ -92,8 +94,12 @@ impl ScalarUDFImpl for ConcatWsFunc { // Scalar if array_len.is_none() { let sep = match &args[0] { - ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => s, - ColumnarValue::Scalar(ScalarValue::Utf8(None)) => { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) + | ColumnarValue::Scalar(ScalarValue::Utf8View(Some(s))) + | ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(s))) => s, + ColumnarValue::Scalar(ScalarValue::Utf8(None)) + | ColumnarValue::Scalar(ScalarValue::Utf8View(None)) + | ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) => { return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))); } _ => unreachable!(), @@ -104,22 +110,30 @@ impl ScalarUDFImpl for ConcatWsFunc { for arg in iter.by_ref() { match arg { - ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) + | ColumnarValue::Scalar(ScalarValue::Utf8View(Some(s))) + | ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(s))) => { result.push_str(s); break; } - ColumnarValue::Scalar(ScalarValue::Utf8(None)) => {} + ColumnarValue::Scalar(ScalarValue::Utf8(None)) + | ColumnarValue::Scalar(ScalarValue::Utf8View(None)) + | ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) => {} _ => unreachable!(), } } for arg in iter.by_ref() { match arg { - ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) + | ColumnarValue::Scalar(ScalarValue::Utf8View(Some(s))) + | ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(s))) => { result.push_str(sep); result.push_str(s); } - ColumnarValue::Scalar(ScalarValue::Utf8(None)) => {} + ColumnarValue::Scalar(ScalarValue::Utf8(None)) + | ColumnarValue::Scalar(ScalarValue::Utf8View(None)) + | ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) => {} _ => unreachable!(), } } @@ -155,21 +169,53 @@ impl ScalarUDFImpl for ConcatWsFunc { let mut columns = Vec::with_capacity(args.len() - 1); for arg in &args[1..] { match arg { - ColumnarValue::Scalar(ScalarValue::Utf8(maybe_value)) => { + ColumnarValue::Scalar(ScalarValue::Utf8(maybe_value)) + | ColumnarValue::Scalar(ScalarValue::LargeUtf8(maybe_value)) + | ColumnarValue::Scalar(ScalarValue::Utf8View(maybe_value)) => { if let Some(s) = maybe_value { data_size += s.len() * len; columns.push(ColumnarValueRef::Scalar(s.as_bytes())); } } ColumnarValue::Array(array) => { - let string_array = as_string_array(array)?; - data_size += string_array.values().len(); - let column = if array.is_nullable() { - ColumnarValueRef::NullableArray(string_array) - } else { - ColumnarValueRef::NonNullableArray(string_array) + match array.data_type() { + DataType::Utf8 => { + let string_array = as_string_array(array)?; + + data_size += string_array.values().len(); + let column = if array.is_nullable() { + ColumnarValueRef::NullableArray(string_array) + } else { + ColumnarValueRef::NonNullableArray(string_array) + }; + columns.push(column); + }, + DataType::LargeUtf8 => { + let string_array = as_largestring_array(array); + + data_size += string_array.values().len(); + let column = if array.is_nullable() { + ColumnarValueRef::NullableLargeStringArray(string_array) + } else { + ColumnarValueRef::NonNullableLargeStringArray(string_array) + }; + columns.push(column); + }, + DataType::Utf8View => { + let string_array = as_string_view_array(array)?; + + data_size += string_array.data_buffers().iter().map(|buf| buf.len()).sum::(); + let column = if array.is_nullable() { + ColumnarValueRef::NullableStringViewArray(string_array) + } else { + ColumnarValueRef::NonNullableStringViewArray(string_array) + }; + columns.push(column); + }, + other => { + return plan_err!("Input was {other} which is not a supported datatype for concat_ws function.") + } }; - columns.push(column); } _ => unreachable!(), } @@ -223,7 +269,9 @@ impl ScalarUDFImpl for ConcatWsFunc { fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result { match delimiter { Expr::Literal( - ScalarValue::Utf8(delimiter) | ScalarValue::LargeUtf8(delimiter), + ScalarValue::Utf8(delimiter) + | ScalarValue::LargeUtf8(delimiter) + | ScalarValue::Utf8View(delimiter), ) => { match delimiter { // when the delimiter is an empty string, @@ -236,8 +284,8 @@ fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result {} - Expr::Literal(ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v))) => { + Expr::Literal(ScalarValue::Utf8(None) | ScalarValue::LargeUtf8(None) | ScalarValue::Utf8View(None)) => {} + Expr::Literal(ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v)) | ScalarValue::Utf8View(Some(v))) => { match contiguous_scalar { None => contiguous_scalar = Some(v.to_string()), Some(mut pre) => { diff --git a/datafusion/sqllogictest/test_files/string_view.slt b/datafusion/sqllogictest/test_files/string_view.slt index eb625e530b66..2ff935351828 100644 --- a/datafusion/sqllogictest/test_files/string_view.slt +++ b/datafusion/sqllogictest/test_files/string_view.slt @@ -796,7 +796,7 @@ EXPLAIN SELECT FROM test; ---- logical_plan -01)Projection: concat_ws(Utf8(", "), CAST(test.column1_utf8view AS Utf8), CAST(test.column2_utf8view AS Utf8)) AS c +01)Projection: concat_ws(Utf8(", "), test.column1_utf8view, test.column2_utf8view) AS c 02)--TableScan: test projection=[column1_utf8view, column2_utf8view] ## Ensure no casts for CONTAINS @@ -927,6 +927,83 @@ XiangpengXiangpeng RaphaelR R +## Should run CONCAT successfully with utf8view +query T +SELECT + concat(column1_utf8view, column2_utf8view) as c +FROM test; +---- +AndrewX +XiangpengXiangpeng +RaphaelR +R + +## Should run CONCAT_WS successfully with utf8 +query T +SELECT + concat_ws(',', column1_utf8, column2_utf8) as c +FROM test; +---- +Andrew,X +Xiangpeng,Xiangpeng +Raphael,R +R + +## Should run CONCAT_WS successfully with utf8view +query T +SELECT + concat_ws(',', column1_utf8view, column2_utf8view) as c +FROM test; +---- +Andrew,X +Xiangpeng,Xiangpeng +Raphael,R +R + +## Should run CONCAT_WS successfully with largeutf8 +query T +SELECT + concat_ws(',', column1_large_utf8, column2_large_utf8) as c +FROM test; +---- +Andrew,X +Xiangpeng,Xiangpeng +Raphael,R +R + +## Should run CONCAT_WS successfully with utf8 and largeutf8 +query T +SELECT + concat_ws(',', column1_utf8, column2_large_utf8) as c +FROM test; +---- +Andrew,X +Xiangpeng,Xiangpeng +Raphael,R +R + +## Should run CONCAT_WS successfully with utf8 and utf8view +query T +SELECT + concat_ws(',', column1_utf8view, column2_utf8) as c +FROM test; +---- +Andrew,X +Xiangpeng,Xiangpeng +Raphael,R +R + +## Should run CONCAT_WS successfully with largeutf8 and utf8view +query T +SELECT + concat_ws(',', column1_utf8view, column2_large_utf8) as c +FROM test; +---- +Andrew,X +Xiangpeng,Xiangpeng +Raphael,R +R + ## Ensure no casts for LPAD query TT EXPLAIN SELECT