Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Infer data type from schema for Values and add struct coercion to coalesce #12864

Merged
merged 16 commits into from
Oct 24, 2024
Merged
1 change: 1 addition & 0 deletions datafusion-cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion datafusion/common/src/dfschema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,6 @@ impl DFSchema {
None => self_unqualified_names.contains(field.name().as_str()),
};
if !duplicated_field {
// self.inner.fields.push(field.clone());
schema_builder.push(Arc::clone(field));
qualifiers.push(qualifier.cloned());
}
Expand Down
1 change: 1 addition & 0 deletions datafusion/expr-common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,5 @@ path = "src/lib.rs"
[dependencies]
arrow = { workspace = true }
datafusion-common = { workspace = true }
itertools = { workspace = true }
paste = "^1.0"
91 changes: 90 additions & 1 deletion datafusion/expr-common/src/type_coercion/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@ use arrow::datatypes::{
DataType, Field, FieldRef, Fields, TimeUnit, DECIMAL128_MAX_PRECISION,
DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE,
};
use datafusion_common::{exec_datafusion_err, plan_datafusion_err, plan_err, Result};
use datafusion_common::{
exec_datafusion_err, exec_err, internal_err, plan_datafusion_err, plan_err,
DataFusionError, Result,
};
use itertools::Itertools;

/// The type signature of an instantiation of binary operator expression such as
/// `lhs + rhs`
Expand Down Expand Up @@ -372,6 +376,8 @@ impl From<&DataType> for TypeCategory {
/// decimal precision and scale when coercing decimal types.
///
/// This function doesn't preserve correct field name and nullability for the struct type, we only care about data type.
///
/// Returns Option because we might want to continue on the code even if the data types are not coercible to the common type
pub fn type_union_resolution(data_types: &[DataType]) -> Option<DataType> {
if data_types.is_empty() {
return None;
Expand Down Expand Up @@ -529,6 +535,89 @@ fn type_union_resolution_coercion(
}
}

pub fn try_type_union_resolution(data_types: &[DataType]) -> Result<Vec<DataType>> {
jayzhan211 marked this conversation as resolved.
Show resolved Hide resolved
let mut err = None;
match try_type_union_resolution_with_struct(data_types) {
Ok(struct_types) => return Ok(struct_types),
Err(e) => err = Some(e),
}

if let Some(new_type) = type_union_resolution(data_types) {
Ok(vec![new_type; data_types.len()])
} else {
exec_err!("Fail to find the coerced type, errors: {:?}", err)
}
}

// Handle struct where we only change the data type but preserve the field name and nullability.
// Since field name is the key of the struct, so it shouldn't be updated to the common column name like "c0" or "c1"
pub fn try_type_union_resolution_with_struct(
data_types: &[DataType],
) -> Result<Vec<DataType>> {
let mut keys_string: Option<String> = None;
for data_type in data_types {
if let DataType::Struct(fields) = data_type {
let keys = fields.iter().map(|f| f.name().to_owned()).join(",");
jayzhan211 marked this conversation as resolved.
Show resolved Hide resolved
if let Some(ref k) = keys_string {
if *k != keys {
return exec_err!("Expect same keys for struct type but got mismatched pair {} and {}", *k, keys);
}
} else {
keys_string = Some(keys);
}
} else {
return exec_err!("Expect to get struct but got {}", data_type);
}
}

let mut struct_types: Vec<DataType> = if let DataType::Struct(fields) = &data_types[0]
{
fields.iter().map(|f| f.data_type().to_owned()).collect()
} else {
return internal_err!("Struct type is checked is the previous function, so this should be unreachable");
};

for data_type in data_types.iter().skip(1) {
if let DataType::Struct(fields) = data_type {
let incoming_struct_types: Vec<DataType> =
fields.iter().map(|f| f.data_type().to_owned()).collect();
// The order of field is verified above
for (lhs_type, rhs_type) in
struct_types.iter_mut().zip(incoming_struct_types.iter())
{
if let Some(coerced_type) =
type_union_resolution_coercion(lhs_type, rhs_type)
{
*lhs_type = coerced_type;
} else {
return exec_err!(
"Fail to find the coerced type for {} and {}",
lhs_type,
rhs_type
);
}
}
} else {
return exec_err!("Expect to get struct but got {}", data_type);
}
}

let mut final_struct_types = vec![];
for s in data_types {
let mut new_fields = vec![];
if let DataType::Struct(fields) = s {
for (i, f) in fields.iter().enumerate() {
let field = Arc::unwrap_or_clone(Arc::clone(f))
.with_data_type(struct_types[i].to_owned());
new_fields.push(Arc::new(field));
}
}
final_struct_types.push(DataType::Struct(new_fields.into()))
}

Ok(final_struct_types)
}

/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a
/// comparison operation
///
Expand Down
72 changes: 61 additions & 11 deletions datafusion/expr/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,15 @@ use crate::{

use super::dml::InsertOp;
use super::plan::ColumnUnnestList;
use arrow::compute::can_cast_types;
use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef};
use datafusion_common::display::ToStringifiedPlan;
use datafusion_common::file_options::file_type::FileType;
use datafusion_common::{
get_target_functional_dependencies, internal_err, not_impl_err, plan_datafusion_err,
plan_err, Column, DFSchema, DFSchemaRef, DataFusionError, FunctionalDependencies,
Result, ScalarValue, TableReference, ToDFSchema, UnnestOptions,
exec_err, get_target_functional_dependencies, internal_err, not_impl_err,
plan_datafusion_err, plan_err, Column, DFSchema, DFSchemaRef, DataFusionError,
FunctionalDependencies, Result, ScalarValue, TableReference, ToDFSchema,
UnnestOptions,
};
use datafusion_expr_common::type_coercion::binary::type_union_resolution;

Expand Down Expand Up @@ -168,7 +170,7 @@ impl LogicalPlanBuilder {
})))
}

