diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index cbf5d400bab5..d92067501657 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -302,6 +302,8 @@ pub enum BuiltinScalarFunction { OverLay, /// levenshtein Levenshtein, + /// substr_index + SubstrIndex, } /// Maps the sql function name to `BuiltinScalarFunction` @@ -470,6 +472,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrowTypeof => Volatility::Immutable, BuiltinScalarFunction::OverLay => Volatility::Immutable, BuiltinScalarFunction::Levenshtein => Volatility::Immutable, + BuiltinScalarFunction::SubstrIndex => Volatility::Immutable, // Stable builtin functions BuiltinScalarFunction::Now => Volatility::Stable, @@ -773,6 +776,9 @@ impl BuiltinScalarFunction { return plan_err!("The to_hex function can only accept integers."); } }), + BuiltinScalarFunction::SubstrIndex => { + utf8_to_str_type(&input_expr_types[0], "substr_index") + } BuiltinScalarFunction::ToTimestamp => Ok(match &input_expr_types[0] { Int64 => Timestamp(Second, None), _ => Timestamp(Nanosecond, None), @@ -1235,6 +1241,14 @@ impl BuiltinScalarFunction { self.volatility(), ), + BuiltinScalarFunction::SubstrIndex => Signature::one_of( + vec![ + Exact(vec![Utf8, Utf8, Int64]), + Exact(vec![LargeUtf8, LargeUtf8, Int64]), + ], + self.volatility(), + ), + BuiltinScalarFunction::Replace | BuiltinScalarFunction::Translate => { Signature::one_of(vec![Exact(vec![Utf8, Utf8, Utf8])], self.volatility()) } @@ -1486,6 +1500,7 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] { BuiltinScalarFunction::Upper => &["upper"], BuiltinScalarFunction::Uuid => &["uuid"], BuiltinScalarFunction::Levenshtein => &["levenshtein"], + BuiltinScalarFunction::SubstrIndex => &["substr_index", "substring_index"], // regex functions BuiltinScalarFunction::RegexpMatch => &["regexp_match"], diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 4da68575946a..d2c5e5cddbf3 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -916,6 +916,7 @@ scalar_expr!( scalar_expr!(ArrowTypeof, arrow_typeof, val, "data type"); scalar_expr!(Levenshtein, levenshtein, string1 string2, "Returns the Levenshtein distance between the two given strings"); +scalar_expr!(SubstrIndex, substr_index, string delimiter count, "Returns the substring from str before count occurrences of the delimiter"); scalar_expr!( Struct, @@ -1205,6 +1206,7 @@ mod test { test_nary_scalar_expr!(OverLay, overlay, string, characters, position, len); test_nary_scalar_expr!(OverLay, overlay, string, characters, position); test_scalar_expr!(Levenshtein, levenshtein, string1, string2); + test_scalar_expr!(SubstrIndex, substr_index, string, delimiter, count); } #[test] diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 5a1a68dd2127..40b21347edf5 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -862,6 +862,29 @@ pub fn create_physical_fun( ))), }) } + BuiltinScalarFunction::SubstrIndex => { + Arc::new(|args| match args[0].data_type() { + DataType::Utf8 => { + let func = invoke_if_unicode_expressions_feature_flag!( + substr_index, + i32, + "substr_index" + ); + make_scalar_function(func)(args) + } + DataType::LargeUtf8 => { + let func = invoke_if_unicode_expressions_feature_flag!( + substr_index, + i64, + "substr_index" + ); + make_scalar_function(func)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {other:?} for function substr_index", + ))), + }) + } }) } diff --git a/datafusion/physical-expr/src/unicode_expressions.rs b/datafusion/physical-expr/src/unicode_expressions.rs index e28700a25ce4..f27b3c157741 100644 --- a/datafusion/physical-expr/src/unicode_expressions.rs +++ b/datafusion/physical-expr/src/unicode_expressions.rs @@ -455,3 +455,68 @@ pub fn translate(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } + +/// Returns the substring from str before count occurrences of the delimiter delim. If count is positive, everything to the left of the final delimiter (counting from the left) is returned. If count is negative, everything to the right of the final delimiter (counting from the right) is returned. +/// SUBSTRING_INDEX('www.apache.org', '.', 1) = www +/// SUBSTRING_INDEX('www.apache.org', '.', 2) = www.apache +/// SUBSTRING_INDEX('www.apache.org', '.', -2) = apache.org +/// SUBSTRING_INDEX('www.apache.org', '.', -1) = org +pub fn substr_index(args: &[ArrayRef]) -> Result { + if args.len() != 3 { + return internal_err!( + "substr_index was called with {} arguments. It requires 3.", + args.len() + ); + } + + let string_array = as_generic_string_array::(&args[0])?; + let delimiter_array = as_generic_string_array::(&args[1])?; + let count_array = as_int64_array(&args[2])?; + + let result = string_array + .iter() + .zip(delimiter_array.iter()) + .zip(count_array.iter()) + .map(|((string, delimiter), n)| match (string, delimiter, n) { + (Some(string), Some(delimiter), Some(n)) => { + let mut res = String::new(); + match n { + 0 => { + "".to_string(); + } + _other => { + if n > 0 { + let idx = string + .split(delimiter) + .take(n as usize) + .fold(0, |len, x| len + x.len() + delimiter.len()) + - delimiter.len(); + res.push_str(if idx >= string.len() { + string + } else { + &string[..idx] + }); + } else { + let idx = (string.split(delimiter).take((-n) as usize).fold( + string.len() as isize, + |len, x| { + len - x.len() as isize - delimiter.len() as isize + }, + ) + delimiter.len() as isize) + as usize; + res.push_str(if idx >= string.len() { + string + } else { + &string[idx..] + }); + } + } + } + Some(res) + } + _ => None, + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index d43d19f85842..5c33b10f1395 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -641,6 +641,7 @@ enum ScalarFunction { ArrayExcept = 123; ArrayPopFront = 124; Levenshtein = 125; + SubstrIndex = 126; } message ScalarFunctionNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 133bbbee8920..598719dc8ac6 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -20863,6 +20863,7 @@ impl serde::Serialize for ScalarFunction { Self::ArrayExcept => "ArrayExcept", Self::ArrayPopFront => "ArrayPopFront", Self::Levenshtein => "Levenshtein", + Self::SubstrIndex => "SubstrIndex", }; serializer.serialize_str(variant) } @@ -21000,6 +21001,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayExcept", "ArrayPopFront", "Levenshtein", + "SubstrIndex", ]; struct GeneratedVisitor; @@ -21166,6 +21168,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayExcept" => Ok(ScalarFunction::ArrayExcept), "ArrayPopFront" => Ok(ScalarFunction::ArrayPopFront), "Levenshtein" => Ok(ScalarFunction::Levenshtein), + "SubstrIndex" => Ok(ScalarFunction::SubstrIndex), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 503c4b6c73f1..e79a17fc5c9c 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2594,6 +2594,7 @@ pub enum ScalarFunction { ArrayExcept = 123, ArrayPopFront = 124, Levenshtein = 125, + SubstrIndex = 126, } impl ScalarFunction { /// String value of the enum field names used in the ProtoBuf definition. @@ -2728,6 +2729,7 @@ impl ScalarFunction { ScalarFunction::ArrayExcept => "ArrayExcept", ScalarFunction::ArrayPopFront => "ArrayPopFront", ScalarFunction::Levenshtein => "Levenshtein", + ScalarFunction::SubstrIndex => "SubstrIndex", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -2859,6 +2861,7 @@ impl ScalarFunction { "ArrayExcept" => Some(Self::ArrayExcept), "ArrayPopFront" => Some(Self::ArrayPopFront), "Levenshtein" => Some(Self::Levenshtein), + "SubstrIndex" => Some(Self::SubstrIndex), _ => None, } } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index d4a64287b07e..b2455d5a0d13 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -55,9 +55,9 @@ use datafusion_expr::{ lower, lpad, ltrim, md5, nanvl, now, nullif, octet_length, overlay, pi, power, radians, random, regexp_match, regexp_replace, repeat, replace, reverse, right, round, rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, sinh, split_part, - sqrt, starts_with, string_to_array, strpos, struct_fun, substr, substring, tan, tanh, - to_hex, to_timestamp_micros, to_timestamp_millis, to_timestamp_nanos, - to_timestamp_seconds, translate, trim, trunc, upper, uuid, + sqrt, starts_with, string_to_array, strpos, struct_fun, substr, substr_index, + substring, tan, tanh, to_hex, to_timestamp_micros, to_timestamp_millis, + to_timestamp_nanos, to_timestamp_seconds, translate, trim, trunc, upper, uuid, window_frame::regularize, AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction, BuiltinScalarFunction, Case, Cast, Expr, GetFieldAccess, GetIndexedField, GroupingSet, @@ -551,6 +551,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::ArrowTypeof => Self::ArrowTypeof, ScalarFunction::OverLay => Self::OverLay, ScalarFunction::Levenshtein => Self::Levenshtein, + ScalarFunction::SubstrIndex => Self::SubstrIndex, } } } @@ -1716,6 +1717,11 @@ pub fn parse_expr( .map(|expr| parse_expr(expr, registry)) .collect::, _>>()?, )), + ScalarFunction::SubstrIndex => Ok(substr_index( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + parse_expr(&args[2], registry)?, + )), ScalarFunction::StructFun => { Ok(struct_fun(parse_expr(&args[0], registry)?)) } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 508cde98ae2a..9be4a532bb5b 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1583,6 +1583,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::ArrowTypeof => Self::ArrowTypeof, BuiltinScalarFunction::OverLay => Self::OverLay, BuiltinScalarFunction::Levenshtein => Self::Levenshtein, + BuiltinScalarFunction::SubstrIndex => Self::SubstrIndex, }; Ok(scalar_function) diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index 9c8bb2c5f844..91072a49cd46 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -877,3 +877,78 @@ query ? SELECT levenshtein(NULL, NULL) ---- NULL + +query T +SELECT substr_index('www.apache.org', '.', 1) +---- +www + +query T +SELECT substr_index('www.apache.org', '.', 2) +---- +www.apache + +query T +SELECT substr_index('www.apache.org', '.', -1) +---- +org + +query T +SELECT substr_index('www.apache.org', '.', -2) +---- +apache.org + +query T +SELECT substr_index('www.apache.org', 'ac', 1) +---- +www.ap + +query T +SELECT substr_index('www.apache.org', 'ac', -1) +---- +he.org + +query T +SELECT substr_index('www.apache.org', 'ac', 2) +---- +www.apache.org + +query T +SELECT substr_index('www.apache.org', 'ac', -2) +---- +www.apache.org + +query ? +SELECT substr_index(NULL, 'ac', 1) +---- +NULL + +query T +SELECT substr_index('www.apache.org', NULL, 1) +---- +NULL + +query T +SELECT substr_index('www.apache.org', 'ac', NULL) +---- +NULL + +query T +SELECT substr_index('', 'ac', 1) +---- +(empty) + +query T +SELECT substr_index('www.apache.org', '', 1) +---- +(empty) + +query T +SELECT substr_index('www.apache.org', 'ac', 0) +---- +(empty) + +query ? +SELECT substr_index(NULL, NULL, NULL) +---- +NULL diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index eda46ef8a73b..e7ebbc9f1fe7 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -637,6 +637,7 @@ nullif(expression1, expression2) - [uuid](#uuid) - [overlay](#overlay) - [levenshtein](#levenshtein) +- [substr_index](#substr_index) ### `ascii` @@ -1152,6 +1153,23 @@ levenshtein(str1, str2) - **str1**: String expression to compute Levenshtein distance with str2. - **str2**: String expression to compute Levenshtein distance with str1. +### `substr_index` + +Returns the substring from str before count occurrences of the delimiter delim. +If count is positive, everything to the left of the final delimiter (counting from the left) is returned. +If count is negative, everything to the right of the final delimiter (counting from the right) is returned. +For example, `substr_index('www.apache.org', '.', 1) = www`, `substr_index('www.apache.org', '.', -1) = org` + +``` +substr_index(str, delim, count) +``` + +#### Arguments + +- **str**: String expression to operate on. +- **delim**: the string to find in str to split str. +- **count**: The number of times to search for the delimiter. Can be both a positive or negative number. + ## Binary String Functions - [decode](#decode)