Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: new create_one ExpressionHandler API #662

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
284 changes: 283 additions & 1 deletion kernel/src/engine/arrow_expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,87 @@ impl ExpressionHandler for ArrowExpressionHandler {
output_type,
})
}

fn create_one(&self, schema: SchemaRef, expr: &Expression) -> DeltaResult<Box<dyn EngineData>> {
if let Expression::Struct(child_exprs) = expr {
if schema.len() != child_exprs.len() {
return Err(Error::Generic(format!(
"Schema has {} top-level fields, but struct expr has {} children",
schema.len(),
child_exprs.len()
)));
}

let arrays: Vec<ArrayRef> = schema
.fields()
.zip(child_exprs.iter())
.map(|(field, child_expr)| create_single_row_array(child_expr, field))
.collect::<Result<Vec<_>, Error>>()?;

let record_batch = RecordBatch::try_new(Arc::new(schema.as_ref().try_into()?), arrays)?;
Ok(Box::new(ArrowEngineData::new(record_batch)))
} else {
Err(Error::generic(
"ArrowExpressionHandler::create_one() requires a top-level struct expression",
))
}
}
}

fn create_single_row_array(expr: &Expression, field: &StructField) -> DeltaResult<ArrayRef> {
match expr {
// simple case: for literals, just create a single-row array and ensure the data types match
Expression::Literal(scalar) => {
let array = scalar.to_array(1)?;
// TODO(zach): we could do better and cast here instead of always failing
ensure_data_types(field.data_type(), array.data_type(), true)?;
Ok(array)
}
// recursive case: for struct expressions, build a struct array by recursing into each child
Expression::Struct(child_exprs) => {
// co-traverse the expression and schema: we expect the data type to be struct, error
// otherwise.
match field.data_type() {
DataType::Struct(struct_type) => {
if struct_type.len() != child_exprs.len() {
return Err(Error::Generic(format!(
"Schema struct field has {} children, but expression has {} children",
struct_type.len(),
child_exprs.len()
)));
}

let child_arrays: Vec<ArrayRef> = struct_type
.fields()
.zip(child_exprs.iter())
.map(|(subfield, subexpr)| create_single_row_array(subexpr, subfield))
.collect::<Result<Vec<ArrayRef>, Error>>()?;

let arrow_fields = struct_type
.fields()
.map(|f| ArrowField::try_from(f))
.collect::<Result<Vec<ArrowField>, ArrowError>>()?;

let struct_array = StructArray::new(
arrow_fields.into(),
child_arrays,
None, // FIXME: null bitmap
);

Ok(Arc::new(struct_array))
}
other_type => Err(Error::Generic(format!(
"Expected struct type in schema, but got {:?}",
other_type
))),
}
}
// fail for any non-literal, non-struct expressions
non_literal_non_struct => Err(Error::Unsupported(format!(
"build_array_from_expr: unhandled expr variant: {:?}",
non_literal_non_struct
))),
}
}

