Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Infer data type from schema for Values and add struct coercion to coalesce #12864

Merged
merged 16 commits into from
Oct 24, 2024
Merged
1 change: 1 addition & 0 deletions datafusion-cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion datafusion/common/src/dfschema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,6 @@ impl DFSchema {
None => self_unqualified_names.contains(field.name().as_str()),
};
if !duplicated_field {
// self.inner.fields.push(field.clone());
schema_builder.push(Arc::clone(field));
qualifiers.push(qualifier.cloned());
}
Expand Down
1 change: 1 addition & 0 deletions datafusion/expr-common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,5 @@ path = "src/lib.rs"
[dependencies]
arrow = { workspace = true }
datafusion-common = { workspace = true }
itertools = { workspace = true }
paste = "^1.0"
90 changes: 89 additions & 1 deletion datafusion/expr-common/src/type_coercion/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ use arrow::datatypes::{
DataType, Field, FieldRef, Fields, TimeUnit, DECIMAL128_MAX_PRECISION,
DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE,
};
use datafusion_common::{exec_datafusion_err, plan_datafusion_err, plan_err, Result};
use datafusion_common::{
exec_datafusion_err, exec_err, internal_err, plan_datafusion_err, plan_err, Result,
};
use itertools::Itertools;

/// The type signature of an instantiation of binary operator expression such as
/// `lhs + rhs`
Expand Down Expand Up @@ -372,6 +375,8 @@ impl From<&DataType> for TypeCategory {
/// decimal precision and scale when coercing decimal types.
///
/// This function doesn't preserve correct field name and nullability for the struct type, we only care about data type.
///
/// Returns Option because we might want to continue on the code even if the data types are not coercible to the common type
pub fn type_union_resolution(data_types: &[DataType]) -> Option<DataType> {
if data_types.is_empty() {
return None;
Expand Down Expand Up @@ -529,6 +534,89 @@ fn type_union_resolution_coercion(
}
}

/// Handle type union resolution including struct type and others.
pub fn try_type_union_resolution(data_types: &[DataType]) -> Result<Vec<DataType>> {
jayzhan211 marked this conversation as resolved.
Show resolved Hide resolved
let err = match try_type_union_resolution_with_struct(data_types) {
Ok(struct_types) => return Ok(struct_types),
Err(e) => Some(e),
};

if let Some(new_type) = type_union_resolution(data_types) {
Ok(vec![new_type; data_types.len()])
} else {
exec_err!("Fail to find the coerced type, errors: {:?}", err)
}
}

// Handle struct where we only change the data type but preserve the field name and nullability.
// Since field name is the key of the struct, so it shouldn't be updated to the common column name like "c0" or "c1"
pub fn try_type_union_resolution_with_struct(
data_types: &[DataType],
) -> Result<Vec<DataType>> {
let mut keys_string: Option<String> = None;
for data_type in data_types {
if let DataType::Struct(fields) = data_type {
let keys = fields.iter().map(|f| f.name().to_owned()).join(",");
jayzhan211 marked this conversation as resolved.
Show resolved Hide resolved
if let Some(ref k) = keys_string {
if *k != keys {
return exec_err!("Expect same keys for struct type but got mismatched pair {} and {}", *k, keys);
}
} else {
keys_string = Some(keys);
}
} else {
return exec_err!("Expect to get struct but got {}", data_type);
}
}

let mut struct_types: Vec<DataType> = if let DataType::Struct(fields) = &data_types[0]
{
fields.iter().map(|f| f.data_type().to_owned()).collect()
} else {
return internal_err!("Struct type is checked is the previous function, so this should be unreachable");
};

for data_type in data_types.iter().skip(1) {
if let DataType::Struct(fields) = data_type {
let incoming_struct_types: Vec<DataType> =
fields.iter().map(|f| f.data_type().to_owned()).collect();
// The order of field is verified above
for (lhs_type, rhs_type) in
struct_types.iter_mut().zip(incoming_struct_types.iter())
{
if let Some(coerced_type) =
type_union_resolution_coercion(lhs_type, rhs_type)
{
*lhs_type = coerced_type;
} else {
return exec_err!(
"Fail to find the coerced type for {} and {}",
lhs_type,
rhs_type
);
}
}
} else {
return exec_err!("Expect to get struct but got {}", data_type);
}
}

let mut final_struct_types = vec![];
for s in data_types {
let mut new_fields = vec![];
if let DataType::Struct(fields) = s {
for (i, f) in fields.iter().enumerate() {
let field = Arc::unwrap_or_clone(Arc::clone(f))
.with_data_type(struct_types[i].to_owned());
new_fields.push(Arc::new(field));
}
}
final_struct_types.push(DataType::Struct(new_fields.into()))
}

Ok(final_struct_types)
}

/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a
/// comparison operation
///
Expand Down
97 changes: 89 additions & 8 deletions datafusion/expr/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,15 @@ use crate::{

use super::dml::InsertOp;
use super::plan::ColumnUnnestList;
use arrow::compute::can_cast_types;
use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef};
use datafusion_common::display::ToStringifiedPlan;
use datafusion_common::file_options::file_type::FileType;
use datafusion_common::{
get_target_functional_dependencies, internal_err, not_impl_err, plan_datafusion_err,
plan_err, Column, DFSchema, DFSchemaRef, DataFusionError, FunctionalDependencies,
Result, ScalarValue, TableReference, ToDFSchema, UnnestOptions,
exec_err, get_target_functional_dependencies, internal_err, not_impl_err,
plan_datafusion_err, plan_err, Column, DFSchema, DFSchemaRef, DataFusionError,
FunctionalDependencies, Result, ScalarValue, TableReference, ToDFSchema,
UnnestOptions,
};
use datafusion_expr_common::type_coercion::binary::type_union_resolution;

Expand Down Expand Up @@ -172,12 +174,45 @@ impl LogicalPlanBuilder {
/// `value`. See the [Postgres VALUES](https://www.postgresql.org/docs/current/queries-values.html)
/// documentation for more details.
///
/// so it's usually better to override the default names with a table alias list.
///
/// If the values include params/binders such as $1, $2, $3, etc, then the `param_data_types` should be provided.
pub fn values(values: Vec<Vec<Expr>>) -> Result<Self> {
if values.is_empty() {
return plan_err!("Values list cannot be empty");
}
let n_cols = values[0].len();
if n_cols == 0 {
return plan_err!("Values list cannot be zero length");
}
for (i, row) in values.iter().enumerate() {
if row.len() != n_cols {
return plan_err!(
"Inconsistent data length across values list: got {} values in row {} but expected {}",
row.len(),
i,
n_cols
);
}
}

// Infer from data itself
Self::infer_data(values)
}

/// Create a values list based relation, and the schema is inferred from data itself or table schema if provided, consuming
/// `value`. See the [Postgres VALUES](https://www.postgresql.org/docs/current/queries-values.html)
/// documentation for more details.
///
/// By default, it assigns the names column1, column2, etc. to the columns of a VALUES table.
/// The column names are not specified by the SQL standard and different database systems do it differently,
/// so it's usually better to override the default names with a table alias list.
///
/// If the values include params/binders such as $1, $2, $3, etc, then the `param_data_types` should be provided.
pub fn values(mut values: Vec<Vec<Expr>>) -> Result<Self> {
pub fn values_with_schema(
values: Vec<Vec<Expr>>,
schema: &DFSchemaRef,
) -> Result<Self> {
if values.is_empty() {
return plan_err!("Values list cannot be empty");
}
Expand All @@ -196,16 +231,53 @@ impl LogicalPlanBuilder {
}
}

let empty_schema = DFSchema::empty();
// Check the type of value against the schema
Self::infer_values_from_schema(values, schema)
}

fn infer_values_from_schema(
values: Vec<Vec<Expr>>,
schema: &DFSchema,
) -> Result<Self> {
let n_cols = values[0].len();
let mut field_types: Vec<DataType> = Vec::with_capacity(n_cols);
for j in 0..n_cols {
let field_type = schema.field(j).data_type();
for row in values.iter() {
let value = &row[j];
let data_type = value.get_type(schema)?;

if !data_type.equals_datatype(field_type) {
jayzhan211 marked this conversation as resolved.
Show resolved Hide resolved
if can_cast_types(&data_type, field_type) {
} else {
jayzhan211 marked this conversation as resolved.
Show resolved Hide resolved
return exec_err!(
"type mistmatch and can't cast to got {} and {}",
data_type,
field_type
);
}
}
}
field_types.push(field_type.to_owned());
}

Self::infer_inner(values, &field_types, schema)
}

fn infer_data(values: Vec<Vec<Expr>>) -> Result<Self> {
let n_cols = values[0].len();
let schema = DFSchema::empty();

let mut field_types: Vec<DataType> = Vec::with_capacity(n_cols);
for j in 0..n_cols {
let mut common_type: Option<DataType> = None;
for (i, row) in values.iter().enumerate() {
let value = &row[j];
let data_type = value.get_type(&empty_schema)?;
let data_type = value.get_type(&schema)?;
if data_type == DataType::Null {
continue;
}

if let Some(prev_type) = common_type {
// get common type of each column values.
let data_types = vec![prev_type.clone(), data_type.clone()];
Expand All @@ -221,14 +293,22 @@ impl LogicalPlanBuilder {
// since the code loop skips NULL
field_types.push(common_type.unwrap_or(DataType::Null));
}

Self::infer_inner(values, &field_types, &schema)
}

fn infer_inner(
mut values: Vec<Vec<Expr>>,
field_types: &[DataType],
schema: &DFSchema,
) -> Result<Self> {
// wrap cast if data type is not same as common type.
for row in &mut values {
for (j, field_type) in field_types.iter().enumerate() {
if let Expr::Literal(ScalarValue::Null) = row[j] {
row[j] = Expr::Literal(ScalarValue::try_from(field_type)?);
} else {
row[j] =
std::mem::take(&mut row[j]).cast_to(field_type, &empty_schema)?;
row[j] = std::mem::take(&mut row[j]).cast_to(field_type, schema)?;
}
}
}
Expand All @@ -243,6 +323,7 @@ impl LogicalPlanBuilder {
.collect::<Vec<_>>();
let dfschema = DFSchema::from_unqualified_fields(fields.into(), HashMap::new())?;
let schema = DFSchemaRef::new(dfschema);

Ok(Self::new(LogicalPlan::Values(Values { schema, values })))
}

Expand Down
64 changes: 14 additions & 50 deletions datafusion/functions-nested/src/make_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@ use arrow_array::{
use arrow_buffer::OffsetBuffer;
use arrow_schema::DataType::{LargeList, List, Null};
use arrow_schema::{DataType, Field};
use datafusion_common::{exec_err, internal_err};
use datafusion_common::{plan_err, utils::array_into_list_array_nullable, Result};
use datafusion_expr::binary::type_union_resolution;
use datafusion_expr::binary::{
try_type_union_resolution_with_struct, type_union_resolution,
};
use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY;
use datafusion_expr::TypeSignature;
use datafusion_expr::{
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
};
use itertools::Itertools;

use crate::utils::make_scalar_function;

Expand Down Expand Up @@ -111,33 +111,16 @@ impl ScalarUDFImpl for MakeArray {
}

fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
if let Some(new_type) = type_union_resolution(arg_types) {
// TODO: Move the logic to type_union_resolution if this applies to other functions as well
// Handle struct where we only change the data type but preserve the field name and nullability.
// Since field name is the key of the struct, so it shouldn't be updated to the common column name like "c0" or "c1"
let is_struct_and_has_same_key = are_all_struct_and_have_same_key(arg_types)?;
if is_struct_and_has_same_key {
let data_types: Vec<_> = if let DataType::Struct(fields) = &arg_types[0] {
fields.iter().map(|f| f.data_type().to_owned()).collect()
} else {
return internal_err!("Struct type is checked is the previous function, so this should be unreachable");
};

let mut final_struct_types = vec![];
for s in arg_types {
let mut new_fields = vec![];
if let DataType::Struct(fields) = s {
for (i, f) in fields.iter().enumerate() {
let field = Arc::unwrap_or_clone(Arc::clone(f))
.with_data_type(data_types[i].to_owned());
new_fields.push(Arc::new(field));
}
}
final_struct_types.push(DataType::Struct(new_fields.into()))
}
return Ok(final_struct_types);
let mut errors = vec![];
match try_type_union_resolution_with_struct(arg_types) {
Ok(r) => return Ok(r),
Err(e) => {
errors.push(e);
}
}

if let Some(new_type) = type_union_resolution(arg_types) {
// TODO: Move FixedSizeList to List in type_union_resolution
if let DataType::FixedSizeList(field, _) = new_type {
Ok(vec![DataType::List(field); arg_types.len()])
} else if new_type.is_null() {
Expand All @@ -147,9 +130,10 @@ impl ScalarUDFImpl for MakeArray {
}
} else {
plan_err!(
"Fail to find the valid type between {:?} for {}",
"Fail to find the valid type between {:?} for {}, errors are {:?}",
arg_types,
self.name()
self.name(),
errors
)
}
}
Expand Down Expand Up @@ -188,26 +172,6 @@ fn get_make_array_doc() -> &'static Documentation {
})
}

fn are_all_struct_and_have_same_key(data_types: &[DataType]) -> Result<bool> {
let mut keys_string: Option<String> = None;
for data_type in data_types {
if let DataType::Struct(fields) = data_type {
let keys = fields.iter().map(|f| f.name().to_owned()).join(",");
if let Some(ref k) = keys_string {
if *k != keys {
return exec_err!("Expect same keys for struct type but got mismatched pair {} and {}", *k, keys);
}
} else {
keys_string = Some(keys);
}
} else {
return Ok(false);
}
}

Ok(true)
}

// Empty array is a special case that is useful for many other array functions
pub(super) fn empty_array_type() -> DataType {
DataType::List(Arc::new(Field::new("item", DataType::Int64, true)))
Expand Down
Loading
Loading