Skip to content

Commit

Permalink
feat: support the ergonomics of getting list slice with stride (#8946)
Browse files Browse the repository at this point in the history
* support list stride

* add test

* fix fmt

* rename and extend ListRange to ListStride

* fix ci

* fix doctest

* fix conflict and keep ListRange

* clean up thde code

* chore

* fix ci
  • Loading branch information
Weijun-H authored Jan 29, 2024
1 parent 1097dc0 commit fffc8be
Show file tree
Hide file tree
Showing 17 changed files with 282 additions and 80 deletions.
9 changes: 7 additions & 2 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,10 +209,15 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result<String> {
let key = create_physical_name(key, false)?;
format!("{expr}[{key}]")
}
GetFieldAccess::ListRange { start, stop } => {
GetFieldAccess::ListRange {
start,
stop,
stride,
} => {
let start = create_physical_name(start, false)?;
let stop = create_physical_name(stop, false)?;
format!("{expr}[{start}:{stop}]")
let stride = create_physical_name(stride, false)?;
format!("{expr}[{start}:{stop}:{stride}]")
}
};

Expand Down
28 changes: 21 additions & 7 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -421,8 +421,12 @@ pub enum GetFieldAccess {
NamedStructField { name: ScalarValue },
/// Single list index, for example: `list[i]`
ListIndex { key: Box<Expr> },
/// List range, for example `list[i:j]`
ListRange { start: Box<Expr>, stop: Box<Expr> },
/// List stride, for example `list[i:j:k]`
ListRange {
start: Box<Expr>,
stop: Box<Expr>,
stride: Box<Expr>,
},
}

/// Returns the field of a [`arrow::array::ListArray`] or
Expand Down Expand Up @@ -1209,14 +1213,15 @@ impl Expr {
/// # use datafusion_expr::{lit, col};
/// let expr = col("c1")
/// .range(lit(2), lit(4));
/// assert_eq!(expr.display_name().unwrap(), "c1[Int32(2):Int32(4)]");
/// assert_eq!(expr.display_name().unwrap(), "c1[Int32(2):Int32(4):Int64(1)]");
/// ```
pub fn range(self, start: Expr, stop: Expr) -> Self {
Expr::GetIndexedField(GetIndexedField {
expr: Box::new(self),
field: GetFieldAccess::ListRange {
start: Box::new(start),
stop: Box::new(stop),
stride: Box::new(Expr::Literal(ScalarValue::Int64(Some(1)))),
},
})
}
Expand Down Expand Up @@ -1530,8 +1535,12 @@ impl fmt::Display for Expr {
write!(f, "({expr})[{name}]")
}
GetFieldAccess::ListIndex { key } => write!(f, "({expr})[{key}]"),
GetFieldAccess::ListRange { start, stop } => {
write!(f, "({expr})[{start}:{stop}]")
GetFieldAccess::ListRange {
start,
stop,
stride,
} => {
write!(f, "({expr})[{start}:{stop}:{stride}]")
}
},
Expr::GroupingSet(grouping_sets) => match grouping_sets {
Expand Down Expand Up @@ -1732,10 +1741,15 @@ fn create_name(e: &Expr) -> Result<String> {
let key = create_name(key)?;
Ok(format!("{expr}[{key}]"))
}
GetFieldAccess::ListRange { start, stop } => {
GetFieldAccess::ListRange {
start,
stop,
stride,
} => {
let start = create_name(start)?;
let stop = create_name(stop)?;
Ok(format!("{expr}[{start}:{stop}]"))
let stride = create_name(stride)?;
Ok(format!("{expr}[{start}:{stop}:{stride}]"))
}
}
}
Expand Down
7 changes: 6 additions & 1 deletion datafusion/expr/src/expr_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -374,9 +374,14 @@ fn field_for_index<S: ExprSchema>(
GetFieldAccess::ListIndex { key } => GetFieldAccessSchema::ListIndex {
key_dt: key.get_type(schema)?,
},
GetFieldAccess::ListRange { start, stop } => GetFieldAccessSchema::ListRange {
GetFieldAccess::ListRange {
start,
stop,
stride,
} => GetFieldAccessSchema::ListRange {
start_dt: start.get_type(schema)?,
stop_dt: stop.get_type(schema)?,
stride_dt: stride.get_type(schema)?,
},
}
.get_accessed_field(&expr_dt)
Expand Down
13 changes: 7 additions & 6 deletions datafusion/expr/src/field_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@ pub enum GetFieldAccessSchema {
NamedStructField { name: ScalarValue },
/// Single list index, for example: `list[i]`
ListIndex { key_dt: DataType },
/// List range, for example `list[i:j]`
/// List stride, for example `list[i:j:k]`
ListRange {
start_dt: DataType,
stop_dt: DataType,
stride_dt: DataType,
},
}

Expand Down Expand Up @@ -85,13 +86,13 @@ impl GetFieldAccessSchema {
(other, _) => plan_err!("The expression to get an indexed field is only valid for `List` or `Struct` types, got {other}"),
}
}
Self::ListRange{ start_dt, stop_dt } => {
match (data_type, start_dt, stop_dt) {
(DataType::List(_), DataType::Int64, DataType::Int64) => Ok(Field::new("list", data_type.clone(), true)),
(DataType::List(_), _, _) => plan_err!(
Self::ListRange { start_dt, stop_dt, stride_dt } => {
match (data_type, start_dt, stop_dt, stride_dt) {
(DataType::List(_), DataType::Int64, DataType::Int64, DataType::Int64) => Ok(Field::new("list", data_type.clone(), true)),
(DataType::List(_), _, _, _) => plan_err!(
"Only ints are valid as an indexed field in a list"
),
(other, _, _) => plan_err!("The expression to get an indexed field is only valid for `List` or `Struct` types, got {other}"),
(other, _, _, _) => plan_err!("The expression to get an indexed field is only valid for `List` or `Struct` types, got {other}"),
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions datafusion/expr/src/tree_node/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ impl TreeNode for Expr {
let expr = expr.as_ref();
match field {
GetFieldAccess::ListIndex {key} => vec![key.as_ref(), expr],
GetFieldAccess::ListRange {start, stop} => {
vec![start.as_ref(), stop.as_ref(), expr]
GetFieldAccess::ListRange {start, stop, stride} => {
vec![start.as_ref(), stop.as_ref(),stride.as_ref(), expr]
}
GetFieldAccess::NamedStructField { .. } => vec![expr],
}
Expand Down
85 changes: 64 additions & 21 deletions datafusion/physical-expr/src/expressions/get_indexed_field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use crate::PhysicalExpr;
use datafusion_common::exec_err;

use crate::array_expressions::{array_element, array_slice};
use crate::expressions::Literal;
use crate::physical_expr::down_cast_any_ref;
use arrow::{
array::{Array, Scalar, StringArray},
Expand All @@ -43,10 +44,11 @@ pub enum GetFieldAccessExpr {
NamedStructField { name: ScalarValue },
/// Single list index, for example: `list[i]`
ListIndex { key: Arc<dyn PhysicalExpr> },
/// List range, for example `list[i:j]`
/// List stride, for example `list[i:j:k]`
ListRange {
start: Arc<dyn PhysicalExpr>,
stop: Arc<dyn PhysicalExpr>,
stride: Arc<dyn PhysicalExpr>,
},
}

Expand All @@ -55,8 +57,12 @@ impl std::fmt::Display for GetFieldAccessExpr {
match self {
GetFieldAccessExpr::NamedStructField { name } => write!(f, "[{}]", name),
GetFieldAccessExpr::ListIndex { key } => write!(f, "[{}]", key),
GetFieldAccessExpr::ListRange { start, stop } => {
write!(f, "[{}:{}]", start, stop)
GetFieldAccessExpr::ListRange {
start,
stop,
stride,
} => {
write!(f, "[{}:{}:{}]", start, stop, stride)
}
}
}
Expand All @@ -76,12 +82,18 @@ impl PartialEq<dyn Any> for GetFieldAccessExpr {
ListRange {
start: start_lhs,
stop: stop_lhs,
stride: stride_lhs,
},
ListRange {
start: start_rhs,
stop: stop_rhs,
stride: stride_rhs,
},
) => start_lhs.eq(start_rhs) && stop_lhs.eq(stop_rhs),
) => {
start_lhs.eq(start_rhs)
&& stop_lhs.eq(stop_rhs)
&& stride_lhs.eq(stride_rhs)
}
(NamedStructField { .. }, ListIndex { .. } | ListRange { .. }) => false,
(ListIndex { .. }, NamedStructField { .. } | ListRange { .. }) => false,
(ListRange { .. }, NamedStructField { .. } | ListIndex { .. }) => false,
Expand Down Expand Up @@ -126,7 +138,32 @@ impl GetIndexedFieldExpr {
start: Arc<dyn PhysicalExpr>,
stop: Arc<dyn PhysicalExpr>,
) -> Self {
Self::new(arg, GetFieldAccessExpr::ListRange { start, stop })
Self::new(
arg,
GetFieldAccessExpr::ListRange {
start,
stop,
stride: Arc::new(Literal::new(ScalarValue::Int64(Some(1))))
as Arc<dyn PhysicalExpr>,
},
)
}

/// Create a new [`GetIndexedFieldExpr`] for accessing the stride
pub fn new_stride(
arg: Arc<dyn PhysicalExpr>,
start: Arc<dyn PhysicalExpr>,
stop: Arc<dyn PhysicalExpr>,
stride: Arc<dyn PhysicalExpr>,
) -> Self {
Self::new(
arg,
GetFieldAccessExpr::ListRange {
start,
stop,
stride,
},
)
}

/// Get the description of what field should be accessed
Expand All @@ -147,12 +184,15 @@ impl GetIndexedFieldExpr {
GetFieldAccessExpr::ListIndex { key } => GetFieldAccessSchema::ListIndex {
key_dt: key.data_type(input_schema)?,
},
GetFieldAccessExpr::ListRange { start, stop } => {
GetFieldAccessSchema::ListRange {
start_dt: start.data_type(input_schema)?,
stop_dt: stop.data_type(input_schema)?,
}
}
GetFieldAccessExpr::ListRange {
start,
stop,
stride,
} => GetFieldAccessSchema::ListRange {
start_dt: start.data_type(input_schema)?,
stop_dt: stop.data_type(input_schema)?,
stride_dt: stride.data_type(input_schema)?,
},
})
}
}
Expand Down Expand Up @@ -223,21 +263,24 @@ impl PhysicalExpr for GetIndexedFieldExpr {
with utf8 indexes. Tried {dt:?} with {key:?} index"),
}
},
GetFieldAccessExpr::ListRange{start, stop} => {
GetFieldAccessExpr::ListRange { start, stop, stride } => {
let start = start.evaluate(batch)?.into_array(batch.num_rows())?;
let stop = stop.evaluate(batch)?.into_array(batch.num_rows())?;
match (array.data_type(), start.data_type(), stop.data_type()) {
(DataType::List(_), DataType::Int64, DataType::Int64) => Ok(ColumnarValue::Array(array_slice(&[
array, start, stop
])?)),
(DataType::List(_), start, stop) => exec_err!(
let stride = stride.evaluate(batch)?.into_array(batch.num_rows())?;
match (array.data_type(), start.data_type(), stop.data_type(), stride.data_type()) {
(DataType::List(_), DataType::Int64, DataType::Int64, DataType::Int64) => {
Ok(ColumnarValue::Array((array_slice(&[
array, start, stop, stride
]))?))
},
(DataType::List(_), start, stop, stride) => exec_err!(
"get indexed field is only possible on lists with int64 indexes. \
Tried with {start:?} and {stop:?} indices"),
(dt, start, stop) => exec_err!(
Tried with {start:?}, {stop:?} and {stride:?} indices"),
(dt, start, stop, stride) => exec_err!(
"get indexed field is only possible on lists with int64 indexes or struct \
with utf8 indexes. Tried {dt:?} with {start:?} and {stop:?} indices"),
with utf8 indexes. Tried {dt:?} with {start:?}, {stop:?} and {stride:?} indices"),
}
},
}
}
}

Expand Down
27 changes: 13 additions & 14 deletions datafusion/physical-expr/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,20 +238,19 @@ pub fn create_physical_expr(
GetFieldAccess::ListIndex { key } => GetFieldAccessExpr::ListIndex {
key: create_physical_expr(key, input_dfschema, execution_props)?,
},
GetFieldAccess::ListRange { start, stop } => {
GetFieldAccessExpr::ListRange {
start: create_physical_expr(
start,
input_dfschema,
execution_props,
)?,
stop: create_physical_expr(
stop,
input_dfschema,
execution_props,
)?,
}
}
GetFieldAccess::ListRange {
start,
stop,
stride,
} => GetFieldAccessExpr::ListRange {
start: create_physical_expr(start, input_dfschema, execution_props)?,
stop: create_physical_expr(stop, input_dfschema, execution_props)?,
stride: create_physical_expr(
stride,
input_dfschema,
execution_props,
)?,
},
};
Ok(Arc::new(GetIndexedFieldExpr::new(
create_physical_expr(expr, input_dfschema, execution_props)?,
Expand Down
2 changes: 2 additions & 0 deletions datafusion/proto/proto/datafusion.proto
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,7 @@ message ListIndex {
message ListRange {
LogicalExprNode start = 1;
LogicalExprNode stop = 2;
LogicalExprNode stride = 3;
}

message GetIndexedField {
Expand Down Expand Up @@ -1773,6 +1774,7 @@ message ListIndexExpr {
message ListRangeExpr {
PhysicalExprNode start = 1;
PhysicalExprNode stop = 2;
PhysicalExprNode stride = 3;
}

message PhysicalGetIndexedFieldExprNode {
Expand Down
Loading

0 comments on commit fffc8be

Please sign in to comment.