/// Create a values list based relation, and the schema is inferred from data, consuming
/// Create a values list based relation, and the schema is inferred from data itself or table schema if provided, consuming
/// `value`. See the [Postgres VALUES](https://www.postgresql.org/docs/current/queries-values.html)
/// documentation for more details.
///
Expand All @@ -177,7 +179,7 @@ impl LogicalPlanBuilder {
/// so it's usually better to override the default names with a table alias list.
///
/// If the values include params/binders such as $1, $2, $3, etc, then the `param_data_types` should be provided.
pub fn values(mut values: Vec<Vec<Expr>>) -> Result<Self> {
pub fn values(values: Vec<Vec<Expr>>, schema: Option<&DFSchemaRef>) -> Result<Self> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 This makes a lot of sense to me to pass in the schema too if knon.

I think it might be nicer on users if we didn't make an API change and instead added a new API like

pub fn values_with_schema(values: Vec<Vec<Expr>>, schema: Option<&DFSchemaRef>) -> Result<Self> {
...
}

Even if we also deprecated values it would help users prepare for upgrade

if values.is_empty() {
return plan_err!("Values list cannot be empty");
}
Expand All @@ -196,16 +198,55 @@ impl LogicalPlanBuilder {
}
}

let empty_schema = DFSchema::empty();
// Check the type of value against the schema
if let Some(schema) = schema {
Self::infer_from_schema(values, schema)
} else {
// Infer from data itself
Self::infer_data(values)
}
}

fn infer_from_schema(values: Vec<Vec<Expr>>, schema: &DFSchema) -> Result<Self> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can give this a name to make it clear it is related to VALUES processing

