Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: add support for Decimal128 and Decimal256 types in interval arithmetic #14126

Merged
merged 8 commits into from
Jan 25, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 84 additions & 4 deletions datafusion/expr-common/src/interval_arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,14 @@ macro_rules! get_extreme_value {
DataType::Interval(IntervalUnit::MonthDayNano) => {
ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano::$extreme))
}
DataType::Decimal128(precision, scale) => {
ScalarValue::Decimal128(Some(i128::$extreme), *precision, *scale)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this correct for min/max of i128? It seems like the minimum value of Decimal128(precision, scale) would be the minimum value for the precision and scale separately rather than the min value of the overall i128 🤔

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh, that makes sense. These huge arrays remind me it should be all of 9s...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sadly the decimal 256 version is not public 😢 let me find a workaround

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe get_extreme_value should be part of arrow?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sadly the decimal 256 version is not public 😢 let me find a workaround

I recommend

  1. copy/pasting the value into DataFusion for now
  2. File a ticket in arrow-rs to make it public (or maybe even a PR!)
  3. leave a comment next to the copy in DataFUsion referencing the upstream ticket

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I filed a PR in arrow to expose these constants:

}
DataType::Decimal256(precision, scale) => ScalarValue::Decimal256(
Some(arrow::datatypes::i256::$extreme),
*precision,
*scale,
),
_ => unreachable!(),
}
};
Expand Down Expand Up @@ -1008,17 +1016,20 @@ fn handle_overflow<const UPPER: bool>(
lhs: &ScalarValue,
rhs: &ScalarValue,
) -> ScalarValue {
let zero = ScalarValue::new_zero(dt).unwrap();
let lhs_zero = ScalarValue::new_zero(&lhs.data_type()).unwrap();
let rhs_zero = ScalarValue::new_zero(&rhs.data_type()).unwrap();
let positive_sign = match op {
Operator::Multiply | Operator::Divide => {
lhs.lt(&zero) && rhs.lt(&zero) || lhs.gt(&zero) && rhs.gt(&zero)
lhs.lt(&lhs_zero) && rhs.lt(&rhs_zero)
|| lhs.gt(&lhs_zero) && rhs.gt(&rhs_zero)
}
Operator::Plus => lhs.ge(&zero),
Operator::Plus => lhs.ge(&lhs_zero),
Operator::Minus => lhs.ge(rhs),
_ => {
unreachable!()
}
};

match (UPPER, positive_sign) {
(true, true) | (false, false) => ScalarValue::try_from(dt).unwrap(),
(true, false) => {
Expand Down Expand Up @@ -1832,7 +1843,12 @@ impl NullableInterval {

#[cfg(test)]
mod tests {
use crate::interval_arithmetic::{next_value, prev_value, satisfy_greater, Interval};
use crate::{
interval_arithmetic::{
handle_overflow, next_value, prev_value, satisfy_greater, Interval,
},
operator::Operator,
};

use arrow::datatypes::DataType;
use datafusion_common::{Result, ScalarValue};
Expand Down Expand Up @@ -3108,6 +3124,70 @@ mod tests {
Ok(())
}

#[test]
fn test_overflow_handling() -> Result<()> {
// Test integer overflow handling:
let dt = DataType::Int32;
let op = Operator::Plus;
let lhs = ScalarValue::Int32(Some(i32::MAX));
let rhs = ScalarValue::Int32(Some(1));
let result = handle_overflow::<true>(&dt, op, &lhs, &rhs);
assert_eq!(result, ScalarValue::Int32(None));
let result = handle_overflow::<false>(&dt, op, &lhs, &rhs);
assert_eq!(result, ScalarValue::Int32(Some(i32::MAX)));

// Test float overflow handling:
let dt = DataType::Float32;
let op = Operator::Multiply;
let lhs = ScalarValue::Float32(Some(f32::MAX));
let rhs = ScalarValue::Float32(Some(2.0));
let result = handle_overflow::<true>(&dt, op, &lhs, &rhs);
assert_eq!(result, ScalarValue::Float32(None));
let result = handle_overflow::<false>(&dt, op, &lhs, &rhs);
assert_eq!(result, ScalarValue::Float32(Some(f32::MAX)));

// Test float underflow handling:
let lhs = ScalarValue::Float32(Some(f32::MIN));
let rhs = ScalarValue::Float32(Some(2.0));
let result = handle_overflow::<true>(&dt, op, &lhs, &rhs);
assert_eq!(result, ScalarValue::Float32(Some(f32::MIN)));
let result = handle_overflow::<false>(&dt, op, &lhs, &rhs);
assert_eq!(result, ScalarValue::Float32(None));

// Test integer underflow handling:
let dt = DataType::Int64;
let op = Operator::Minus;
let lhs = ScalarValue::Int64(Some(i64::MIN));
let rhs = ScalarValue::Int64(Some(1));
let result = handle_overflow::<true>(&dt, op, &lhs, &rhs);
assert_eq!(result, ScalarValue::Int64(Some(i64::MIN)));
let result = handle_overflow::<false>(&dt, op, &lhs, &rhs);
assert_eq!(result, ScalarValue::Int64(None));

// Test unsigned integer handling:
let dt = DataType::UInt32;
let op = Operator::Minus;
let lhs = ScalarValue::UInt32(Some(0));
let rhs = ScalarValue::UInt32(Some(1));
let result = handle_overflow::<true>(&dt, op, &lhs, &rhs);
assert_eq!(result, ScalarValue::UInt32(Some(0)));
let result = handle_overflow::<false>(&dt, op, &lhs, &rhs);
assert_eq!(result, ScalarValue::UInt32(None));

// Test decimal handling:
let dt = DataType::Decimal128(38, 35);
let op = Operator::Plus;
let lhs =
ScalarValue::Decimal128(Some(54321543215432154321543215432154321), 35, 35);
let rhs = ScalarValue::Decimal128(Some(10000), 20, 0);
let result = handle_overflow::<true>(&dt, op, &lhs, &rhs);
assert_eq!(result, ScalarValue::Decimal128(None, 38, 35));
let result = handle_overflow::<false>(&dt, op, &lhs, &rhs);
assert_eq!(result, ScalarValue::Decimal128(Some(i128::MAX), 38, 35));

Ok(())
}

#[test]
fn test_cardinality_of_intervals() -> Result<()> {
// In IEEE 754 standard for floating-point arithmetic, if we keep the sign and exponent fields same,
Expand Down
4 changes: 4 additions & 0 deletions datafusion/sqllogictest/test_files/select.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1836,3 +1836,7 @@ DROP TABLE test;
# Can't reference an unqualified column by a qualified name
query error DataFusion error: Schema error: No field named t1\.v1\. Valid fields are "t1\.v1"\.
SELECT t1.v1 FROM (SELECT 1 AS "t1.v1");

# Test issue: https://github.com/apache/datafusion/issues/14124
query error DataFusion error: Arrow error: Arithmetic overflow: Overflow happened on: 10000 \* 100000000000000000000000000000000000
SELECT ('0.54321543215432154321543215432154321'::DECIMAL(35,35) + 10000)::VARCHAR
Loading