Skip to content

Commit

Permalink
fix: make get_valid_types handle TypeSignature::Numeric correctly (ap…
Browse files Browse the repository at this point in the history
…ache#14060)

* fix get_valid_types with TypeSignature::Numeric

* fix sqllogictest
  • Loading branch information
niebayes authored Jan 11, 2025
1 parent 334d6ec commit 4b7281e
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 3 deletions.
65 changes: 65 additions & 0 deletions datafusion/expr/src/type_coercion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,10 @@ fn get_valid_types(
// and their default type is double precision
if logical_data_type == NativeType::Null {
valid_type = DataType::Float64;
} else if !logical_data_type.is_numeric() {
return plan_err!(
"The signature expected NativeType::Numeric but received {logical_data_type}"
);
}

vec![vec![valid_type; *number]]
Expand Down Expand Up @@ -997,6 +1001,67 @@ mod tests {
}
}

#[test]
fn test_get_valid_types_numeric() -> Result<()> {
let get_valid_types_flatten =
|signature: &TypeSignature, current_types: &[DataType]| {
get_valid_types(signature, current_types)
.unwrap()
.into_iter()
.flatten()
.collect::<Vec<_>>()
};

// Trivial case.
let got = get_valid_types_flatten(&TypeSignature::Numeric(1), &[DataType::Int32]);
assert_eq!(got, [DataType::Int32]);

// Args are coerced into a common numeric type.
let got = get_valid_types_flatten(
&TypeSignature::Numeric(2),
&[DataType::Int32, DataType::Int64],
);
assert_eq!(got, [DataType::Int64, DataType::Int64]);

// Args are coerced into a common numeric type, specifically, int would be coerced to float.
let got = get_valid_types_flatten(
&TypeSignature::Numeric(3),
&[DataType::Int32, DataType::Int64, DataType::Float64],
);
assert_eq!(
got,
[DataType::Float64, DataType::Float64, DataType::Float64]
);

// Cannot coerce args to a common numeric type.
let got = get_valid_types(
&TypeSignature::Numeric(2),
&[DataType::Int32, DataType::Utf8],
)
.unwrap_err();
assert_contains!(
got.to_string(),
"The signature expected NativeType::Numeric but received NativeType::String"
);

// Fallbacks to float64 if the arg is of type null.
let got = get_valid_types_flatten(&TypeSignature::Numeric(1), &[DataType::Null]);
assert_eq!(got, [DataType::Float64]);

// Rejects non-numeric arg.
let got = get_valid_types(
&TypeSignature::Numeric(1),
&[DataType::Timestamp(TimeUnit::Second, None)],
)
.unwrap_err();
assert_contains!(
got.to_string(),
"The signature expected NativeType::Numeric but received NativeType::Timestamp(Second, None)"
);

Ok(())
}

#[test]
fn test_get_valid_types_one_of() -> Result<()> {
let signature =
Expand Down
6 changes: 3 additions & 3 deletions datafusion/sqllogictest/test_files/math.slt
Original file line number Diff line number Diff line change
Expand Up @@ -126,15 +126,15 @@ statement error
SELECT abs(1, 2);

# abs: unsupported argument type
query error This feature is not implemented: Unsupported data type Utf8 for function abs
query error DataFusion error: Error during planning: The signature expected NativeType::Numeric but received NativeType::String
SELECT abs('foo');

# abs: numeric string
# TODO: In Postgres, '-1.2' is unknown type and interpreted to float8 so they don't fail on this query
query error DataFusion error: This feature is not implemented: Unsupported data type Utf8 for function abs
query error DataFusion error: Error during planning: The signature expected NativeType::Numeric but received NativeType::String
select abs('-1.2');

query error DataFusion error: This feature is not implemented: Unsupported data type Utf8 for function abs
query error DataFusion error: Error during planning: The signature expected NativeType::Numeric but received NativeType::String
select abs(arrow_cast('-1.2', 'Utf8'));

statement ok
Expand Down

0 comments on commit 4b7281e

Please sign in to comment.