Skip to content

Commit

Permalink
introduce build_array_list_primitive
Browse files Browse the repository at this point in the history
Signed-off-by: jayzhan211 <[email protected]>
  • Loading branch information
jayzhan211 committed Dec 4, 2023
1 parent a73be00 commit 994c97b
Showing 1 changed file with 31 additions and 89 deletions.
120 changes: 31 additions & 89 deletions datafusion/common/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1368,6 +1368,28 @@ impl ScalarValue {
}};
}

fn build_array_list_primitive(
scalars: impl IntoIterator<Item = ScalarValue>,
) -> Result<ArrayRef> {
let arrays = scalars
.into_iter()
.map(|s| s.to_array())
.collect::<Result<Vec<_>>>()?;
let capacity = Capacities::Array(arrays.iter().map(|arr| arr.len()).sum());
let arrays_data = arrays.iter().map(|arr| arr.to_data()).collect::<Vec<_>>();

let arrays_ref = arrays_data.iter().collect::<Vec<_>>();
let mut mutable =
MutableArrayData::with_capacities(arrays_ref, false, capacity);

// ScalarValue::List contains a single element ListArray.
for i in 0..arrays.len() {
mutable.extend(i, 0, 1);
}
let data = mutable.freeze();
Ok(arrow_array::make_array(data))
}

macro_rules! build_array_list_primitive {
($ARRAY_TY:ident, $SCALAR_TY:ident, $NATIVE_TYPE:ident, $LIST_TY:ident, $SCALAR_LIST:pat) => {{
Ok::<ArrayRef, DataFusionError>(Arc::new($LIST_TY::from_iter_primitive::<$ARRAY_TY, _, _>(
Expand Down Expand Up @@ -1541,95 +1563,15 @@ impl ScalarValue {
DataType::Interval(IntervalUnit::MonthDayNano) => {
build_array_primitive!(IntervalMonthDayNanoArray, IntervalMonthDayNano)
}
DataType::List(fields) if fields.data_type() == &DataType::Int8 => {
build_array_list_primitive!(
Int8Type,
Int8,
i8,
ListArray,
ScalarValue::List(_)
)?
}
DataType::List(fields) if fields.data_type() == &DataType::Int16 => {
build_array_list_primitive!(
Int16Type,
Int16,
i16,
ListArray,
ScalarValue::List(_)
)?
}
DataType::List(fields) if fields.data_type() == &DataType::Int32 => {
build_array_list_primitive!(
Int32Type,
Int32,
i32,
ListArray,
ScalarValue::List(_)
)?
}
DataType::List(fields) if fields.data_type() == &DataType::Int64 => {
build_array_list_primitive!(
Int64Type,
Int64,
i64,
ListArray,
ScalarValue::List(_)
)?
}
DataType::List(fields) if fields.data_type() == &DataType::UInt8 => {
build_array_list_primitive!(
UInt8Type,
UInt8,
u8,
ListArray,
ScalarValue::List(_)
)?
}
DataType::List(fields) if fields.data_type() == &DataType::UInt16 => {
build_array_list_primitive!(
UInt16Type,
UInt16,
u16,
ListArray,
ScalarValue::List(_)
)?
}
DataType::List(fields) if fields.data_type() == &DataType::UInt32 => {
build_array_list_primitive!(
UInt32Type,
UInt32,
u32,
ListArray,
ScalarValue::List(_)
)?
}
DataType::List(fields) if fields.data_type() == &DataType::UInt64 => {
build_array_list_primitive!(
UInt64Type,
UInt64,
u64,
ListArray,
ScalarValue::List(_)
)?
}
DataType::List(fields) if fields.data_type() == &DataType::Float32 => {
build_array_list_primitive!(
Float32Type,
Float32,
f32,
ListArray,
ScalarValue::List(_)
)?
}
DataType::List(fields) if fields.data_type() == &DataType::Float64 => {
build_array_list_primitive!(
Float64Type,
Float64,
f64,
ListArray,
ScalarValue::List(_)
)?
DataType::List(field)
if field.data_type().is_numeric()
&& !matches!(
field.data_type().to_owned(),
DataType::Decimal128(_, _) |
DataType::Decimal256(_, _)
) =>
{
build_array_list_primitive(scalars)?
}
DataType::List(fields) if fields.data_type() == &DataType::Utf8 => {
build_array_list_string!(
Expand Down

0 comments on commit 994c97b

Please sign in to comment.