diff --git a/vortex-array/src/array/sparse/compute.rs b/vortex-array/src/array/sparse/compute.rs index 37d102ec02..12c94f1bfb 100644 --- a/vortex-array/src/array/sparse/compute.rs +++ b/vortex-array/src/array/sparse/compute.rs @@ -1,13 +1,13 @@ use arrow_buffer::BooleanBufferBuilder; use itertools::Itertools; -use vortex_error::{vortex_bail, vortex_err, VortexResult}; +use vortex_error::{vortex_bail, VortexResult}; use crate::array::downcast::DowncastArrayBuiltin; use crate::array::primitive::PrimitiveArray; use crate::array::sparse::SparseArray; use crate::array::{Array, ArrayRef}; use crate::compute::as_contiguous::{as_contiguous, AsContiguousFn}; -use crate::compute::flatten::{flatten, FlattenFn, FlattenedArray}; +use crate::compute::flatten::{flatten_primitive, FlattenFn, FlattenedArray}; use crate::compute::scalar_at::{scalar_at, ScalarAtFn}; use crate::compute::ArrayCompute; use crate::match_each_native_ptype; @@ -67,54 +67,45 @@ impl FlattenFn for SparseArray { let mut validity = BooleanBufferBuilder::new(self.len()); validity.append_n(self.len(), false); - let values = flatten(self.values())?; - let null_fill = self.fill_value().is_null(); - if let FlattenedArray::Primitive(ref parray) = values { - match_each_native_ptype!(parray.ptype(), |$P| { - flatten_primitive::<$P>( - self, - parray, - indices, - null_fill, - validity - ) - }) - } else { - Err(vortex_err!( - "Cannot flatten SparseArray with non-primitive values" - )) - } + let values = flatten_primitive(self.values())?; + match_each_native_ptype!(values.ptype(), |$P| { + flatten_sparse_values( + values.typed_data::<$P>(), + &indices, + self.len(), + self.fill_value(), + validity + ) + }) } } -fn flatten_primitive( - sparse_array: &SparseArray, - parray: &PrimitiveArray, - indices: Vec, - null_fill: bool, + +fn flatten_sparse_values( + values: &[T], + indices: &[usize], + len: usize, + fill_value: &Scalar, mut validity: BooleanBufferBuilder, ) -> VortexResult { - let fill_value = if null_fill { + let primitive_fill = if fill_value.is_null() { T::default() } else { - sparse_array.fill_value.clone().try_into()? + fill_value.try_into()? }; - let mut values = vec![fill_value; sparse_array.len()]; + let mut result = vec![primitive_fill; len]; - for (offset, v) in parray.typed_data::().iter().enumerate() { - let idx = indices[offset]; - values[idx] = *v; - validity.set_bit(idx, true); + for (v, idx) in values.iter().zip_eq(indices) { + result[*idx] = *v; + validity.set_bit(*idx, true); } let validity = validity.finish(); - if null_fill { - Ok(FlattenedArray::Primitive(PrimitiveArray::from_nullable( - values, - Some(validity.into()), - ))) + let array = if fill_value.is_null() { + PrimitiveArray::from_nullable(result, Some(validity.into())) } else { - Ok(FlattenedArray::Primitive(PrimitiveArray::from(values))) - } + PrimitiveArray::from(result) + }; + Ok(FlattenedArray::Primitive(array)) } impl ScalarAtFn for SparseArray {