diff --git a/datafusion/core/tests/expr_api/simplification.rs b/datafusion/core/tests/expr_api/simplification.rs index 68785b7a5a45..1e6ff8088d0a 100644 --- a/datafusion/core/tests/expr_api/simplification.rs +++ b/datafusion/core/tests/expr_api/simplification.rs @@ -483,10 +483,12 @@ fn expr_test_schema() -> DFSchemaRef { Field::new("c2", DataType::Boolean, true), Field::new("c3", DataType::Int64, true), Field::new("c4", DataType::UInt32, true), + Field::new("c5", DataType::Utf8View, true), Field::new("c1_non_null", DataType::Utf8, false), Field::new("c2_non_null", DataType::Boolean, false), Field::new("c3_non_null", DataType::Int64, false), Field::new("c4_non_null", DataType::UInt32, false), + Field::new("c5_non_null", DataType::Utf8View, false), ]) .to_dfschema_ref() .unwrap() @@ -665,20 +667,32 @@ fn test_simplify_concat_ws_with_null() { } #[test] -fn test_simplify_concat() { +fn test_simplify_concat() -> Result<()> { + let schema = expr_test_schema(); let null = lit(ScalarValue::Utf8(None)); let expr = concat(vec![ null.clone(), - col("c0"), + col("c1"), lit("hello "), null.clone(), lit("rust"), - col("c1"), + lit(ScalarValue::Utf8View(Some("!".to_string()))), + col("c2"), lit(""), null, + col("c5"), ]); - let expected = concat(vec![col("c0"), lit("hello rust"), col("c1")]); - test_simplify(expr, expected) + let expr_datatype = expr.get_type(schema.as_ref())?; + let expected = concat(vec![ + col("c1"), + lit(ScalarValue::Utf8View(Some("hello rust!".to_string()))), + col("c2"), + col("c5"), + ]); + let expected_datatype = expected.get_type(schema.as_ref())?; + assert_eq!(expr_datatype, expected_datatype); + test_simplify(expr, expected); + Ok(()) } #[test] fn test_simplify_cycles() { diff --git a/datafusion/functions/src/string/concat.rs b/datafusion/functions/src/string/concat.rs index e429a938b27d..c76a08653f53 100644 --- a/datafusion/functions/src/string/concat.rs +++ b/datafusion/functions/src/string/concat.rs @@ -48,7 +48,7 @@ impl ConcatFunc { use DataType::*; Self { signature: Signature::variadic( - vec![Utf8, Utf8View, LargeUtf8], + vec![Utf8View, Utf8, LargeUtf8], Volatility::Immutable, ), } @@ -110,8 +110,19 @@ impl ScalarUDFImpl for ConcatFunc { if array_len.is_none() { let mut result = String::new(); for arg in args { - if let ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) = arg { - result.push_str(v); + match arg { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) + | ColumnarValue::Scalar(ScalarValue::Utf8View(Some(v))) + | ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(v))) => { + result.push_str(v); + } + ColumnarValue::Scalar(ScalarValue::Utf8(None)) + | ColumnarValue::Scalar(ScalarValue::Utf8View(None)) + | ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) => {} + other => plan_err!( + "Concat function does not support scalar type {:?}", + other + )?, } } @@ -282,15 +293,37 @@ pub fn simplify_concat(args: Vec) -> Result { let mut new_args = Vec::with_capacity(args.len()); let mut contiguous_scalar = "".to_string(); + let return_type = { + let data_types: Vec<_> = args + .iter() + .filter_map(|expr| match expr { + Expr::Literal(l) => Some(l.data_type()), + _ => None, + }) + .collect(); + ConcatFunc::new().return_type(&data_types) + }?; + for arg in args.clone() { match arg { + Expr::Literal(ScalarValue::Utf8(None)) => {} + Expr::Literal(ScalarValue::LargeUtf8(None)) => { + } + Expr::Literal(ScalarValue::Utf8View(None)) => { } + // filter out `null` args - Expr::Literal(ScalarValue::Utf8(None) | ScalarValue::LargeUtf8(None) | ScalarValue::Utf8View(None)) => {} // All literals have been converted to Utf8 or LargeUtf8 in type_coercion. // Concatenate it with the `contiguous_scalar`. - Expr::Literal( - ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v)) | ScalarValue::Utf8View(Some(v)), - ) => contiguous_scalar += &v, + Expr::Literal(ScalarValue::Utf8(Some(v))) => { + contiguous_scalar += &v; + } + Expr::Literal(ScalarValue::LargeUtf8(Some(v))) => { + contiguous_scalar += &v; + } + Expr::Literal(ScalarValue::Utf8View(Some(v))) => { + contiguous_scalar += &v; + } + Expr::Literal(x) => { return internal_err!( "The scalar {x} should be casted to string type during the type coercion." @@ -301,7 +334,12 @@ pub fn simplify_concat(args: Vec) -> Result { // Then pushing this arg to the `new_args`. arg => { if !contiguous_scalar.is_empty() { - new_args.push(lit(contiguous_scalar)); + match return_type { + DataType::Utf8 => new_args.push(lit(contiguous_scalar)), + DataType::LargeUtf8 => new_args.push(lit(ScalarValue::LargeUtf8(Some(contiguous_scalar)))), + DataType::Utf8View => new_args.push(lit(ScalarValue::Utf8View(Some(contiguous_scalar)))), + _ => unreachable!(), + } contiguous_scalar = "".to_string(); } new_args.push(arg); @@ -310,7 +348,16 @@ pub fn simplify_concat(args: Vec) -> Result { } if !contiguous_scalar.is_empty() { - new_args.push(lit(contiguous_scalar)); + match return_type { + DataType::Utf8 => new_args.push(lit(contiguous_scalar)), + DataType::LargeUtf8 => { + new_args.push(lit(ScalarValue::LargeUtf8(Some(contiguous_scalar)))) + } + DataType::Utf8View => { + new_args.push(lit(ScalarValue::Utf8View(Some(contiguous_scalar)))) + } + _ => unreachable!(), + } } if !args.eq(&new_args) { @@ -392,6 +439,17 @@ mod tests { LargeUtf8, LargeStringArray ); + test_function!( + ConcatFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some("aa".to_string()))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("cc".to_string()))), + ], + Ok(Some("aacc")), + &str, + Utf8View, + StringViewArray + ); Ok(()) } @@ -406,12 +464,19 @@ mod tests { None, Some("z"), ]))); - let args = &[c0, c1, c2]; + let c3 = ColumnarValue::Scalar(ScalarValue::Utf8View(Some(",".to_string()))); + let c4 = ColumnarValue::Array(Arc::new(StringViewArray::from(vec![ + Some("a"), + None, + Some("b"), + ]))); + let args = &[c0, c1, c2, c3, c4]; #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch let result = ConcatFunc::new().invoke(args)?; let expected = - Arc::new(StringArray::from(vec!["foo,x", "bar,", "baz,z"])) as ArrayRef; + Arc::new(StringViewArray::from(vec!["foo,x,a", "bar,,", "baz,z,b"])) + as ArrayRef; match &result { ColumnarValue::Array(array) => { assert_eq!(&expected, array);