Skip to content

Commit

Permalink
feat:implement sql style 'find_in_set' string function (#8328)
Browse files Browse the repository at this point in the history
* feat:implement sql style 'find_in_set' string function

* format code

* modify test case
  • Loading branch information
Syleechan authored Nov 30, 2023
1 parent 06bbe12 commit c079a92
Show file tree
Hide file tree
Showing 11 changed files with 146 additions and 2 deletions.
11 changes: 11 additions & 0 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,8 @@ pub enum BuiltinScalarFunction {
Levenshtein,
/// substr_index
SubstrIndex,
/// find_in_set
FindInSet,
}

/// Maps the sql function name to `BuiltinScalarFunction`
Expand Down Expand Up @@ -472,6 +474,7 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::OverLay => Volatility::Immutable,
BuiltinScalarFunction::Levenshtein => Volatility::Immutable,
BuiltinScalarFunction::SubstrIndex => Volatility::Immutable,
BuiltinScalarFunction::FindInSet => Volatility::Immutable,

// Stable builtin functions
BuiltinScalarFunction::Now => Volatility::Stable,
Expand Down Expand Up @@ -778,6 +781,9 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::SubstrIndex => {
utf8_to_str_type(&input_expr_types[0], "substr_index")
}
BuiltinScalarFunction::FindInSet => {
utf8_to_int_type(&input_expr_types[0], "find_in_set")
}
BuiltinScalarFunction::ToTimestamp
| BuiltinScalarFunction::ToTimestampNanos => Ok(Timestamp(Nanosecond, None)),
BuiltinScalarFunction::ToTimestampMillis => Ok(Timestamp(Millisecond, None)),
Expand Down Expand Up @@ -1244,6 +1250,10 @@ impl BuiltinScalarFunction {
],
self.volatility(),
),
BuiltinScalarFunction::FindInSet => Signature::one_of(
vec![Exact(vec![Utf8, Utf8]), Exact(vec![LargeUtf8, LargeUtf8])],
self.volatility(),
),

BuiltinScalarFunction::Replace | BuiltinScalarFunction::Translate => {
Signature::one_of(vec![Exact(vec![Utf8, Utf8, Utf8])], self.volatility())
Expand Down Expand Up @@ -1499,6 +1509,7 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Uuid => &["uuid"],
BuiltinScalarFunction::Levenshtein => &["levenshtein"],
BuiltinScalarFunction::SubstrIndex => &["substr_index", "substring_index"],
BuiltinScalarFunction::FindInSet => &["find_in_set"],

// regex functions
BuiltinScalarFunction::RegexpMatch => &["regexp_match"],
Expand Down
2 changes: 2 additions & 0 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -917,6 +917,7 @@ scalar_expr!(
scalar_expr!(ArrowTypeof, arrow_typeof, val, "data type");
scalar_expr!(Levenshtein, levenshtein, string1 string2, "Returns the Levenshtein distance between the two given strings");
scalar_expr!(SubstrIndex, substr_index, string delimiter count, "Returns the substring from str before count occurrences of the delimiter");
scalar_expr!(FindInSet, find_in_set, str strlist, "Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings");

scalar_expr!(
Struct,
Expand Down Expand Up @@ -1207,6 +1208,7 @@ mod test {
test_nary_scalar_expr!(OverLay, overlay, string, characters, position);
test_scalar_expr!(Levenshtein, levenshtein, string1, string2);
test_scalar_expr!(SubstrIndex, substr_index, string, delimiter, count);
test_scalar_expr!(FindInSet, find_in_set, string, stringlist);
}

#[test]
Expand Down
21 changes: 21 additions & 0 deletions datafusion/physical-expr/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -885,6 +885,27 @@ pub fn create_physical_fun(
))),
})
}
BuiltinScalarFunction::FindInSet => Arc::new(|args| match args[0].data_type() {
DataType::Utf8 => {
let func = invoke_if_unicode_expressions_feature_flag!(
find_in_set,
Int32Type,
"find_in_set"
);
make_scalar_function(func)(args)
}
DataType::LargeUtf8 => {
let func = invoke_if_unicode_expressions_feature_flag!(
find_in_set,
Int64Type,
"find_in_set"
);
make_scalar_function(func)(args)
}
other => Err(DataFusionError::Internal(format!(
"Unsupported data type {other:?} for function find_in_set",
))),
}),
})
}

Expand Down
39 changes: 39 additions & 0 deletions datafusion/physical-expr/src/unicode_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -520,3 +520,42 @@ pub fn substr_index<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {

Ok(Arc::new(result) as ArrayRef)
}

///Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings
///A string list is a string composed of substrings separated by , characters.
pub fn find_in_set<T: ArrowPrimitiveType>(args: &[ArrayRef]) -> Result<ArrayRef>
where
T::Native: OffsetSizeTrait,
{
if args.len() != 2 {
return internal_err!(
"find_in_set was called with {} arguments. It requires 2.",
args.len()
);
}

let str_array: &GenericStringArray<T::Native> =
as_generic_string_array::<T::Native>(&args[0])?;
let str_list_array: &GenericStringArray<T::Native> =
as_generic_string_array::<T::Native>(&args[1])?;

let result = str_array
.iter()
.zip(str_list_array.iter())
.map(|(string, str_list)| match (string, str_list) {
(Some(string), Some(str_list)) => {
let mut res = 0;
let str_set: Vec<&str> = str_list.split(',').collect();
for (idx, str) in str_set.iter().enumerate() {
if str == &string {
res = idx + 1;
break;
}
}
T::Native::from_usize(res)
}
_ => None,
})
.collect::<PrimitiveArray<T>>();
Ok(Arc::new(result) as ArrayRef)
}
1 change: 1 addition & 0 deletions datafusion/proto/proto/datafusion.proto
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,7 @@ enum ScalarFunction {
ArrayPopFront = 124;
Levenshtein = 125;
SubstrIndex = 126;
FindInSet = 127;
}

message ScalarFunctionNode {
Expand Down
3 changes: 3 additions & 0 deletions datafusion/proto/src/generated/pbjson.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions datafusion/proto/src/generated/prost.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 7 additions & 2 deletions datafusion/proto/src/logical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ use datafusion_expr::{
chr, coalesce, concat_expr, concat_ws_expr, cos, cosh, cot, current_date,
current_time, date_bin, date_part, date_trunc, decode, degrees, digest, encode, exp,
expr::{self, InList, Sort, WindowFunction},
factorial, flatten, floor, from_unixtime, gcd, gen_range, isnan, iszero, lcm, left,
levenshtein, ln, log, log10, log2,
factorial, find_in_set, flatten, floor, from_unixtime, gcd, gen_range, isnan, iszero,
lcm, left, levenshtein, ln, log, log10, log2,
logical_plan::{PlanType, StringifiedPlan},
lower, lpad, ltrim, md5, nanvl, now, nullif, octet_length, overlay, pi, power,
radians, random, regexp_match, regexp_replace, repeat, replace, reverse, right,
Expand Down Expand Up @@ -552,6 +552,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction {
ScalarFunction::OverLay => Self::OverLay,
ScalarFunction::Levenshtein => Self::Levenshtein,
ScalarFunction::SubstrIndex => Self::SubstrIndex,
ScalarFunction::FindInSet => Self::FindInSet,
}
}
}
Expand Down Expand Up @@ -1722,6 +1723,10 @@ pub fn parse_expr(
parse_expr(&args[1], registry)?,
parse_expr(&args[2], registry)?,
)),
ScalarFunction::FindInSet => Ok(find_in_set(
parse_expr(&args[0], registry)?,
parse_expr(&args[1], registry)?,
)),
ScalarFunction::StructFun => {
Ok(struct_fun(parse_expr(&args[0], registry)?))
}
Expand Down
1 change: 1 addition & 0 deletions datafusion/proto/src/logical_plan/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1584,6 +1584,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction {
BuiltinScalarFunction::OverLay => Self::OverLay,
BuiltinScalarFunction::Levenshtein => Self::Levenshtein,
BuiltinScalarFunction::SubstrIndex => Self::SubstrIndex,
BuiltinScalarFunction::FindInSet => Self::FindInSet,
};

Ok(scalar_function)
Expand Down
43 changes: 43 additions & 0 deletions datafusion/sqllogictest/test_files/functions.slt
Original file line number Diff line number Diff line change
Expand Up @@ -952,3 +952,46 @@ query ?
SELECT substr_index(NULL, NULL, NULL)
----
NULL

query I
SELECT find_in_set('b', 'a,b,c,d')
----
2


query I
SELECT find_in_set('a', 'a,b,c,d,a')
----
1

query I
SELECT find_in_set('', 'a,b,c,d,a')
----
0

query I
SELECT find_in_set('a', '')
----
0


query I
SELECT find_in_set('', '')
----
1

query ?
SELECT find_in_set(NULL, 'a,b,c,d')
----
NULL

query I
SELECT find_in_set('a', NULL)
----
NULL


query ?
SELECT find_in_set(NULL, NULL)
----
NULL
15 changes: 15 additions & 0 deletions docs/source/user-guide/sql/scalar_functions.md
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,7 @@ nullif(expression1, expression2)
- [overlay](#overlay)
- [levenshtein](#levenshtein)
- [substr_index](#substr_index)
- [find_in_set](#find_in_set)

### `ascii`

Expand Down Expand Up @@ -1170,6 +1171,20 @@ substr_index(str, delim, count)
- **delim**: the string to find in str to split str.
- **count**: The number of times to search for the delimiter. Can be both a positive or negative number.

### `find_in_set`

Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings.
For example, `find_in_set('b', 'a,b,c,d') = 2`

```
find_in_set(str, strlist)
```

#### Arguments

- **str**: String expression to find in strlist.
- **strlist**: A string list is a string composed of substrings separated by , characters.

## Binary String Functions

- [decode](#decode)
Expand Down

0 comments on commit c079a92

Please sign in to comment.