Skip to content

Commit

Permalink
feat: support stride in array_slice, change indexes to be1 based (
Browse files Browse the repository at this point in the history
#8829)

* support array slice

* fix argument

* fix typo

* support from and to is negative

* fix conflict

* modify user doc

* refactor code

* fix clippy

* add 1-index test

---------

Co-authored-by: Andrew Lamb <[email protected]>
  • Loading branch information
Weijun-H and alamb authored Jan 21, 2024
1 parent b7e13a0 commit 0116e2a
Show file tree
Hide file tree
Showing 6 changed files with 155 additions and 72 deletions.
5 changes: 4 additions & 1 deletion datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -977,7 +977,10 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArrayReplaceAll => {
Signature::any(3, self.volatility())
}
BuiltinScalarFunction::ArraySlice => Signature::any(3, self.volatility()),
BuiltinScalarFunction::ArraySlice => {
Signature::variadic_any(self.volatility())
}

BuiltinScalarFunction::ArrayToString => {
Signature::variadic_any(self.volatility())
}
Expand Down
2 changes: 1 addition & 1 deletion datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -731,7 +731,7 @@ scalar_expr!(
scalar_expr!(
ArraySlice,
array_slice,
array offset length,
array begin end stride,
"returns a slice of the array."
);
scalar_expr!(
Expand Down
93 changes: 76 additions & 17 deletions datafusion/physical-expr/src/array_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use std::sync::Arc;

use arrow::array::*;
use arrow::buffer::OffsetBuffer;
use arrow::compute;
use arrow::compute::{self};
use arrow::datatypes::{DataType, Field, UInt64Type};
use arrow::row::{RowConverter, SortField};
use arrow_buffer::{ArrowNativeType, NullBuffer};
Expand Down Expand Up @@ -575,23 +575,31 @@ pub fn array_except(args: &[ArrayRef]) -> Result<ArrayRef> {
///
/// See test cases in `array.slt` for more details.
pub fn array_slice(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 3 {
return exec_err!("array_slice needs three arguments");
let args_len = args.len();
if args_len != 3 && args_len != 4 {
return exec_err!("array_slice needs three or four arguments");
}

let stride = if args_len == 4 {
Some(as_int64_array(&args[3])?)
} else {
None
};

let from_array = as_int64_array(&args[1])?;
let to_array = as_int64_array(&args[2])?;

let array_data_type = args[0].data_type();
match array_data_type {
DataType::List(_) => {
let array = as_list_array(&args[0])?;
let from_array = as_int64_array(&args[1])?;
let to_array = as_int64_array(&args[2])?;
general_array_slice::<i32>(array, from_array, to_array)
general_array_slice::<i32>(array, from_array, to_array, stride)
}
DataType::LargeList(_) => {
let array = as_large_list_array(&args[0])?;
let from_array = as_int64_array(&args[1])?;
let to_array = as_int64_array(&args[2])?;
general_array_slice::<i64>(array, from_array, to_array)
general_array_slice::<i64>(array, from_array, to_array, stride)
}
_ => exec_err!("array_slice does not support type: {:?}", array_data_type),
}
Expand All @@ -601,6 +609,7 @@ fn general_array_slice<O: OffsetSizeTrait>(
array: &GenericListArray<O>,
from_array: &Int64Array,
to_array: &Int64Array,
stride: Option<&Int64Array>,
) -> Result<ArrayRef>
where
i64: TryInto<O>,
Expand Down Expand Up @@ -652,7 +661,7 @@ where
let adjusted_zero_index = if index < 0 {
// array_slice in duckdb with negative to_index is python-like, so index itself is exclusive
if let Ok(index) = index.try_into() {
index + len - O::usize_as(1)
index + len
} else {
return exec_err!("array_slice got invalid index: {}", index);
}
Expand Down Expand Up @@ -700,17 +709,67 @@ where
};

if let (Some(from), Some(to)) = (from_index, to_index) {
let stride = stride.map(|s| s.value(row_index));
// array_slice with stride in duckdb, return empty array if stride is not supported and from > to.
if stride.is_none() && from > to {
// return empty array
offsets.push(offsets[row_index]);
continue;
}
let stride = stride.unwrap_or(1);
if stride.is_zero() {
return exec_err!(
"array_slice got invalid stride: {:?}, it cannot be 0",
stride
);
} else if from <= to && stride.is_negative() {
// return empty array
offsets.push(offsets[row_index]);
continue;
}

let stride: O = stride.try_into().map_err(|_| {
internal_datafusion_err!("array_slice got invalid stride: {}", stride)
})?;

if from <= to {
assert!(start + to <= end);
mutable.extend(
0,
(start + from).to_usize().unwrap(),
(start + to + O::usize_as(1)).to_usize().unwrap(),
);
offsets.push(offsets[row_index] + (to - from + O::usize_as(1)));
if stride.eq(&O::one()) {
// stride is default to 1
mutable.extend(
0,
(start + from).to_usize().unwrap(),
(start + to + O::usize_as(1)).to_usize().unwrap(),
);
offsets.push(offsets[row_index] + (to - from + O::usize_as(1)));
continue;
}
let mut index = start + from;
let mut cnt = 0;
while index <= start + to {
mutable.extend(
0,
index.to_usize().unwrap(),
index.to_usize().unwrap() + 1,
);
index += stride;
cnt += 1;
}
offsets.push(offsets[row_index] + O::usize_as(cnt));
} else {
let mut index = start + from;
let mut cnt = 0;
while index >= start + to {
mutable.extend(
0,
index.to_usize().unwrap(),
index.to_usize().unwrap() + 1,
);
index += stride;
cnt += 1;
}
// invalid range, return empty array
offsets.push(offsets[row_index]);
offsets.push(offsets[row_index] + O::usize_as(cnt));
}
} else {
// invalid range, return empty array
Expand Down Expand Up @@ -741,7 +800,7 @@ where
.map(|arr| arr.map_or(0, |arr| arr.len() as i64))
.collect::<Vec<i64>>(),
);
general_array_slice::<O>(array, &from_array, &to_array)
general_array_slice::<O>(array, &from_array, &to_array, None)
}

fn general_pop_back_list<O: OffsetSizeTrait>(
Expand All @@ -757,7 +816,7 @@ where
.map(|arr| arr.map_or(0, |arr| arr.len() as i64 - 1))
.collect::<Vec<i64>>(),
);
general_array_slice::<O>(array, &from_array, &to_array)
general_array_slice::<O>(array, &from_array, &to_array, None)
}

/// array_pop_front SQL function
Expand Down
1 change: 1 addition & 0 deletions datafusion/proto/src/logical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1456,6 +1456,7 @@ pub fn parse_expr(
parse_expr(&args[0], registry)?,
parse_expr(&args[1], registry)?,
parse_expr(&args[2], registry)?,
parse_expr(&args[3], registry)?,
)),
ScalarFunction::ArrayToString => Ok(array_to_string(
parse_expr(&args[0], registry)?,
Expand Down
55 changes: 37 additions & 18 deletions datafusion/sqllogictest/test_files/array.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1234,6 +1234,25 @@ select array_slice(make_array(1, 2, 3, 4, 5), 2, 4), array_slice(make_array('h',
----
[2, 3, 4] [h, e]

query ????
select array_slice(make_array(1, 2, 3, 4, 5), 1, 5, 2), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 1, 5, 2),
array_slice(make_array(1, 2, 3, 4, 5), 0, 5, 2), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 0, 5, 2);
----
[1, 3, 5] [h, l, o] [1, 3, 5] [h, l, o]

query ??
select array_slice(make_array(1, 2, 3, 4, 5), 1, 5, -1), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 1, 5, -1);
----
[] []

query error Execution error: array_slice got invalid stride: 0, it cannot be 0
select array_slice(make_array(1, 2, 3, 4, 5), 1, 5, 0), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 1, 5, 0);

query ??
select array_slice(make_array(1, 2, 3, 4, 5), 5, 1, -2), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 5, 1, -2);
----
[5, 3, 1] [o, l, h]

query ??
select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2, 4), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 1, 2);
----
Expand Down Expand Up @@ -1342,12 +1361,12 @@ select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), NU
query ??
select array_slice(make_array(1, 2, 3, 4, 5), 0, -4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 0, -3);
----
[1] [h, e]
[1, 2] [h, e, l]