Suggested change
fn infer_from_schema(values: Vec<Vec<Expr>>, schema: &DFSchema) -> Result<Self> {
fn infer_values_from_schema(values: Vec<Vec<Expr>>, schema: &DFSchema) -> Result<Self> {

let n_cols = values[0].len();
let mut field_types: Vec<DataType> = Vec::with_capacity(n_cols);
for j in 0..n_cols {
let field_type = schema.field(j).data_type();
for row in values.iter() {
let value = &row[j];
let data_type = value.get_type(schema)?;

if !data_type.equals_datatype(field_type) {
jayzhan211 marked this conversation as resolved.
Show resolved Hide resolved
if can_cast_types(&data_type, field_type) {
} else {
jayzhan211 marked this conversation as resolved.
Show resolved Hide resolved
return exec_err!(
"type mistmatch and can't cast to got {} and {}",
data_type,
field_type
);
}
}
}
field_types.push(field_type.to_owned());
}

Self::infer_inner(values, &field_types, schema)
}

fn infer_data(values: Vec<Vec<Expr>>) -> Result<Self> {
let n_cols = values[0].len();
let schema = DFSchema::empty();

let mut field_types: Vec<DataType> = Vec::with_capacity(n_cols);
for j in 0..n_cols {
let mut common_type: Option<DataType> = None;
for (i, row) in values.iter().enumerate() {
let value = &row[j];
let data_type = value.get_type(&empty_schema)?;
let data_type = value.get_type(&schema)?;
if data_type == DataType::Null {
continue;
}

if let Some(prev_type) = common_type {
// get common type of each column values.
let data_types = vec![prev_type.clone(), data_type.clone()];
Expand All @@ -221,14 +262,22 @@ impl LogicalPlanBuilder {
// since the code loop skips NULL
field_types.push(common_type.unwrap_or(DataType::Null));
}

Self::infer_inner(values, &field_types, &schema)
}

fn infer_inner(
mut values: Vec<Vec<Expr>>,
field_types: &[DataType],
schema: &DFSchema,
) -> Result<Self> {
// wrap cast if data type is not same as common type.
for row in &mut values {
for (j, field_type) in field_types.iter().enumerate() {
if let Expr::Literal(ScalarValue::Null) = row[j] {
row[j] = Expr::Literal(ScalarValue::try_from(field_type)?);
} else {
row[j] =
std::mem::take(&mut row[j]).cast_to(field_type, &empty_schema)?;
row[j] = std::mem::take(&mut row[j]).cast_to(field_type, schema)?;
}
}
}
Expand All @@ -243,6 +292,7 @@ impl LogicalPlanBuilder {
.collect::<Vec<_>>();
let dfschema = DFSchema::from_unqualified_fields(fields.into(), HashMap::new())?;
let schema = DFSchemaRef::new(dfschema);

Ok(Self::new(LogicalPlan::Values(Values { schema, values })))
}

Expand Down Expand Up @@ -2314,10 +2364,10 @@ mod tests {
fn test_union_after_join() -> Result<()> {
let values = vec![vec![lit(1)]];

let left = LogicalPlanBuilder::values(values.clone())?
let left = LogicalPlanBuilder::values(values.clone(), None)?
.alias("left")?
.build()?;
let right = LogicalPlanBuilder::values(values)?
let right = LogicalPlanBuilder::values(values, None)?
.alias("right")?
.build()?;

Expand Down
64 changes: 14 additions & 50 deletions datafusion/functions-nested/src/make_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@ use arrow_array::{
use arrow_buffer::OffsetBuffer;
use arrow_schema::DataType::{LargeList, List, Null};
use arrow_schema::{DataType, Field};
use datafusion_common::{exec_err, internal_err};
use datafusion_common::{plan_err, utils::array_into_list_array_nullable, Result};
use datafusion_expr::binary::type_union_resolution;
use datafusion_expr::binary::{
try_type_union_resolution_with_struct, type_union_resolution,
};
use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY;
use datafusion_expr::TypeSignature;
use datafusion_expr::{
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
};
use itertools::Itertools;

use crate::utils::make_scalar_function;

Expand Down Expand Up @@ -111,33 +111,16 @@ impl ScalarUDFImpl for MakeArray {
}

fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
if let Some(new_type) = type_union_resolution(arg_types) {
// TODO: Move the logic to type_union_resolution if this applies to other functions as well
// Handle struct where we only change the data type but preserve the field name and nullability.
// Since field name is the key of the struct, so it shouldn't be updated to the common column name like "c0" or "c1"
let is_struct_and_has_same_key = are_all_struct_and_have_same_key(arg_types)?;
if is_struct_and_has_same_key {
let data_types: Vec<_> = if let DataType::Struct(fields) = &arg_types[0] {
fields.iter().map(|f| f.data_type().to_owned()).collect()
} else {
return internal_err!("Struct type is checked is the previous function, so this should be unreachable");
};

let mut final_struct_types = vec![];
for s in arg_types {
let mut new_fields = vec![];
if let DataType::Struct(fields) = s {
for (i, f) in fields.iter().enumerate() {
let field = Arc::unwrap_or_clone(Arc::clone(f))
.with_data_type(data_types[i].to_owned());
new_fields.push(Arc::new(field));
}
}
final_struct_types.push(DataType::Struct(new_fields.into()))
}
return Ok(final_struct_types);
let mut errors = vec![];
match try_type_union_resolution_with_struct(arg_types) {
Ok(r) => return Ok(r),
Err(e) => {
errors.push(e);
}
}

if let Some(new_type) = type_union_resolution(arg_types) {
// TODO: Move FixedSizeList to List in type_union_resolution
if let DataType::FixedSizeList(field, _) = new_type {
Ok(vec![DataType::List(field); arg_types.len()])
} else if new_type.is_null() {
Expand All @@ -147,9 +130,10 @@ impl ScalarUDFImpl for MakeArray {
}
} else {
plan_err!(
"Fail to find the valid type between {:?} for {}",
"Fail to find the valid type between {:?} for {}, errors are {:?}",
arg_types,
self.name()
self.name(),
errors
)
}
}
Expand Down Expand Up @@ -188,26 +172,6 @@ fn get_make_array_doc() -> &'static Documentation {
})
}

fn are_all_struct_and_have_same_key(data_types: &[DataType]) -> Result<bool> {
let mut keys_string: Option<String> = None;
for data_type in data_types {
if let DataType::Struct(fields) = data_type {
let keys = fields.iter().map(|f| f.name().to_owned()).join(",");
if let Some(ref k) = keys_string {
if *k != keys {
return exec_err!("Expect same keys for struct type but got mismatched pair {} and {}", *k, keys);
}
} else {
keys_string = Some(keys);
}
} else {
return Ok(false);
}
}

Ok(true)
}

// Empty array is a special case that is useful for many other array functions
pub(super) fn empty_array_type() -> DataType {
DataType::List(Arc::new(Field::new("item", DataType::Int64, true)))
Expand Down
7 changes: 3 additions & 4 deletions datafusion/functions/src/core/coalesce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ use arrow::compute::kernels::zip::zip;
use arrow::compute::{and, is_not_null, is_null};
use arrow::datatypes::DataType;
use datafusion_common::{exec_err, ExprSchema, Result};
use datafusion_expr::binary::try_type_union_resolution;
use datafusion_expr::scalar_doc_sections::DOC_SECTION_CONDITIONAL;
use datafusion_expr::type_coercion::binary::type_union_resolution;
use datafusion_expr::{ColumnarValue, Documentation, Expr, ExprSchemable};
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
use itertools::Itertools;
Expand Down Expand Up @@ -137,9 +137,8 @@ impl ScalarUDFImpl for CoalesceFunc {
if arg_types.is_empty() {
return exec_err!("coalesce must have at least one argument");
}
let new_type = type_union_resolution(arg_types)
.unwrap_or(arg_types.first().unwrap().clone());
Ok(vec![new_type; arg_types.len()])

try_type_union_resolution(arg_types)
}

fn documentation(&self) -> Option<&Documentation> {
Expand Down
Loading
Loading