diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 6c03e8698e80..c891e85aa59b 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -2678,7 +2678,10 @@ impl ScalarValue { DataType::Duration(TimeUnit::Nanosecond) => { typed_cast!(array, index, DurationNanosecondArray, DurationNanosecond)? } - + DataType::Map(_, _) => { + let a = array.slice(index, 1); + Self::Map(Arc::new(a.as_map().to_owned())) + } other => { return _not_impl_err!( "Can't create a scalar from array of type \"{other:?}\"" diff --git a/datafusion/functions/src/core/map.rs b/datafusion/functions/src/core/map.rs index 8a8a19d7af52..6626831c8034 100644 --- a/datafusion/functions/src/core/map.rs +++ b/datafusion/functions/src/core/map.rs @@ -28,7 +28,21 @@ use datafusion_common::{exec_err, internal_err, ScalarValue}; use datafusion_common::{not_impl_err, Result}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +/// Check if we can evaluate the expr to constant directly. +/// +/// # Example +/// ```sql +/// SELECT make_map('type', 'test') from test +/// ``` +/// We can evaluate the result of `make_map` directly. +fn can_evaluate_to_const(args: &[ColumnarValue]) -> bool { + args.iter() + .all(|arg| matches!(arg, ColumnarValue::Scalar(_))) +} + fn make_map(args: &[ColumnarValue]) -> Result { + let can_evaluate_to_const = can_evaluate_to_const(args); + let (key, value): (Vec<_>, Vec<_>) = args .chunks_exact(2) .map(|chunk| { @@ -58,7 +72,7 @@ fn make_map(args: &[ColumnarValue]) -> Result { Ok(value) => value, Err(e) => return internal_err!("Error concatenating values: {}", e), }; - make_map_batch_internal(key, value) + make_map_batch_internal(key, value, can_evaluate_to_const) } fn make_map_batch(args: &[ColumnarValue]) -> Result { @@ -68,9 +82,12 @@ fn make_map_batch(args: &[ColumnarValue]) -> Result { args.len() ); } + + let can_evaluate_to_const = can_evaluate_to_const(args); + let key = get_first_array_ref(&args[0])?; let value = get_first_array_ref(&args[1])?; - make_map_batch_internal(key, value) + make_map_batch_internal(key, value, can_evaluate_to_const) } fn get_first_array_ref(columnar_value: &ColumnarValue) -> Result { @@ -85,7 +102,11 @@ fn get_first_array_ref(columnar_value: &ColumnarValue) -> Result { } } -fn make_map_batch_internal(keys: ArrayRef, values: ArrayRef) -> Result { +fn make_map_batch_internal( + keys: ArrayRef, + values: ArrayRef, + can_evaluate_to_const: bool, +) -> Result { if keys.null_count() > 0 { return exec_err!("map key cannot be null"); } @@ -124,8 +145,13 @@ fn make_map_batch_internal(keys: ArrayRef, values: ArrayRef) -> Result ConstEvaluator<'a> { } else { // Non-ListArray match ScalarValue::try_from_array(&a, 0) { - Ok(s) => ConstSimplifyResult::Simplified(s), + Ok(s) => { + // TODO: support the optimization for `Map` type after support impl hash for it + if matches!(&s, ScalarValue::Map(_)) { + ConstSimplifyResult::SimplifyRuntimeError( + DataFusionError::NotImplemented("Const evaluate for Map type is still not supported".to_string()), + expr, + ) + } else { + ConstSimplifyResult::Simplified(s) + } + } Err(err) => ConstSimplifyResult::SimplifyRuntimeError(err, expr), } } } - ColumnarValue::Scalar(s) => ConstSimplifyResult::Simplified(s), + ColumnarValue::Scalar(s) => { + // TODO: support the optimization for `Map` type after support impl hash for it + if matches!(&s, ScalarValue::Map(_)) { + ConstSimplifyResult::SimplifyRuntimeError( + DataFusionError::NotImplemented( + "Const evaluate for Map type is still not supported" + .to_string(), + ), + expr, + ) + } else { + ConstSimplifyResult::Simplified(s) + } + } } } } diff --git a/datafusion/sqllogictest/test_files/map.slt b/datafusion/sqllogictest/test_files/map.slt index abf5b2ebbf98..fb8917a5f4fe 100644 --- a/datafusion/sqllogictest/test_files/map.slt +++ b/datafusion/sqllogictest/test_files/map.slt @@ -212,3 +212,87 @@ SELECT map(column5, column6) FROM t; # {k1:1, k2:2} # {k3: 3} # {k5: 5} + +query ? +SELECT MAKE_MAP('POST', 41, 'HEAD', 33, 'PATCH', 30, 'OPTION', 29, 'GET', 27, 'PUT', 25, 'DELETE', 24) AS method_count from t; +---- +{POST: 41, HEAD: 33, PATCH: 30, OPTION: 29, GET: 27, PUT: 25, DELETE: 24} +{POST: 41, HEAD: 33, PATCH: 30, OPTION: 29, GET: 27, PUT: 25, DELETE: 24} +{POST: 41, HEAD: 33, PATCH: 30, OPTION: 29, GET: 27, PUT: 25, DELETE: 24} + +query I +SELECT MAKE_MAP('POST', 41, 'HEAD', 33)['POST'] from t; +---- +41 +41 +41 + +query ? +SELECT MAKE_MAP('POST', 41, 'HEAD', 33, 'PATCH', null) from t; +---- +{POST: 41, HEAD: 33, PATCH: } +{POST: 41, HEAD: 33, PATCH: } +{POST: 41, HEAD: 33, PATCH: } + +query ? +SELECT MAKE_MAP('POST', null, 'HEAD', 33, 'PATCH', null) from t; +---- +{POST: , HEAD: 33, PATCH: } +{POST: , HEAD: 33, PATCH: } +{POST: , HEAD: 33, PATCH: } + +query ? +SELECT MAKE_MAP(1, null, 2, 33, 3, null) from t; +---- +{1: , 2: 33, 3: } +{1: , 2: 33, 3: } +{1: , 2: 33, 3: } + +query ? +SELECT MAKE_MAP([1,2], ['a', 'b'], [3,4], ['b']) from t; +---- +{[1, 2]: [a, b], [3, 4]: [b]} +{[1, 2]: [a, b], [3, 4]: [b]} +{[1, 2]: [a, b], [3, 4]: [b]} + +query ? +SELECT MAP(['POST', 'HEAD', 'PATCH'], [41, 33, 30]) from t; +---- +{POST: 41, HEAD: 33, PATCH: 30} +{POST: 41, HEAD: 33, PATCH: 30} +{POST: 41, HEAD: 33, PATCH: 30} + +query ? +SELECT MAP(['POST', 'HEAD', 'PATCH'], [41, 33, null]) from t; +---- +{POST: 41, HEAD: 33, PATCH: } +{POST: 41, HEAD: 33, PATCH: } +{POST: 41, HEAD: 33, PATCH: } + +query ? +SELECT MAP([[1,2], [3,4]], ['a', 'b']) from t; +---- +{[1, 2]: a, [3, 4]: b} +{[1, 2]: a, [3, 4]: b} +{[1, 2]: a, [3, 4]: b} + +query ? +SELECT MAP(make_array('POST', 'HEAD', 'PATCH'), make_array(41, 33, 30)) from t; +---- +{POST: 41, HEAD: 33, PATCH: 30} +{POST: 41, HEAD: 33, PATCH: 30} +{POST: 41, HEAD: 33, PATCH: 30} + +query ? +SELECT MAP(arrow_cast(make_array('POST', 'HEAD', 'PATCH'), 'FixedSizeList(3, Utf8)'), arrow_cast(make_array(41, 33, 30), 'FixedSizeList(3, Int64)')) from t; +---- +{POST: 41, HEAD: 33, PATCH: 30} +{POST: 41, HEAD: 33, PATCH: 30} +{POST: 41, HEAD: 33, PATCH: 30} + +query ? +SELECT MAP(arrow_cast(make_array('POST', 'HEAD', 'PATCH'), 'LargeList(Utf8)'), arrow_cast(make_array(41, 33, 30), 'LargeList(Int64)')) from t; +---- +{POST: 41, HEAD: 33, PATCH: 30} +{POST: 41, HEAD: 33, PATCH: 30} +{POST: 41, HEAD: 33, PATCH: 30}