query ??
select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 0, -4), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 0, -3);
----
[1] [h, e]
[1, 2] [h, e, l]

# array_slice scalar function #13 (with negative number and NULL)
query error
Expand All @@ -1367,34 +1386,34 @@ select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), NU
query ??
select array_slice(make_array(1, 2, 3, 4, 5), -4, -1), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -3, -1);
----
[2, 3, 4] [l, l]
[2, 3, 4, 5] [l, l, o]

query ??
select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -4, -1), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -3, -1);
----
[2, 3, 4] [l, l]
[2, 3, 4, 5] [l, l, o]

# array_slice scalar function #16 (with negative indexes; almost full array (only with negative indices cannot return full array))
query ??
select array_slice(make_array(1, 2, 3, 4, 5), -5, -1), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -5, -1);
----
[1, 2, 3, 4] [h, e, l, l]
[1, 2, 3, 4, 5] [h, e, l, l, o]

query ??
select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -5, -1), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -5, -1);
----
[1, 2, 3, 4] [h, e, l, l]
[1, 2, 3, 4, 5] [h, e, l, l, o]

# array_slice scalar function #17 (with negative indexes; first index = second index)
query ??
select array_slice(make_array(1, 2, 3, 4, 5), -4, -4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -3, -3);
----
[] []
[2] [l]

