Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Support for e notation using existing parse_decimal in string to decimal conversion #6905

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
24 changes: 11 additions & 13 deletions arrow-cast/src/cast/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// under the License.

use crate::cast::*;
use crate::parse::*;

/// A utility trait that provides checked conversions between
/// decimal types inspired by [`NumCast`]
Expand Down Expand Up @@ -230,6 +231,7 @@ where
)?))
}

#[allow(dead_code)]
Copy link
Contributor Author

@himadripal himadripal Dec 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fails in clippy, hence added #[allow(dead_code)], there is no use, if required we can remove it and cover existing tests with parse_decimal.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should remove this and port the tests, to ensure we aren't losing test coverage / accidentally changing behaviour

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

/// Parses given string to specified decimal native (i128/i256) based on given
/// scale. Returns an `Err` if it cannot parse given string.
pub(crate) fn parse_string_to_decimal_native<T: DecimalType>(
Expand Down Expand Up @@ -342,10 +344,9 @@ where
&'a S: StringArrayType<'a>,
{
if cast_options.safe {
let iter = from.iter().map(|v| {
v.and_then(|v| parse_string_to_decimal_native::<T>(v, scale as usize).ok())
.and_then(|v| T::is_valid_decimal_precision(v, precision).then_some(v))
});
let iter = from
.iter()
.map(|v| v.and_then(|v| parse_decimal::<T>(v, precision, scale).ok()));
// Benefit:
// 20% performance improvement
// Soundness:
Expand All @@ -359,15 +360,12 @@ where
.iter()
.map(|v| {
v.map(|v| {
parse_string_to_decimal_native::<T>(v, scale as usize)
.map_err(|_| {
ArrowError::CastError(format!(
"Cannot cast string '{}' to value of {:?} type",
v,
T::DATA_TYPE,
))
})
.and_then(|v| T::validate_decimal_precision(v, precision).map(|_| v))
parse_decimal::<T>(v, precision, scale).map_err(|_| {
ArrowError::CastError(format!(
"Cannot cast string '{}' to decimal type of precision {} and scale {}",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

T:DATA_TYPE shows default Decimal(38,10) or Decimal256(76,..) in the error message, hiding the precision and scale provided for cast.

v, precision, scale
))
})
})
.transpose()
})
Expand Down
39 changes: 31 additions & 8 deletions arrow-cast/src/cast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2501,6 +2501,7 @@ where
#[cfg(test)]
mod tests {
use super::*;
use crate::parse::parse_decimal;
use arrow_buffer::{Buffer, IntervalDayTime, NullBuffer};
use chrono::NaiveDate;
use half::f16;
Expand Down Expand Up @@ -3843,6 +3844,22 @@ mod tests {
}
}
}
#[test]
fn test_cast_with_options_utf8_to_decimal() {
let array = StringArray::from(vec!["4e7"]);
let result = cast_with_options(
&array,
&DataType::Decimal128(10, 2),
&CastOptions {
safe: false,
format_options: FormatOptions::default(),
},
)
.unwrap();
let output_array = result.as_any().downcast_ref::<Decimal128Array>();
let result_128 = parse_decimal::<Decimal128Type>("40000000", 10, 2);
assert_eq!(output_array.unwrap().value(0), result_128.unwrap());
}

#[test]
fn test_cast_utf8_to_bool() {
Expand Down Expand Up @@ -8832,16 +8849,16 @@ mod tests {
format_options: FormatOptions::default(),
};
let casted_err = cast_with_options(&array, &output_type, &option).unwrap_err();
assert!(casted_err
.to_string()
.contains("Cannot cast string '4.4.5' to value of Decimal128(38, 10) type"));
assert!(casted_err.to_string().contains(
"Cast error: Cannot cast string '4.4.5' to decimal type of precision 38 and scale 2"
));

let str_array = StringArray::from(vec![". 0.123"]);
let array = Arc::new(str_array) as ArrayRef;
let casted_err = cast_with_options(&array, &output_type, &option).unwrap_err();
assert!(casted_err
.to_string()
.contains("Cannot cast string '. 0.123' to value of Decimal128(38, 10) type"));
assert!(casted_err.to_string().contains(
"Cast error: Cannot cast string '. 0.123' to decimal type of precision 38 and scale 2"
));
}

fn test_cast_string_to_decimal128_overflow(overflow_array: ArrayRef) {
Expand Down Expand Up @@ -8885,7 +8902,10 @@ mod tests {
format_options: FormatOptions::default(),
},
);
assert_eq!("Invalid argument error: 100000000000 is too large to store in a Decimal128 of precision 10. Max is 9999999999", err.unwrap_err().to_string());
assert_eq!(
"Cast error: Cannot cast string '1000' to decimal type of precision 10 and scale 8",
err.unwrap_err().to_string()
);
}

#[test]
Expand Down Expand Up @@ -8968,7 +8988,10 @@ mod tests {
format_options: FormatOptions::default(),
},
);
assert_eq!("Invalid argument error: 100000000000 is too large to store in a Decimal256 of precision 10. Max is 9999999999", err.unwrap_err().to_string());
assert_eq!(
"Cast error: Cannot cast string '1000' to decimal type of precision 10 and scale 8",
err.unwrap_err().to_string()
);
}

#[test]
Expand Down
21 changes: 18 additions & 3 deletions arrow-cast/src/parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -842,6 +842,7 @@ pub fn parse_decimal<T: DecimalType>(
let mut result = T::Native::usize_as(0);
let mut fractionals: i8 = 0;
let mut digits: u8 = 0;
let mut rounding_digit = -1; // to store digit after the scale for rounding
let base = T::Native::usize_as(10);

let bs = s.as_bytes();
Expand Down Expand Up @@ -871,6 +872,13 @@ pub fn parse_decimal<T: DecimalType>(
// Ignore leading zeros.
continue;
}
if fractionals == scale && scale != 0 && rounding_digit < 0 {
// Capture the rounding digit once
if rounding_digit < 0 {
rounding_digit = (b - b'0') as i8;
}
continue;
}
digits += 1;
result = result.mul_wrapping(base);
result = result.add_wrapping(T::Native::usize_as((b - b'0') as usize));
Expand Down Expand Up @@ -903,9 +911,10 @@ pub fn parse_decimal<T: DecimalType>(
)));
}
if fractionals == scale && scale != 0 {
// We have processed all the digits that we need. All that
// is left is to validate that the rest of the string contains
// valid digits.
// Capture the rounding digit once
if rounding_digit < 0 {
rounding_digit = (b - b'0') as i8;
}
continue;
}
fractionals += 1;
Expand Down Expand Up @@ -966,6 +975,10 @@ pub fn parse_decimal<T: DecimalType>(
"parse decimal overflow ({s})"
)));
}
//add one if >=5
if rounding_digit >= 5 {
result = result.add_wrapping(T::Native::usize_as(1));
}
}

Ok(if negative {
Expand Down Expand Up @@ -2547,6 +2560,8 @@ mod tests {
let e_notation_tests = [
("1.23e3", "1230.0", 2),
("5.6714e+2", "567.14", 4),
("4e+5", "400000", 4),
("4e7", "40000000", 2),
("5.6714e-2", "0.056714", 4),
("5.6714e-2", "0.056714", 3),
("5.6741214125e2", "567.41214125", 4),
Expand Down
Loading