Skip to content

Commit

Permalink
support null for get field (#10665)
Browse files Browse the repository at this point in the history
Signed-off-by: jayzhan211 <[email protected]>
  • Loading branch information
jayzhan211 authored May 25, 2024
1 parent 4709fc6 commit ea92ae7
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 110 deletions.
100 changes: 0 additions & 100 deletions datafusion/expr/src/field_util.rs

This file was deleted.

1 change: 0 additions & 1 deletion datafusion/expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ pub mod expr;
pub mod expr_fn;
pub mod expr_rewriter;
pub mod expr_schema;
pub mod field_util;
pub mod function;
pub mod groups_accumulator;
pub mod interval_arithmetic;
Expand Down
43 changes: 34 additions & 9 deletions datafusion/functions/src/core/getfield.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ use arrow::array::{
};
use arrow::datatypes::DataType;
use datafusion_common::cast::{as_map_array, as_struct_array};
use datafusion_common::{exec_err, ExprSchema, Result, ScalarValue};
use datafusion_expr::field_util::GetFieldAccessSchema;
use datafusion_common::{
exec_err, plan_datafusion_err, plan_err, ExprSchema, Result, ScalarValue,
};
use datafusion_expr::{ColumnarValue, Expr, ExprSchemable};
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
use std::any::Any;
Expand Down Expand Up @@ -104,14 +105,37 @@ impl ScalarUDFImpl for GetFieldFunc {
);
}
};
let access_schema = GetFieldAccessSchema::NamedStructField { name: name.clone() };
let arg_dt = args[0].get_type(schema)?;
if arg_dt.is_null() {
return Ok(DataType::Null);
let data_type = args[0].get_type(schema)?;
match (data_type, name) {
(DataType::Map(fields, _), _) => {
match fields.data_type() {
DataType::Struct(fields) if fields.len() == 2 => {
// Arrow's MapArray is essentially a ListArray of structs with two columns. They are
// often named "key", and "value", but we don't require any specific naming here;
// instead, we assume that the second columnis the "value" column both here and in
// execution.
let value_field = fields.get(1).expect("fields should have exactly two members");
Ok(value_field.data_type().clone())
},
_ => plan_err!("Map fields must contain a Struct with exactly 2 fields"),
}
}
(DataType::Struct(fields), ScalarValue::Utf8(Some(s))) => {
if s.is_empty() {
plan_err!(
"Struct based indexed access requires a non empty string"
)
} else {
let field = fields.iter().find(|f| f.name() == s);
field.ok_or(plan_datafusion_err!("Field {s} not found in struct")).map(|f| f.data_type().clone())
}
}
(DataType::Struct(_), _) => plan_err!(
"Only UTF8 strings are valid as an indexed field in a struct"
),
(DataType::Null, _) => Ok(DataType::Null),
(other, _) => plan_err!("The expression to get an indexed field is only valid for `List`, `Struct`, `Map` or `Null` types, got {other}"),
}
access_schema
.get_accessed_field(&arg_dt)
.map(|f| f.data_type().clone())
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
Expand Down Expand Up @@ -182,6 +206,7 @@ impl ScalarUDFImpl for GetFieldFunc {
"get indexed field is only possible on struct with utf8 indexes. \
Tried with {name:?} index"
),
(DataType::Null, _) => Ok(ColumnarValue::Scalar(ScalarValue::Null)),
(dt, name) => exec_err!(
"get indexed field is only possible on lists with int64 indexes or struct \
with utf8 indexes. Tried {dt:?} with {name:?} index"
Expand Down

0 comments on commit ea92ae7

Please sign in to comment.