query ??
select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -4, -4), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -3, -3);
----
[] []
[2] [l]

# array_slice scalar function #18 (with negative indexes; first index > second_index)
query ??
Expand Down Expand Up @@ -1422,24 +1441,24 @@ select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -7
query ??
select array_slice(make_array(make_array(1, 2, 3, 4, 5), make_array(6, 7, 8, 9, 10)), -2, -1), array_slice(make_array(make_array(1, 2, 3), make_array(6, 7, 8)), -1, -1);
----
[[1, 2, 3, 4, 5]] []
[[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]] [[6, 7, 8]]

query ??
select array_slice(arrow_cast(make_array(make_array(1, 2, 3, 4, 5), make_array(6, 7, 8, 9, 10)), 'LargeList(List(Int64))'), -2, -1), array_slice(arrow_cast(make_array(make_array(1, 2, 3), make_array(6, 7, 8)), 'LargeList(List(Int64))'), -1, -1);
----
[[1, 2, 3, 4, 5]] []
[[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]] [[6, 7, 8]]


# array_slice scalar function #21 (with first positive index and last negative index)
query ??
select array_slice(make_array(1, 2, 3, 4, 5), 2, -3), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 2, -2);
----
[2] [e, l]
[2, 3] [e, l, l]

query ??
select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2, -3), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 2, -2);
----
[2] [e, l]
[2, 3] [e, l, l]

# array_slice scalar function #22 (with first negative index and last positive index)
query ??
Expand Down Expand Up @@ -1468,7 +1487,7 @@ query ?
select array_slice(column1, column2, column3) from slices;
----
[]
[12, 13, 14, 15, 16]
[12, 13, 14, 15, 16, 17]
[]
[]
[]
Expand All @@ -1479,7 +1498,7 @@ query ?
select array_slice(arrow_cast(column1, 'LargeList(Int64)'), column2, column3) from slices;
----
[]
[12, 13, 14, 15, 16]
[12, 13, 14, 15, 16, 17]
[]
[]
[]
Expand All @@ -1492,9 +1511,9 @@ query ???
select array_slice(make_array(1, 2, 3, 4, 5), column2, column3), array_slice(column1, 3, column3), array_slice(column1, column2, 5) from slices;
----
[1] [] [, 2, 3, 4, 5]
[] [13, 14, 15, 16] [12, 13, 14, 15]
[2] [13, 14, 15, 16, 17] [12, 13, 14, 15]
[] [] [21, 22, 23, , 25]
[] [33] []
[] [33, 34] []
[4, 5] [] []
[1, 2, 3, 4, 5] [43, 44, 45, 46] [41, 42, 43, 44, 45]
[5] [, 54, 55, 56, 57, 58, 59, 60] [55]
Expand All @@ -1503,9 +1522,9 @@ query ???
select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), column2, column3), array_slice(arrow_cast(column1, 'LargeList(Int64)'), 3, column3), array_slice(arrow_cast(column1, 'LargeList(Int64)'), column2, 5) from slices;
----
[1] [] [, 2, 3, 4, 5]
[] [13, 14, 15, 16] [12, 13, 14, 15]
[2] [13, 14, 15, 16, 17] [12, 13, 14, 15]
[] [] [21, 22, 23, , 25]
[] [33] []
[] [33, 34] []
[4, 5] [] []
[1, 2, 3, 4, 5] [43, 44, 45, 46] [41, 42, 43, 44, 45]
[5] [, 54, 55, 56, 57, 58, 59, 60] [55]
Expand Down
Loading

0 comments on commit 0116e2a

Please sign in to comment.