From dfdda7cb04f7f9b640da4f297ce1a16b08f3bf7b Mon Sep 17 00:00:00 2001 From: Arttu Date: Wed, 12 Jun 2024 18:40:40 +0200 Subject: [PATCH] fix: Ignore nullability of list elements when consuming Substrait (#10874) * Ignore nullability of list elements when consuming Substrait DataFusion (= Arrow) is quite strict about nullability, specifically, when using e.g. LogicalPlan::Values, the given schema must match the given literals exactly - including nullability. This is non-trivial to do when converting schema and literals separately. The existing implementation for from_substrait_literal already creates lists that are always nullable (see ScalarValue::new_list => array_into_list_array). This reverts part of https://github.com/apache/datafusion/pull/10640 to align from_substrait_type with that behavior. This is the error I was hitting: ``` ArrowError(InvalidArgumentError("column types must match schema types, expected List(Field { name: \"item\", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }) but found List(Field { name: \"item\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) at column index 0"), None) ``` * use `Field::new_list_field` in `array_into_(large_)list_array` just for consistency, to reduce the places where "item" is written out * add a test for non-nullable lists --- datafusion/common/src/utils/mod.rs | 14 ++-- .../substrait/src/logical_plan/consumer.rs | 4 +- .../substrait/src/logical_plan/producer.rs | 14 ++-- .../substrait/tests/cases/logical_plans.rs | 32 +++++++-- .../non_nullable_lists.substrait.json | 71 +++++++++++++++++++ 5 files changed, 114 insertions(+), 21 deletions(-) create mode 100644 datafusion/substrait/tests/testdata/non_nullable_lists.substrait.json diff --git a/datafusion/common/src/utils/mod.rs b/datafusion/common/src/utils/mod.rs index ae444c2cb285..a0e4d1a76c03 100644 --- a/datafusion/common/src/utils/mod.rs +++ b/datafusion/common/src/utils/mod.rs @@ -354,7 +354,7 @@ pub fn longest_consecutive_prefix>( pub fn array_into_list_array(arr: ArrayRef) -> ListArray { let offsets = OffsetBuffer::from_lengths([arr.len()]); ListArray::new( - Arc::new(Field::new("item", arr.data_type().to_owned(), true)), + Arc::new(Field::new_list_field(arr.data_type().to_owned(), true)), offsets, arr, None, @@ -366,7 +366,7 @@ pub fn array_into_list_array(arr: ArrayRef) -> ListArray { pub fn array_into_large_list_array(arr: ArrayRef) -> LargeListArray { let offsets = OffsetBuffer::from_lengths([arr.len()]); LargeListArray::new( - Arc::new(Field::new("item", arr.data_type().to_owned(), true)), + Arc::new(Field::new_list_field(arr.data_type().to_owned(), true)), offsets, arr, None, @@ -379,7 +379,7 @@ pub fn array_into_fixed_size_list_array( ) -> FixedSizeListArray { let list_size = list_size as i32; FixedSizeListArray::new( - Arc::new(Field::new("item", arr.data_type().to_owned(), true)), + Arc::new(Field::new_list_field(arr.data_type().to_owned(), true)), list_size, arr, None, @@ -420,7 +420,7 @@ pub fn arrays_into_list_array( let data_type = arr[0].data_type().to_owned(); let values = arr.iter().map(|x| x.as_ref()).collect::>(); Ok(ListArray::new( - Arc::new(Field::new("item", data_type, true)), + Arc::new(Field::new_list_field(data_type, true)), OffsetBuffer::from_lengths(lens), arrow::compute::concat(values.as_slice())?, None, @@ -435,7 +435,7 @@ pub fn arrays_into_list_array( /// use datafusion_common::utils::base_type; /// use std::sync::Arc; /// -/// let data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); +/// let data_type = DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))); /// assert_eq!(base_type(&data_type), DataType::Int32); /// /// let data_type = DataType::Int32; @@ -458,10 +458,10 @@ pub fn base_type(data_type: &DataType) -> DataType { /// use datafusion_common::utils::coerced_type_with_base_type_only; /// use std::sync::Arc; /// -/// let data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); +/// let data_type = DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))); /// let base_type = DataType::Float64; /// let coerced_type = coerced_type_with_base_type_only(&data_type, &base_type); -/// assert_eq!(coerced_type, DataType::List(Arc::new(Field::new("item", DataType::Float64, true)))); +/// assert_eq!(coerced_type, DataType::List(Arc::new(Field::new_list_field(DataType::Float64, true)))); pub fn coerced_type_with_base_type_only( data_type: &DataType, base_type: &DataType, diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 648a281832e1..3f9a895d951c 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -1395,7 +1395,9 @@ fn from_substrait_type( })?; let field = Arc::new(Field::new_list_field( from_substrait_type(inner_type, dfs_names, name_idx)?, - is_substrait_type_nullable(inner_type)?, + // We ignore Substrait's nullability here to match to_substrait_literal + // which always creates nullable lists + true, )); match list.type_variation_reference { DEFAULT_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::List(field)), diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 88dc894eccd2..c0469d333164 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -2309,14 +2309,12 @@ mod test { round_trip_type(DataType::Decimal128(10, 2))?; round_trip_type(DataType::Decimal256(30, 2))?; - for nullable in [true, false] { - round_trip_type(DataType::List( - Field::new_list_field(DataType::Int32, nullable).into(), - ))?; - round_trip_type(DataType::LargeList( - Field::new_list_field(DataType::Int32, nullable).into(), - ))?; - } + round_trip_type(DataType::List( + Field::new_list_field(DataType::Int32, true).into(), + ))?; + round_trip_type(DataType::LargeList( + Field::new_list_field(DataType::Int32, true).into(), + ))?; round_trip_type(DataType::Struct( vec![ diff --git a/datafusion/substrait/tests/cases/logical_plans.rs b/datafusion/substrait/tests/cases/logical_plans.rs index 994a932c30e0..94572e098b2c 100644 --- a/datafusion/substrait/tests/cases/logical_plans.rs +++ b/datafusion/substrait/tests/cases/logical_plans.rs @@ -20,6 +20,7 @@ #[cfg(test)] mod tests { use datafusion::common::Result; + use datafusion::dataframe::DataFrame; use datafusion::prelude::{CsvReadOptions, SessionContext}; use datafusion_substrait::logical_plan::consumer::from_substrait_plan; use std::fs::File; @@ -38,11 +39,7 @@ mod tests { // File generated with substrait-java's Isthmus: // ./isthmus-cli/build/graal/isthmus "select not d from data" -c "create table data (d boolean)" - let path = "tests/testdata/select_not_bool.substrait.json"; - let proto = serde_json::from_reader::<_, Plan>(BufReader::new( - File::open(path).expect("file not found"), - )) - .expect("failed to parse json"); + let proto = read_json("tests/testdata/select_not_bool.substrait.json"); let plan = from_substrait_plan(&ctx, &proto).await?; @@ -54,6 +51,31 @@ mod tests { Ok(()) } + #[tokio::test] + async fn non_nullable_lists() -> Result<()> { + // DataFusion's Substrait consumer treats all lists as nullable, even if the Substrait plan specifies them as non-nullable. + // That's because implementing the non-nullability consistently is non-trivial. + // This test confirms that reading a plan with non-nullable lists works as expected. + let ctx = create_context().await?; + let proto = read_json("tests/testdata/non_nullable_lists.substrait.json"); + + let plan = from_substrait_plan(&ctx, &proto).await?; + + assert_eq!(format!("{:?}", &plan), "Values: (List([1, 2]))"); + + // Need to trigger execution to ensure that Arrow has validated the plan + DataFrame::new(ctx.state(), plan).show().await?; + + Ok(()) + } + + fn read_json(path: &str) -> Plan { + serde_json::from_reader::<_, Plan>(BufReader::new( + File::open(path).expect("file not found"), + )) + .expect("failed to parse json") + } + async fn create_context() -> datafusion::common::Result { let ctx = SessionContext::new(); ctx.register_csv("DATA", "tests/testdata/data.csv", CsvReadOptions::new()) diff --git a/datafusion/substrait/tests/testdata/non_nullable_lists.substrait.json b/datafusion/substrait/tests/testdata/non_nullable_lists.substrait.json new file mode 100644 index 000000000000..e1c5574f8bec --- /dev/null +++ b/datafusion/substrait/tests/testdata/non_nullable_lists.substrait.json @@ -0,0 +1,71 @@ +{ + "extensionUris": [], + "extensions": [], + "relations": [ + { + "root": { + "input": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": [ + "col" + ], + "struct": { + "types": [ + { + "list": { + "type": { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "virtualTable": { + "values": [ + { + "fields": [ + { + "list": { + "values": [ + { + "i32": 1, + "nullable": false, + "typeVariationReference": 0 + }, + { + "i32": 2, + "nullable": false, + "typeVariationReference": 0 + } + ] + }, + "nullable": false, + "typeVariationReference": 0 + } + ] + } + ] + } + } + }, + "names": [ + "col" + ] + } + } + ], + "expectedTypeUrls": [] +}