From b8b0c5584f9f3a3aeca730ef1ac23dafc3e76dde Mon Sep 17 00:00:00 2001 From: Victor Barua Date: Sat, 4 Jan 2025 07:05:04 -0800 Subject: [PATCH] feat(substrait): introduce consume_rel and consume_expression (#13963) * feat(substrait): introduce consume_rel and consume_expression Route calls to from_substrait_rel and from_substrait_rex through the SubstraitConsumer in order to allow users to provide their own behaviour * feat(substrait): consume nulls of user-defined types * docs(substrait): consume_rel and consume_expression docstrings --- .../substrait/src/logical_plan/consumer.rs | 373 +++++------------- 1 file changed, 106 insertions(+), 267 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 515553152659..0ee87afe3286 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -17,7 +17,7 @@ use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano, OffsetBuffer}; use async_recursion::async_recursion; -use datafusion::arrow::array::{GenericListArray, MapArray}; +use datafusion::arrow::array::MapArray; use datafusion::arrow::datatypes::{ DataType, Field, FieldRef, Fields, IntervalUnit, Schema, TimeUnit, }; @@ -173,9 +173,9 @@ use substrait::proto::{ /// /// // You can implement a fully custom consumer method if you need special handling /// async fn consume_filter(&self, rel: &FilterRel) -> Result { -/// let input = from_substrait_rel(self, rel.input.as_ref().unwrap()).await?; +/// let input = self.consume_rel(rel.input.as_ref().unwrap()).await?; /// let expression = -/// from_substrait_rex(self, rel.condition.as_ref().unwrap(), input.schema()) +/// self.consume_expression(rel.condition.as_ref().unwrap(), input.schema()) /// .await?; /// // though this one is quite boring /// LogicalPlanBuilder::from(input).filter(expression)?.build() @@ -233,6 +233,12 @@ pub trait SubstraitConsumer: Send + Sync + Sized { // These methods have default implementations calling the common handler code, to allow for users // to re-use common handling logic. + /// All [Rel]s to be converted pass through this method. + /// You can provide your own implementation if you wish to customize the conversion behaviour. + async fn consume_rel(&self, rel: &Rel) -> Result { + from_substrait_rel(self, rel).await + } + async fn consume_read(&self, rel: &ReadRel) -> Result { from_read_rel(self, rel).await } @@ -285,6 +291,16 @@ pub trait SubstraitConsumer: Send + Sync + Sized { // These methods have default implementations calling the common handler code, to allow for users // to re-use common handling logic. + /// All [Expression]s to be converted pass through this method. + /// You can provide your own implementation if you wish to customize the conversion behaviour. + async fn consume_expression( + &self, + expr: &Expression, + input_schema: &DFSchema, + ) -> Result { + from_substrait_rex(self, expr, input_schema).await + } + async fn consume_literal(&self, expr: &Literal) -> Result { from_literal(self, expr).await } @@ -535,7 +551,7 @@ impl SubstraitConsumer for DefaultSubstraitConsumer<'_> { "ExtensionSingleRel missing input rel, try using ExtensionLeafRel instead" ); }; - let input_plan = from_substrait_rel(self, input_rel).await?; + let input_plan = self.consume_rel(input_rel).await?; let plan = plan.with_exprs_and_inputs(plan.expressions(), vec![input_plan])?; Ok(LogicalPlan::Extension(Extension { node: plan })) } @@ -553,7 +569,7 @@ impl SubstraitConsumer for DefaultSubstraitConsumer<'_> { .deserialize_logical_plan(&ext_detail.type_url, &ext_detail.value)?; let mut inputs = Vec::with_capacity(rel.inputs.len()); for input in &rel.inputs { - let input_plan = from_substrait_rel(self, input).await?; + let input_plan = self.consume_rel(input).await?; inputs.push(input_plan); } let plan = plan.with_exprs_and_inputs(plan.expressions(), inputs)?; @@ -666,10 +682,10 @@ async fn union_rels( is_all: bool, ) -> Result { let mut union_builder = Ok(LogicalPlanBuilder::from( - from_substrait_rel(consumer, &rels[0]).await?, + consumer.consume_rel(&rels[0]).await?, )); for input in &rels[1..] { - let rel_plan = from_substrait_rel(consumer, input).await?; + let rel_plan = consumer.consume_rel(input).await?; union_builder = if is_all { union_builder?.union(rel_plan) @@ -685,12 +701,12 @@ async fn intersect_rels( rels: &[Rel], is_all: bool, ) -> Result { - let mut rel = from_substrait_rel(consumer, &rels[0]).await?; + let mut rel = consumer.consume_rel(&rels[0]).await?; for input in &rels[1..] { rel = LogicalPlanBuilder::intersect( rel, - from_substrait_rel(consumer, input).await?, + consumer.consume_rel(input).await?, is_all, )? } @@ -703,14 +719,10 @@ async fn except_rels( rels: &[Rel], is_all: bool, ) -> Result { - let mut rel = from_substrait_rel(consumer, &rels[0]).await?; + let mut rel = consumer.consume_rel(&rels[0]).await?; for input in &rels[1..] { - rel = LogicalPlanBuilder::except( - rel, - from_substrait_rel(consumer, input).await?, - is_all, - )? + rel = LogicalPlanBuilder::except(rel, consumer.consume_rel(input).await?, is_all)? } Ok(rel) @@ -743,11 +755,9 @@ pub async fn from_substrait_plan_with_consumer( 1 => { match plan.relations[0].rel_type.as_ref() { Some(rt) => match rt { - plan_rel::RelType::Rel(rel) => { - Ok(from_substrait_rel(consumer, rel).await?) - }, + plan_rel::RelType::Rel(rel) => Ok(consumer.consume_rel(rel).await?), plan_rel::RelType::Root(root) => { - let plan = from_substrait_rel(consumer, root.input.as_ref().unwrap()).await?; + let plan = consumer.consume_rel(root.input.as_ref().unwrap()).await?; if root.names.is_empty() { // Backwards compatibility for plans missing names return Ok(plan); @@ -841,7 +851,9 @@ pub async fn from_substrait_extended_expr( plan_err!("required property `expr_type` missing from Substrait ExpressionReference message") } }?; - let expr = from_substrait_rex(&consumer, scalar_expr, &input_schema).await?; + let expr = consumer + .consume_expression(scalar_expr, &input_schema) + .await?; let (output_type, expected_nullability) = expr.data_type_and_nullable(&input_schema)?; let output_field = Field::new("", output_type, expected_nullability); @@ -1034,8 +1046,7 @@ pub async fn from_project_rel( p: &ProjectRel, ) -> Result { if let Some(input) = p.input.as_ref() { - let mut input = - LogicalPlanBuilder::from(from_substrait_rel(consumer, input).await?); + let mut input = LogicalPlanBuilder::from(consumer.consume_rel(input).await?); let original_schema = input.schema().clone(); // Ensure that all expressions have a unique display name, so that @@ -1052,7 +1063,9 @@ pub async fn from_project_rel( let mut explicit_exprs: Vec = vec![]; for expr in &p.expressions { - let e = from_substrait_rex(consumer, expr, input.clone().schema()).await?; + let e = consumer + .consume_expression(expr, input.clone().schema()) + .await?; // if the expression is WindowFunction, wrap in a Window relation if let Expr::WindowFunction(_) = &e { // Adding the same expression here and in the project below @@ -1081,9 +1094,11 @@ pub async fn from_filter_rel( filter: &FilterRel, ) -> Result { if let Some(input) = filter.input.as_ref() { - let input = LogicalPlanBuilder::from(from_substrait_rel(consumer, input).await?); + let input = LogicalPlanBuilder::from(consumer.consume_rel(input).await?); if let Some(condition) = filter.condition.as_ref() { - let expr = from_substrait_rex(consumer, condition, input.schema()).await?; + let expr = consumer + .consume_expression(condition, input.schema()) + .await?; input.filter(expr)?.build() } else { not_impl_err!("Filter without an condition is not valid") @@ -1099,12 +1114,12 @@ pub async fn from_fetch_rel( fetch: &FetchRel, ) -> Result { if let Some(input) = fetch.input.as_ref() { - let input = LogicalPlanBuilder::from(from_substrait_rel(consumer, input).await?); + let input = LogicalPlanBuilder::from(consumer.consume_rel(input).await?); let empty_schema = DFSchemaRef::new(DFSchema::empty()); let offset = match &fetch.offset_mode { Some(fetch_rel::OffsetMode::Offset(offset)) => Some(lit(*offset)), Some(fetch_rel::OffsetMode::OffsetExpr(expr)) => { - Some(from_substrait_rex(consumer, expr, &empty_schema).await?) + Some(consumer.consume_expression(expr, &empty_schema).await?) } None => None, }; @@ -1114,7 +1129,7 @@ pub async fn from_fetch_rel( (*count != -1).then(|| lit(*count)) } Some(fetch_rel::CountMode::CountExpr(expr)) => { - Some(from_substrait_rex(consumer, expr, &empty_schema).await?) + Some(consumer.consume_expression(expr, &empty_schema).await?) } None => None, }; @@ -1129,7 +1144,7 @@ pub async fn from_sort_rel( sort: &SortRel, ) -> Result { if let Some(input) = sort.input.as_ref() { - let input = LogicalPlanBuilder::from(from_substrait_rel(consumer, input).await?); + let input = LogicalPlanBuilder::from(consumer.consume_rel(input).await?); let sorts = from_substrait_sorts(consumer, &sort.sorts, input.schema()).await?; input.sort(sorts)?.build() } else { @@ -1142,11 +1157,11 @@ pub async fn from_aggregate_rel( agg: &AggregateRel, ) -> Result { if let Some(input) = agg.input.as_ref() { - let input = LogicalPlanBuilder::from(from_substrait_rel(consumer, input).await?); + let input = LogicalPlanBuilder::from(consumer.consume_rel(input).await?); let mut ref_group_exprs = vec![]; for e in &agg.grouping_expressions { - let x = from_substrait_rex(consumer, e, input.schema()).await?; + let x = consumer.consume_expression(e, input.schema()).await?; ref_group_exprs.push(x); } @@ -1189,7 +1204,7 @@ pub async fn from_aggregate_rel( for m in &agg.measures { let filter = match &m.filter { Some(fil) => Some(Box::new( - from_substrait_rex(consumer, fil, input.schema()).await?, + consumer.consume_expression(fil, input.schema()).await?, )), None => None, }; @@ -1242,10 +1257,10 @@ pub async fn from_join_rel( } let left: LogicalPlanBuilder = LogicalPlanBuilder::from( - from_substrait_rel(consumer, join.left.as_ref().unwrap()).await?, + consumer.consume_rel(join.left.as_ref().unwrap()).await?, ); let right = LogicalPlanBuilder::from( - from_substrait_rel(consumer, join.right.as_ref().unwrap()).await?, + consumer.consume_rel(join.right.as_ref().unwrap()).await?, ); let (left, right) = requalify_sides_if_needed(left, right)?; @@ -1258,7 +1273,7 @@ pub async fn from_join_rel( // Otherwise, build join with only the filter, without join keys match &join.expression.as_ref() { Some(expr) => { - let on = from_substrait_rex(consumer, expr, &in_join_schema).await?; + let on = consumer.consume_expression(expr, &in_join_schema).await?; // The join expression can contain both equal and non-equal ops. // As of datafusion 31.0.0, the equal and non equal join conditions are in separate fields. // So we extract each part as follows: @@ -1290,10 +1305,10 @@ pub async fn from_cross_rel( cross: &CrossRel, ) -> Result { let left = LogicalPlanBuilder::from( - from_substrait_rel(consumer, cross.left.as_ref().unwrap()).await?, + consumer.consume_rel(cross.left.as_ref().unwrap()).await?, ); let right = LogicalPlanBuilder::from( - from_substrait_rel(consumer, cross.right.as_ref().unwrap()).await?, + consumer.consume_rel(cross.right.as_ref().unwrap()).await?, ); let (left, right) = requalify_sides_if_needed(left, right)?; left.cross_join(right.build()?)?.build() @@ -1466,7 +1481,7 @@ pub async fn from_set_rel( SetOp::UnionAll => union_rels(consumer, &set.inputs, true).await, SetOp::UnionDistinct => union_rels(consumer, &set.inputs, false).await, SetOp::IntersectionPrimary => LogicalPlanBuilder::intersect( - from_substrait_rel(consumer, &set.inputs[0]).await?, + consumer.consume_rel(&set.inputs[0]).await?, union_rels(consumer, &set.inputs[1..], true).await?, false, ), @@ -1490,7 +1505,7 @@ pub async fn from_exchange_rel( let Some(input) = exchange.input.as_ref() else { return substrait_err!("Unexpected empty input in ExchangeRel"); }; - let input = Arc::new(from_substrait_rel(consumer, input).await?); + let input = Arc::new(consumer.consume_rel(input).await?); let Some(exchange_kind) = &exchange.exchange_kind else { return substrait_err!("Unexpected empty input in ExchangeRel"); @@ -1822,8 +1837,9 @@ pub async fn from_substrait_sorts( ) -> Result> { let mut sorts: Vec = vec![]; for s in substrait_sorts { - let expr = - from_substrait_rex(consumer, s.expr.as_ref().unwrap(), input_schema).await?; + let expr = consumer + .consume_expression(s.expr.as_ref().unwrap(), input_schema) + .await?; let asc_nullfirst = match &s.sort_kind { Some(k) => match k { Direction(d) => { @@ -1870,7 +1886,7 @@ pub async fn from_substrait_rex_vec( ) -> Result> { let mut expressions: Vec = vec![]; for expr in exprs { - let expression = from_substrait_rex(consumer, expr, input_schema).await?; + let expression = consumer.consume_expression(expr, input_schema).await?; expressions.push(expression); } Ok(expressions) @@ -1885,9 +1901,7 @@ pub async fn from_substrait_func_args( let mut args: Vec = vec![]; for arg in arguments { let arg_expr = match &arg.arg_type { - Some(ArgType::Value(e)) => { - from_substrait_rex(consumer, e, input_schema).await - } + Some(ArgType::Value(e)) => consumer.consume_expression(e, input_schema).await, _ => not_impl_err!("Function argument non-Value type not supported"), }; args.push(arg_expr?); @@ -1991,7 +2005,11 @@ pub async fn from_singular_or_list( let substrait_expr = expr.value.as_ref().unwrap(); let substrait_list = expr.options.as_ref(); Ok(Expr::InList(InList { - expr: Box::new(from_substrait_rex(consumer, substrait_expr, input_schema).await?), + expr: Box::new( + consumer + .consume_expression(substrait_expr, input_schema) + .await?, + ), list: from_substrait_rex_vec(consumer, substrait_list, input_schema).await?, negated: false, })) @@ -2019,39 +2037,30 @@ pub async fn from_if_then( // Check if the first element is type base expression if if_expr.then.is_none() { expr = Some(Box::new( - from_substrait_rex( - consumer, - if_expr.r#if.as_ref().unwrap(), - input_schema, - ) - .await?, + consumer + .consume_expression(if_expr.r#if.as_ref().unwrap(), input_schema) + .await?, )); continue; } } when_then_expr.push(( Box::new( - from_substrait_rex( - consumer, - if_expr.r#if.as_ref().unwrap(), - input_schema, - ) - .await?, + consumer + .consume_expression(if_expr.r#if.as_ref().unwrap(), input_schema) + .await?, ), Box::new( - from_substrait_rex( - consumer, - if_expr.then.as_ref().unwrap(), - input_schema, - ) - .await?, + consumer + .consume_expression(if_expr.then.as_ref().unwrap(), input_schema) + .await?, ), )); } // Parse `else` let else_expr = match &if_then.r#else { Some(e) => Some(Box::new( - from_substrait_rex(consumer, e, input_schema).await?, + consumer.consume_expression(e, input_schema).await?, )), None => None, }; @@ -2134,12 +2143,12 @@ pub async fn from_cast( match cast.r#type.as_ref() { Some(output_type) => { let input_expr = Box::new( - from_substrait_rex( - consumer, - cast.input.as_ref().unwrap().as_ref(), - input_schema, - ) - .await?, + consumer + .consume_expression( + cast.input.as_ref().unwrap().as_ref(), + input_schema, + ) + .await?, ); let data_type = from_substrait_type_without_names(consumer, output_type)?; if cast.failure_behavior() == ReturnNull { @@ -2229,12 +2238,12 @@ pub async fn from_subquery( let needle_expr = &in_predicate.needles[0]; let haystack_expr = &in_predicate.haystack; if let Some(haystack_expr) = haystack_expr { - let haystack_expr = - from_substrait_rel(consumer, haystack_expr).await?; + let haystack_expr = consumer.consume_rel(haystack_expr).await?; let outer_refs = haystack_expr.all_out_ref_exprs(); Ok(Expr::InSubquery(InSubquery { expr: Box::new( - from_substrait_rex(consumer, needle_expr, input_schema) + consumer + .consume_expression(needle_expr, input_schema) .await?, ), subquery: Subquery { @@ -2251,11 +2260,9 @@ pub async fn from_subquery( } } SubqueryType::Scalar(query) => { - let plan = from_substrait_rel( - consumer, - &(query.input.clone()).unwrap_or_default(), - ) - .await?; + let plan = consumer + .consume_rel(&(query.input.clone()).unwrap_or_default()) + .await?; let outer_ref_columns = plan.all_out_ref_exprs(); Ok(Expr::ScalarSubquery(Subquery { subquery: Arc::new(plan), @@ -2267,11 +2274,9 @@ pub async fn from_subquery( // exist PredicateOp::Exists => { let relation = &predicate.tuples; - let plan = from_substrait_rel( - consumer, - &relation.clone().unwrap_or_default(), - ) - .await?; + let plan = consumer + .consume_rel(&relation.clone().unwrap_or_default()) + .await?; let outer_ref_columns = plan.all_out_ref_exprs(); Ok(Expr::Exists(Exists::new( Subquery { @@ -2909,8 +2914,10 @@ fn from_substrait_literal( } builder.build()? } - Some(LiteralType::Null(ntype)) => { - from_substrait_null(consumer, ntype, dfs_names, name_idx)? + Some(LiteralType::Null(null_type)) => { + let data_type = + from_substrait_type(consumer, null_type, dfs_names, name_idx)?; + ScalarValue::try_from(&data_type)? } Some(LiteralType::IntervalDayToSecond(IntervalDayToSecond { days, @@ -3082,180 +3089,6 @@ fn from_substrait_literal( Ok(scalar_value) } -fn from_substrait_null( - consumer: &impl SubstraitConsumer, - null_type: &Type, - dfs_names: &[String], - name_idx: &mut usize, -) -> Result { - if let Some(kind) = &null_type.kind { - match kind { - r#type::Kind::Bool(_) => Ok(ScalarValue::Boolean(None)), - r#type::Kind::I8(integer) => match integer.type_variation_reference { - DEFAULT_TYPE_VARIATION_REF => Ok(ScalarValue::Int8(None)), - UNSIGNED_INTEGER_TYPE_VARIATION_REF => Ok(ScalarValue::UInt8(None)), - v => not_impl_err!( - "Unsupported Substrait type variation {v} of type {kind:?}" - ), - }, - r#type::Kind::I16(integer) => match integer.type_variation_reference { - DEFAULT_TYPE_VARIATION_REF => Ok(ScalarValue::Int16(None)), - UNSIGNED_INTEGER_TYPE_VARIATION_REF => Ok(ScalarValue::UInt16(None)), - v => not_impl_err!( - "Unsupported Substrait type variation {v} of type {kind:?}" - ), - }, - r#type::Kind::I32(integer) => match integer.type_variation_reference { - DEFAULT_TYPE_VARIATION_REF => Ok(ScalarValue::Int32(None)), - UNSIGNED_INTEGER_TYPE_VARIATION_REF => Ok(ScalarValue::UInt32(None)), - v => not_impl_err!( - "Unsupported Substrait type variation {v} of type {kind:?}" - ), - }, - r#type::Kind::I64(integer) => match integer.type_variation_reference { - DEFAULT_TYPE_VARIATION_REF => Ok(ScalarValue::Int64(None)), - UNSIGNED_INTEGER_TYPE_VARIATION_REF => Ok(ScalarValue::UInt64(None)), - v => not_impl_err!( - "Unsupported Substrait type variation {v} of type {kind:?}" - ), - }, - r#type::Kind::Fp32(_) => Ok(ScalarValue::Float32(None)), - r#type::Kind::Fp64(_) => Ok(ScalarValue::Float64(None)), - r#type::Kind::Timestamp(ts) => { - // Kept for backwards compatibility, new plans should use PrecisionTimestamp(Tz) instead - #[allow(deprecated)] - match ts.type_variation_reference { - TIMESTAMP_SECOND_TYPE_VARIATION_REF => { - Ok(ScalarValue::TimestampSecond(None, None)) - } - TIMESTAMP_MILLI_TYPE_VARIATION_REF => { - Ok(ScalarValue::TimestampMillisecond(None, None)) - } - TIMESTAMP_MICRO_TYPE_VARIATION_REF => { - Ok(ScalarValue::TimestampMicrosecond(None, None)) - } - TIMESTAMP_NANO_TYPE_VARIATION_REF => { - Ok(ScalarValue::TimestampNanosecond(None, None)) - } - v => not_impl_err!( - "Unsupported Substrait type variation {v} of type {kind:?}" - ), - } - } - r#type::Kind::PrecisionTimestamp(pts) => match pts.precision { - 0 => Ok(ScalarValue::TimestampSecond(None, None)), - 3 => Ok(ScalarValue::TimestampMillisecond(None, None)), - 6 => Ok(ScalarValue::TimestampMicrosecond(None, None)), - 9 => Ok(ScalarValue::TimestampNanosecond(None, None)), - p => not_impl_err!( - "Unsupported Substrait precision {p} for PrecisionTimestamp" - ), - }, - r#type::Kind::PrecisionTimestampTz(pts) => match pts.precision { - 0 => Ok(ScalarValue::TimestampSecond( - None, - Some(DEFAULT_TIMEZONE.into()), - )), - 3 => Ok(ScalarValue::TimestampMillisecond( - None, - Some(DEFAULT_TIMEZONE.into()), - )), - 6 => Ok(ScalarValue::TimestampMicrosecond( - None, - Some(DEFAULT_TIMEZONE.into()), - )), - 9 => Ok(ScalarValue::TimestampNanosecond( - None, - Some(DEFAULT_TIMEZONE.into()), - )), - p => not_impl_err!( - "Unsupported Substrait precision {p} for PrecisionTimestamp" - ), - }, - r#type::Kind::Date(date) => match date.type_variation_reference { - DATE_32_TYPE_VARIATION_REF => Ok(ScalarValue::Date32(None)), - DATE_64_TYPE_VARIATION_REF => Ok(ScalarValue::Date64(None)), - v => not_impl_err!( - "Unsupported Substrait type variation {v} of type {kind:?}" - ), - }, - r#type::Kind::Binary(binary) => match binary.type_variation_reference { - DEFAULT_CONTAINER_TYPE_VARIATION_REF => Ok(ScalarValue::Binary(None)), - LARGE_CONTAINER_TYPE_VARIATION_REF => Ok(ScalarValue::LargeBinary(None)), - v => not_impl_err!( - "Unsupported Substrait type variation {v} of type {kind:?}" - ), - }, - // FixedBinary is not supported because `None` doesn't have length - r#type::Kind::String(string) => match string.type_variation_reference { - DEFAULT_CONTAINER_TYPE_VARIATION_REF => Ok(ScalarValue::Utf8(None)), - LARGE_CONTAINER_TYPE_VARIATION_REF => Ok(ScalarValue::LargeUtf8(None)), - v => not_impl_err!( - "Unsupported Substrait type variation {v} of type {kind:?}" - ), - }, - r#type::Kind::Decimal(d) => Ok(ScalarValue::Decimal128( - None, - d.precision as u8, - d.scale as i8, - )), - r#type::Kind::List(l) => { - let field = Field::new_list_field( - from_substrait_type( - consumer, - l.r#type.clone().unwrap().as_ref(), - dfs_names, - name_idx, - )?, - true, - ); - match l.type_variation_reference { - DEFAULT_CONTAINER_TYPE_VARIATION_REF => Ok(ScalarValue::List( - Arc::new(GenericListArray::new_null(field.into(), 1)), - )), - LARGE_CONTAINER_TYPE_VARIATION_REF => Ok(ScalarValue::LargeList( - Arc::new(GenericListArray::new_null(field.into(), 1)), - )), - v => not_impl_err!( - "Unsupported Substrait type variation {v} of type {kind:?}" - ), - } - } - r#type::Kind::Map(map) => { - let key_type = map.key.as_ref().ok_or_else(|| { - substrait_datafusion_err!("Map type must have key type") - })?; - let value_type = map.value.as_ref().ok_or_else(|| { - substrait_datafusion_err!("Map type must have value type") - })?; - - let key_type = - from_substrait_type(consumer, key_type, dfs_names, name_idx)?; - let value_type = - from_substrait_type(consumer, value_type, dfs_names, name_idx)?; - let entries_field = Arc::new(Field::new_struct( - "entries", - vec![ - Field::new("key", key_type, false), - Field::new("value", value_type, true), - ], - false, - )); - - DataType::Map(entries_field, false /* keys sorted */).try_into() - } - r#type::Kind::Struct(s) => { - let fields = - from_substrait_struct_type(consumer, s, dfs_names, name_idx)?; - Ok(ScalarStructBuilder::new_null(fields)) - } - _ => not_impl_err!("Unsupported Substrait type for null: {kind:?}"), - } - } else { - not_impl_err!("Null type without kind is not supported") - } -} - #[allow(deprecated)] async fn from_substrait_grouping( consumer: &impl SubstraitConsumer, @@ -3266,7 +3099,7 @@ async fn from_substrait_grouping( let mut group_exprs = vec![]; if !grouping.grouping_expressions.is_empty() { for e in &grouping.grouping_expressions { - let expr = from_substrait_rex(consumer, e, input_schema).await?; + let expr = consumer.consume_expression(e, input_schema).await?; group_exprs.push(expr); } return Ok(group_exprs); @@ -3349,7 +3182,9 @@ impl BuiltinExprBuilder { let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type else { return substrait_err!("Invalid arguments type for {fn_name} expr"); }; - let arg = from_substrait_rex(consumer, expr_substrait, input_schema).await?; + let arg = consumer + .consume_expression(expr_substrait, input_schema) + .await?; let arg = Box::new(arg); let expr = match fn_name { @@ -3383,12 +3218,15 @@ impl BuiltinExprBuilder { let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type else { return substrait_err!("Invalid arguments type for `{fn_name}` expr"); }; - let expr = from_substrait_rex(consumer, expr_substrait, input_schema).await?; + let expr = consumer + .consume_expression(expr_substrait, input_schema) + .await?; let Some(ArgType::Value(pattern_substrait)) = &f.arguments[1].arg_type else { return substrait_err!("Invalid arguments type for `{fn_name}` expr"); }; - let pattern = - from_substrait_rex(consumer, pattern_substrait, input_schema).await?; + let pattern = consumer + .consume_expression(pattern_substrait, input_schema) + .await?; // Default case: escape character is Literal(Utf8(None)) let escape_char = if f.arguments.len() == 3 { @@ -3397,8 +3235,9 @@ impl BuiltinExprBuilder { return substrait_err!("Invalid arguments type for `{fn_name}` expr"); }; - let escape_char_expr = - from_substrait_rex(consumer, escape_char_substrait, input_schema).await?; + let escape_char_expr = consumer + .consume_expression(escape_char_substrait, input_schema) + .await?; match escape_char_expr { Expr::Literal(ScalarValue::Utf8(escape_char_string)) => {