#[derive(Debug)]
Expand Down Expand Up @@ -568,7 +649,7 @@ impl ExpressionEvaluator for DefaultExpressionEvaluator {
mod tests {
use std::ops::{Add, Div, Mul, Sub};

use arrow_array::{GenericStringArray, Int32Array};
use arrow_array::{create_array, record_batch, GenericStringArray, Int32Array};
use arrow_buffer::ScalarBuffer;
use arrow_schema::{DataType, Field, Fields, Schema};

Expand Down Expand Up @@ -867,4 +948,205 @@ mod tests {
let expected = Arc::new(BooleanArray::from(vec![true, false]));
assert_eq!(results.as_ref(), expected.as_ref());
}

#[test]
fn test_create_one() {
let expr = Expression::struct_from([
Expression::literal(1),
Expression::literal(2),
Expression::literal(3),
]);
let schema = Arc::new(crate::schema::StructType::new([
StructField::new("a", DeltaDataTypes::INTEGER, true),
StructField::new("b", DeltaDataTypes::INTEGER, true),
StructField::new("c", DeltaDataTypes::INTEGER, true),
]));

let handler = ArrowExpressionHandler;
let actual = handler.create_one(schema, &expr).unwrap();
let expected =
record_batch!(("a", Int32, [1]), ("b", Int32, [2]), ("c", Int32, [3])).unwrap();
let actual_rb: RecordBatch = actual
.into_any()
.downcast::<ArrowEngineData>()
.unwrap()
.into();
assert_eq!(actual_rb, expected);
}

#[test]
fn test_create_one_string() {
let expr = Expression::struct_from([Expression::literal("a")]);
let schema = Arc::new(crate::schema::StructType::new([StructField::new(
"col_1",
DeltaDataTypes::STRING,
true,
)]));

let handler = ArrowExpressionHandler;
let actual = handler.create_one(schema, &expr).unwrap();
let expected = record_batch!(("col_1", Utf8, ["a"])).unwrap();
let actual_rb: RecordBatch = actual
.into_any()
.downcast::<ArrowEngineData>()
.unwrap()
.into();
assert_eq!(actual_rb, expected);
}

#[test]
fn test_create_one_null() {
let expr = Expression::struct_from([Expression::null_literal(DeltaDataTypes::INTEGER)]);
let schema = Arc::new(crate::schema::StructType::new([StructField::new(
"col_1",
DeltaDataTypes::INTEGER,
true,
)]));

let handler = ArrowExpressionHandler;
let actual = handler.create_one(schema, &expr).unwrap();
let expected = record_batch!(("col_1", Int32, [None])).unwrap();
let actual_rb: RecordBatch = actual
.into_any()
.downcast::<ArrowEngineData>()
.unwrap()
.into();
assert_eq!(actual_rb, expected);
}

#[test]
fn test_create_one_non_null() {
let expr = Expression::struct_from([Expression::literal(1)]);
let schema = Arc::new(crate::schema::StructType::new([StructField::new(
"a",
DeltaDataTypes::INTEGER,
false,
)]));

let handler = ArrowExpressionHandler;
let actual = handler.create_one(schema, &expr).unwrap();
let expected_schema = Arc::new(arrow_schema::Schema::new(vec![Field::new(
"a",
DataType::Int32,
false,
)]));
let expected =
RecordBatch::try_new(expected_schema, vec![create_array!(Int32, [1])]).unwrap();
let actual_rb: RecordBatch = actual
.into_any()
.downcast::<ArrowEngineData>()
.unwrap()
.into();
assert_eq!(actual_rb, expected);
}

#[test]
fn test_create_one_disallow_column_ref() {
let expr = Expression::struct_from([column_expr!("a")]);
let schema = Arc::new(crate::schema::StructType::new([StructField::new(
"a",
DeltaDataTypes::INTEGER,
true,
)]));

let handler = ArrowExpressionHandler;
let actual = handler.create_one(schema, &expr);
assert!(actual.is_err());
}

#[test]
fn test_create_one_disallow_operator() {
let expr = Expression::struct_from([Expression::binary(
BinaryOperator::Plus,
Expression::literal(1),
Expression::literal(2),
)]);
let schema = Arc::new(crate::schema::StructType::new([StructField::new(
"a",
DeltaDataTypes::INTEGER,
true,
)]));

let handler = ArrowExpressionHandler;
let actual = handler.create_one(schema, &expr);
assert!(actual.is_err());
}

#[test]
fn test_create_one_nested() {
let expr = Expression::struct_from([Expression::struct_from([
Expression::literal(1),
Expression::literal(2),
])]);
let schema = Arc::new(crate::schema::StructType::new([StructField::new(
"a",
DeltaDataTypes::struct_type([
StructField::new("b", DeltaDataTypes::INTEGER, true),
StructField::new("c", DeltaDataTypes::INTEGER, false),
]),
false,
)]));

let handler = ArrowExpressionHandler;
let actual = handler.create_one(schema, &expr).unwrap();
let expected_schema = Arc::new(arrow_schema::Schema::new(vec![Field::new(
"a",
DataType::Struct(
vec![
Field::new("b", DataType::Int32, true),
Field::new("c", DataType::Int32, false),
]
.into(),
),
false,
)]));
let expected = RecordBatch::try_new(
expected_schema,
vec![Arc::new(StructArray::from(vec![
(
Arc::new(Field::new("b", DataType::Int32, true)),
create_array!(Int32, [1]) as ArrayRef,
),
(
Arc::new(Field::new("c", DataType::Int32, false)),
create_array!(Int32, [2]) as ArrayRef,
),
]))],
)
.unwrap();
let actual_rb: RecordBatch = actual
.into_any()
.downcast::<ArrowEngineData>()
.unwrap()
.into();
assert_eq!(actual_rb, expected);

// make the same but with literal struct instead of struct of literal
let struct_data = StructData::try_new(
vec![
StructField::new("b", DeltaDataTypes::INTEGER, true),
StructField::new("c", DeltaDataTypes::INTEGER, false),
],
vec![Scalar::Integer(1), Scalar::Integer(2)],
)
.unwrap();
let expr = Expression::struct_from([Expression::literal(Scalar::Struct(struct_data))]);
let schema = Arc::new(crate::schema::StructType::new([StructField::new(
"a",
DeltaDataTypes::struct_type([
StructField::new("b", DeltaDataTypes::INTEGER, true),
StructField::new("c", DeltaDataTypes::INTEGER, false),
]),
false,
)]));

let handler = ArrowExpressionHandler;
let actual = handler.create_one(schema, &expr).unwrap();
let actual_rb: RecordBatch = actual
.into_any()
.downcast::<ArrowEngineData>()
.unwrap()
.into();
assert_eq!(actual_rb, expected);
}
}
9 changes: 9 additions & 0 deletions kernel/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,15 @@ pub trait ExpressionHandler: AsAny {
expression: Expression,
output_type: DataType,
) -> Arc<dyn ExpressionEvaluator>;

/// Create a single-row [`EngineData`] by evaluating an [`Expression`] with no column
/// references.
///
/// The schema of the output is the schema parameter which must match the output of the
/// expression.
// Note: we will stick with a Schema instead of DataType (more constrained can expand in
// future)
fn create_one(&self, schema: SchemaRef, expr: &Expression) -> DeltaResult<Box<dyn EngineData>>;
}

/// Provides file system related functionalities to Delta Kernel.
Expand Down
5 changes: 5 additions & 0 deletions kernel/src/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,11 @@ impl StructType {
self.fields.values()
}

pub fn len(&self) -> usize {
// O(1) for indexmap
self.fields.len()
}

/// Extracts the name and type of all leaf columns, in schema order. Caller should pass Some
/// `own_name` if this schema is embedded in a larger struct (e.g. `add.*`) and None if the
/// schema is a top-level result (e.g. `*`).
Expand Down
Loading