Skip to content

Commit

Permalink
[arrow-cast] Support cast numeric to string view (alternate) (#6816)
Browse files Browse the repository at this point in the history
* [arrow-cast] Support cast numeric to string view

Signed-off-by: Tai Le Manh <[email protected]>

* fix test

---------

Signed-off-by: Tai Le Manh <[email protected]>
Co-authored-by: Tai Le Manh <[email protected]>
  • Loading branch information
alamb and tlm365 authored Dec 1, 2024
1 parent d7581ef commit 8a8c10d
Show file tree
Hide file tree
Showing 2 changed files with 207 additions and 107 deletions.
290 changes: 183 additions & 107 deletions arrow-cast/src/cast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,8 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
(Decimal128(_, _) | Decimal256(_, _), UInt8 | UInt16 | UInt32 | UInt64) |
// decimal to signed numeric
(Decimal128(_, _) | Decimal256(_, _), Null | Int8 | Int16 | Int32 | Int64 | Float32 | Float64) => true,
// decimal to Utf8
(Decimal128(_, _) | Decimal256(_, _), Utf8 | LargeUtf8) => true,
// decimal to string
(Decimal128(_, _) | Decimal256(_, _), Utf8View | Utf8 | LargeUtf8) => true,
// string to decimal
(Utf8View | Utf8 | LargeUtf8, Decimal128(_, _) | Decimal256(_, _)) => true,
(Struct(from_fields), Struct(to_fields)) => {
Expand Down Expand Up @@ -232,6 +232,7 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
(BinaryView, Binary | LargeBinary | Utf8 | LargeUtf8 | Utf8View ) => true,
(Utf8View | Utf8 | LargeUtf8, _) => to_type.is_numeric() && to_type != &Float16,
(_, Utf8 | LargeUtf8) => from_type.is_primitive(),
(_, Utf8View) => from_type.is_numeric(),

(_, Binary | LargeBinary) => from_type.is_integer(),

Expand Down Expand Up @@ -917,6 +918,7 @@ pub fn cast_with_options(
Float64 => cast_decimal_to_float::<Decimal128Type, Float64Type, _>(array, |x| {
x as f64 / 10_f64.powi(*scale as i32)
}),
Utf8View => value_to_string_view(array, cast_options),
Utf8 => value_to_string::<i32>(array, cast_options),
LargeUtf8 => value_to_string::<i64>(array, cast_options),
Null => Ok(new_null_array(to_type, array.len())),
Expand Down Expand Up @@ -982,6 +984,7 @@ pub fn cast_with_options(
Float64 => cast_decimal_to_float::<Decimal256Type, Float64Type, _>(array, |x| {
x.to_f64().unwrap() / 10_f64.powi(*scale as i32)
}),
Utf8View => value_to_string_view(array, cast_options),
Utf8 => value_to_string::<i32>(array, cast_options),
LargeUtf8 => value_to_string::<i64>(array, cast_options),
Null => Ok(new_null_array(to_type, array.len())),
Expand Down Expand Up @@ -1462,6 +1465,9 @@ pub fn cast_with_options(
(BinaryView, _) => Err(ArrowError::CastError(format!(
"Casting from {from_type:?} to {to_type:?} not supported",
))),
(from_type, Utf8View) if from_type.is_numeric() => {
value_to_string_view(array, cast_options)
}
(from_type, LargeUtf8) if from_type.is_primitive() => {
value_to_string::<i64>(array, cast_options)
}
Expand Down Expand Up @@ -3707,6 +3713,55 @@ mod tests {
assert_eq!(10.0, c.value(3));
}

#[test]
fn test_cast_int_to_utf8view() {
let inputs = vec![
Arc::new(Int8Array::from(vec![None, Some(8), Some(9), Some(10)])) as ArrayRef,
Arc::new(Int16Array::from(vec![None, Some(8), Some(9), Some(10)])) as ArrayRef,
Arc::new(Int32Array::from(vec![None, Some(8), Some(9), Some(10)])) as ArrayRef,
Arc::new(Int64Array::from(vec![None, Some(8), Some(9), Some(10)])) as ArrayRef,
Arc::new(UInt8Array::from(vec![None, Some(8), Some(9), Some(10)])) as ArrayRef,
Arc::new(UInt16Array::from(vec![None, Some(8), Some(9), Some(10)])) as ArrayRef,
Arc::new(UInt32Array::from(vec![None, Some(8), Some(9), Some(10)])) as ArrayRef,
Arc::new(UInt64Array::from(vec![None, Some(8), Some(9), Some(10)])) as ArrayRef,
];
let expected: ArrayRef = Arc::new(StringViewArray::from(vec![
None,
Some("8"),
Some("9"),
Some("10"),
]));

for array in inputs {
assert!(can_cast_types(array.data_type(), &DataType::Utf8View));
let arr = cast(&array, &DataType::Utf8View).unwrap();
assert_eq!(expected.as_ref(), arr.as_ref());
}
}

#[test]
fn test_cast_float_to_utf8view() {
let inputs = vec![
Arc::new(Float16Array::from(vec![
Some(f16::from_f64(1.5)),
Some(f16::from_f64(2.5)),
None,
])) as ArrayRef,
Arc::new(Float32Array::from(vec![Some(1.5), Some(2.5), None])) as ArrayRef,
Arc::new(Float64Array::from(vec![Some(1.5), Some(2.5), None])) as ArrayRef,
];

let expected: ArrayRef =
Arc::new(StringViewArray::from(vec![Some("1.5"), Some("2.5"), None]));

for array in inputs {
println!("type: {}", array.data_type());
assert!(can_cast_types(array.data_type(), &DataType::Utf8View));
let arr = cast(&array, &DataType::Utf8View).unwrap();
assert_eq!(expected.as_ref(), arr.as_ref());
}
}

#[test]
fn test_cast_utf8_to_i32() {
let array = StringArray::from(vec!["5", "6", "seven", "8", "9.1"]);
Expand Down Expand Up @@ -5178,41 +5233,46 @@ mod tests {
assert_eq!("2018-12-25T00:00:00", c.value(1));
}

// Cast Timestamp to Utf8View is not supported yet
// TODO: Implement casting from Timestamp to Utf8View
// https://github.com/apache/arrow-rs/issues/6734
macro_rules! assert_cast_timestamp_to_string {
($array:expr, $datatype:expr, $output_array_type: ty, $expected:expr) => {{
let out = cast(&$array, &$datatype).unwrap();
let actual = out
.as_any()
.downcast_ref::<$output_array_type>()
.unwrap()
.into_iter()
.collect::<Vec<_>>();
assert_eq!(actual, $expected);
}};
($array:expr, $datatype:expr, $output_array_type: ty, $options:expr, $expected:expr) => {{
let out = cast_with_options(&$array, &$datatype, &$options).unwrap();
let actual = out
.as_any()
.downcast_ref::<$output_array_type>()
.unwrap()
.into_iter()
.collect::<Vec<_>>();
assert_eq!(actual, $expected);
}};
}

#[test]
fn test_cast_timestamp_to_strings() {
// "2018-12-25T00:00:02.001", "1997-05-19T00:00:03.005", None
let array =
TimestampMillisecondArray::from(vec![Some(864000003005), Some(1545696002001), None]);
let out = cast(&array, &DataType::Utf8).unwrap();
let out = out
.as_any()
.downcast_ref::<StringArray>()
.unwrap()
.into_iter()
.collect::<Vec<_>>();
assert_eq!(
out,
vec![
Some("1997-05-19T00:00:03.005"),
Some("2018-12-25T00:00:02.001"),
None
]
);
let out = cast(&array, &DataType::LargeUtf8).unwrap();
let out = out
.as_any()
.downcast_ref::<LargeStringArray>()
.unwrap()
.into_iter()
.collect::<Vec<_>>();
assert_eq!(
out,
vec![
Some("1997-05-19T00:00:03.005"),
Some("2018-12-25T00:00:02.001"),
None
]
);
let expected = vec![
Some("1997-05-19T00:00:03.005"),
Some("2018-12-25T00:00:02.001"),
None,
];

// assert_cast_timestamp_to_string!(array, DataType::Utf8View, StringViewArray, expected);
assert_cast_timestamp_to_string!(array, DataType::Utf8, StringArray, expected);
assert_cast_timestamp_to_string!(array, DataType::LargeUtf8, LargeStringArray, expected);
}

#[test]
Expand All @@ -5225,73 +5285,53 @@ mod tests {
.with_timestamp_format(Some(ts_format))
.with_timestamp_tz_format(Some(ts_format)),
};

// "2018-12-25T00:00:02.001", "1997-05-19T00:00:03.005", None
let array_without_tz =
TimestampMillisecondArray::from(vec![Some(864000003005), Some(1545696002001), None]);
let out = cast_with_options(&array_without_tz, &DataType::Utf8, &cast_options).unwrap();
let out = out
.as_any()
.downcast_ref::<StringArray>()
.unwrap()
.into_iter()
.collect::<Vec<_>>();
assert_eq!(
out,
vec![
Some("1997-05-19 00:00:03.005000"),
Some("2018-12-25 00:00:02.001000"),
None
]
let expected = vec![
Some("1997-05-19 00:00:03.005000"),
Some("2018-12-25 00:00:02.001000"),
None,
];
// assert_cast_timestamp_to_string!(array_without_tz, DataType::Utf8View, StringViewArray, cast_options, expected);
assert_cast_timestamp_to_string!(
array_without_tz,
DataType::Utf8,
StringArray,
cast_options,
expected
);
let out =
cast_with_options(&array_without_tz, &DataType::LargeUtf8, &cast_options).unwrap();
let out = out
.as_any()
.downcast_ref::<LargeStringArray>()
.unwrap()
.into_iter()
.collect::<Vec<_>>();
assert_eq!(
out,
vec![
Some("1997-05-19 00:00:03.005000"),
Some("2018-12-25 00:00:02.001000"),
None
]
assert_cast_timestamp_to_string!(
array_without_tz,
DataType::LargeUtf8,
LargeStringArray,
cast_options,
expected
);

let array_with_tz =
TimestampMillisecondArray::from(vec![Some(864000003005), Some(1545696002001), None])
.with_timezone(tz.to_string());
let out = cast_with_options(&array_with_tz, &DataType::Utf8, &cast_options).unwrap();
let out = out
.as_any()
.downcast_ref::<StringArray>()
.unwrap()
.into_iter()
.collect::<Vec<_>>();
assert_eq!(
out,
vec![
Some("1997-05-19 05:45:03.005000"),
Some("2018-12-25 05:45:02.001000"),
None
]
let expected = vec![
Some("1997-05-19 05:45:03.005000"),
Some("2018-12-25 05:45:02.001000"),
None,
];
// assert_cast_timestamp_to_string!(array_with_tz, DataType::Utf8View, StringViewArray, cast_options, expected);
assert_cast_timestamp_to_string!(
array_with_tz,
DataType::Utf8,
StringArray,
cast_options,
expected
);
let out = cast_with_options(&array_with_tz, &DataType::LargeUtf8, &cast_options).unwrap();
let out = out
.as_any()
.downcast_ref::<LargeStringArray>()
.unwrap()
.into_iter()
.collect::<Vec<_>>();
assert_eq!(
out,
vec![
Some("1997-05-19 05:45:03.005000"),
Some("2018-12-25 05:45:02.001000"),
None
]
assert_cast_timestamp_to_string!(
array_with_tz,
DataType::LargeUtf8,
LargeStringArray,
cast_options,
expected
);
}

Expand Down Expand Up @@ -9146,26 +9186,51 @@ mod tests {
}

#[test]
fn test_cast_decimal_to_utf8() {
fn test_cast_decimal_to_string() {
assert!(can_cast_types(
&DataType::Decimal128(10, 4),
&DataType::Utf8View
));
assert!(can_cast_types(
&DataType::Decimal256(38, 10),
&DataType::Utf8View
));

macro_rules! assert_decimal_values {
($array:expr) => {
let c = $array;
assert_eq!("1123.454", c.value(0));
assert_eq!("2123.456", c.value(1));
assert_eq!("-3123.453", c.value(2));
assert_eq!("-3123.456", c.value(3));
assert_eq!("0.000", c.value(4));
assert_eq!("0.123", c.value(5));
assert_eq!("1234.567", c.value(6));
assert_eq!("-1234.567", c.value(7));
assert!(c.is_null(8));
};
}

fn test_decimal_to_string<IN: ArrowPrimitiveType, OffsetSize: OffsetSizeTrait>(
output_type: DataType,
array: PrimitiveArray<IN>,
) {
let b = cast(&array, &output_type).unwrap();

assert_eq!(b.data_type(), &output_type);
let c = b.as_string::<OffsetSize>();

assert_eq!("1123.454", c.value(0));
assert_eq!("2123.456", c.value(1));
assert_eq!("-3123.453", c.value(2));
assert_eq!("-3123.456", c.value(3));
assert_eq!("0.000", c.value(4));
assert_eq!("0.123", c.value(5));
assert_eq!("1234.567", c.value(6));
assert_eq!("-1234.567", c.value(7));
assert!(c.is_null(8));
match b.data_type() {
DataType::Utf8View => {
let c = b.as_string_view();
assert_decimal_values!(c);
}
DataType::Utf8 | DataType::LargeUtf8 => {
let c = b.as_string::<OffsetSize>();
assert_decimal_values!(c);
}
_ => (),
}
}

let array128: Vec<Option<i128>> = vec![
Some(1123454),
Some(2123456),
Expand All @@ -9177,22 +9242,33 @@ mod tests {
Some(-123456789),
None,
];
let array256: Vec<Option<i256>> = array128
.iter()
.map(|num| num.map(i256::from_i128))
.collect();

let array256: Vec<Option<i256>> = array128.iter().map(|v| v.map(i256::from_i128)).collect();

test_decimal_to_string::<arrow_array::types::Decimal128Type, i32>(
test_decimal_to_string::<Decimal128Type, i32>(
DataType::Utf8View,
create_decimal_array(array128.clone(), 7, 3).unwrap(),
);
test_decimal_to_string::<Decimal128Type, i32>(
DataType::Utf8,
create_decimal_array(array128.clone(), 7, 3).unwrap(),
);
test_decimal_to_string::<arrow_array::types::Decimal128Type, i64>(
test_decimal_to_string::<Decimal128Type, i64>(
DataType::LargeUtf8,
create_decimal_array(array128, 7, 3).unwrap(),
);
test_decimal_to_string::<arrow_array::types::Decimal256Type, i32>(

test_decimal_to_string::<Decimal256Type, i32>(
DataType::Utf8View,
create_decimal256_array(array256.clone(), 7, 3).unwrap(),
);
test_decimal_to_string::<Decimal256Type, i32>(
DataType::Utf8,
create_decimal256_array(array256.clone(), 7, 3).unwrap(),
);
test_decimal_to_string::<arrow_array::types::Decimal256Type, i64>(
test_decimal_to_string::<Decimal256Type, i64>(
DataType::LargeUtf8,
create_decimal256_array(array256, 7, 3).unwrap(),
);
Expand Down
Loading

0 comments on commit 8a8c10d

Please sign in to comment.