From ea92ae72f7ec2e941d35aa077c6a39f74523ab63 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Sat, 25 May 2024 21:12:05 +0800 Subject: [PATCH] support null for get field (#10665) Signed-off-by: jayzhan211 --- datafusion/expr/src/field_util.rs | 100 ---------------------- datafusion/expr/src/lib.rs | 1 - datafusion/functions/src/core/getfield.rs | 43 ++++++++-- 3 files changed, 34 insertions(+), 110 deletions(-) delete mode 100644 datafusion/expr/src/field_util.rs diff --git a/datafusion/expr/src/field_util.rs b/datafusion/expr/src/field_util.rs deleted file mode 100644 index f0ce61ee9bbb..000000000000 --- a/datafusion/expr/src/field_util.rs +++ /dev/null @@ -1,100 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Utility functions for complex field access - -use arrow::datatypes::{DataType, Field}; -use datafusion_common::{plan_datafusion_err, plan_err, Result, ScalarValue}; - -/// Types of the field access expression of a nested type, such as `Field` or `List` -pub enum GetFieldAccessSchema { - /// Named field, For example `struct["name"]` - NamedStructField { name: ScalarValue }, - /// Single list index, for example: `list[i]` - ListIndex { key_dt: DataType }, - /// List stride, for example `list[i:j:k]` - ListRange { - start_dt: DataType, - stop_dt: DataType, - stride_dt: DataType, - }, -} - -impl GetFieldAccessSchema { - /// Returns the schema [`Field`] from a [`DataType::List`] or - /// [`DataType::Struct`] indexed by this structure - /// - /// # Error - /// Errors if - /// * the `data_type` is not a Struct or a List, - /// * the `data_type` of the name/index/start-stop do not match a supported index type - pub fn get_accessed_field(&self, data_type: &DataType) -> Result { - match self { - Self::NamedStructField{ name } => { - match (data_type, name) { - (DataType::Map(fields, _), _) => { - match fields.data_type() { - DataType::Struct(fields) if fields.len() == 2 => { - // Arrow's MapArray is essentially a ListArray of structs with two columns. They are - // often named "key", and "value", but we don't require any specific naming here; - // instead, we assume that the second columnis the "value" column both here and in - // execution. - let value_field = fields.get(1).expect("fields should have exactly two members"); - Ok(Field::new("map", value_field.data_type().clone(), true)) - }, - _ => plan_err!("Map fields must contain a Struct with exactly 2 fields"), - } - } - (DataType::Struct(fields), ScalarValue::Utf8(Some(s))) => { - if s.is_empty() { - plan_err!( - "Struct based indexed access requires a non empty string" - ) - } else { - let field = fields.iter().find(|f| f.name() == s); - field.ok_or(plan_datafusion_err!("Field {s} not found in struct")).map(|f| f.as_ref().clone()) - } - } - (DataType::Struct(_), _) => plan_err!( - "Only utf8 strings are valid as an indexed field in a struct" - ), - (other, _) => plan_err!("The expression to get an indexed field is only valid for `List`, `Struct`, or `Map` types, got {other}"), - } - } - Self::ListIndex{ key_dt } => { - match (data_type, key_dt) { - (DataType::List(lt), DataType::Int64) => Ok(Field::new("list", lt.data_type().clone(), true)), - (DataType::LargeList(lt), DataType::Int64) => Ok(Field::new("large_list", lt.data_type().clone(), true)), - (DataType::List(_), _) | (DataType::LargeList(_), _) => plan_err!( - "Only ints are valid as an indexed field in a List/LargeList" - ), - (other, _) => plan_err!("The expression to get an indexed field is only valid for `List`, `LargeList` or `Struct` types, got {other}"), - } - } - Self::ListRange { start_dt, stop_dt, stride_dt } => { - match (data_type, start_dt, stop_dt, stride_dt) { - (DataType::List(_), DataType::Int64, DataType::Int64, DataType::Int64) => Ok(Field::new("list", data_type.clone(), true)), - (DataType::LargeList(_), DataType::Int64, DataType::Int64, DataType::Int64) => Ok(Field::new("large_list", data_type.clone(), true)), - (DataType::List(_), _, _, _) | (DataType::LargeList(_), _, _, _)=> plan_err!( - "Only ints are valid as an indexed field in a List/LargeList" - ), - (other, _, _, _) => plan_err!("The expression to get an indexed field is only valid for `List`, `LargeList` or `Struct` types, got {other}"), - } - } - } - } -} diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index bac2f9c14541..c491a2656470 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -44,7 +44,6 @@ pub mod expr; pub mod expr_fn; pub mod expr_rewriter; pub mod expr_schema; -pub mod field_util; pub mod function; pub mod groups_accumulator; pub mod interval_arithmetic; diff --git a/datafusion/functions/src/core/getfield.rs b/datafusion/functions/src/core/getfield.rs index 0013655e6dd9..b76da15c52ca 100644 --- a/datafusion/functions/src/core/getfield.rs +++ b/datafusion/functions/src/core/getfield.rs @@ -20,8 +20,9 @@ use arrow::array::{ }; use arrow::datatypes::DataType; use datafusion_common::cast::{as_map_array, as_struct_array}; -use datafusion_common::{exec_err, ExprSchema, Result, ScalarValue}; -use datafusion_expr::field_util::GetFieldAccessSchema; +use datafusion_common::{ + exec_err, plan_datafusion_err, plan_err, ExprSchema, Result, ScalarValue, +}; use datafusion_expr::{ColumnarValue, Expr, ExprSchemable}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use std::any::Any; @@ -104,14 +105,37 @@ impl ScalarUDFImpl for GetFieldFunc { ); } }; - let access_schema = GetFieldAccessSchema::NamedStructField { name: name.clone() }; - let arg_dt = args[0].get_type(schema)?; - if arg_dt.is_null() { - return Ok(DataType::Null); + let data_type = args[0].get_type(schema)?; + match (data_type, name) { + (DataType::Map(fields, _), _) => { + match fields.data_type() { + DataType::Struct(fields) if fields.len() == 2 => { + // Arrow's MapArray is essentially a ListArray of structs with two columns. They are + // often named "key", and "value", but we don't require any specific naming here; + // instead, we assume that the second columnis the "value" column both here and in + // execution. + let value_field = fields.get(1).expect("fields should have exactly two members"); + Ok(value_field.data_type().clone()) + }, + _ => plan_err!("Map fields must contain a Struct with exactly 2 fields"), + } + } + (DataType::Struct(fields), ScalarValue::Utf8(Some(s))) => { + if s.is_empty() { + plan_err!( + "Struct based indexed access requires a non empty string" + ) + } else { + let field = fields.iter().find(|f| f.name() == s); + field.ok_or(plan_datafusion_err!("Field {s} not found in struct")).map(|f| f.data_type().clone()) + } + } + (DataType::Struct(_), _) => plan_err!( + "Only UTF8 strings are valid as an indexed field in a struct" + ), + (DataType::Null, _) => Ok(DataType::Null), + (other, _) => plan_err!("The expression to get an indexed field is only valid for `List`, `Struct`, `Map` or `Null` types, got {other}"), } - access_schema - .get_accessed_field(&arg_dt) - .map(|f| f.data_type().clone()) } fn invoke(&self, args: &[ColumnarValue]) -> Result { @@ -182,6 +206,7 @@ impl ScalarUDFImpl for GetFieldFunc { "get indexed field is only possible on struct with utf8 indexes. \ Tried with {name:?} index" ), + (DataType::Null, _) => Ok(ColumnarValue::Scalar(ScalarValue::Null)), (dt, name) => exec_err!( "get indexed field is only possible on lists with int64 indexes or struct \ with utf8 indexes. Tried {dt:?} with {name:?} index"