From df77be3e2c1539fd57d77a9a394a7d7209b8e7ad Mon Sep 17 00:00:00 2001 From: Dan King Date: Tue, 1 Oct 2024 13:44:03 -0400 Subject: [PATCH] feat: SparseArray uses ScalarValue instead of Scalar (#955) --- encodings/alp/src/compress.rs | 4 +- .../fastlanes/src/bitpacking/compress.rs | 4 +- .../src/bitpacking/compute/scalar_at.rs | 4 +- encodings/fastlanes/src/for/compress.rs | 4 +- encodings/runend/src/compute.rs | 4 +- vortex-array/src/array/sparse/compute/mod.rs | 25 ++++----- .../src/array/sparse/compute/slice.rs | 25 +++------ vortex-array/src/array/sparse/compute/take.rs | 5 +- vortex-array/src/array/sparse/flatten.rs | 8 +-- vortex-array/src/array/sparse/mod.rs | 51 ++++++++----------- vortex-array/src/array/sparse/variants.rs | 8 +-- vortex-array/src/canonical.rs | 4 +- vortex-scalar/src/bool.rs | 18 ++++++- vortex-scalar/src/primitive.rs | 42 +++++++++++---- 14 files changed, 108 insertions(+), 98 deletions(-) diff --git a/encodings/alp/src/compress.rs b/encodings/alp/src/compress.rs index 6c4dfeba7a..12e8388fa9 100644 --- a/encodings/alp/src/compress.rs +++ b/encodings/alp/src/compress.rs @@ -3,7 +3,7 @@ use vortex::validity::Validity; use vortex::{Array, ArrayDType, ArrayDef, IntoArray, IntoArrayVariant}; use vortex_dtype::{NativePType, PType}; use vortex_error::{vortex_bail, VortexExpect as _, VortexResult}; -use vortex_scalar::Scalar; +use vortex_scalar::ScalarValue; use crate::alp::ALPFloat; use crate::array::ALPArray; @@ -42,7 +42,7 @@ where PrimitiveArray::from(exc_pos).into_array(), PrimitiveArray::from_vec(exc, Validity::AllValid).into_array(), len, - Scalar::null(values.dtype().as_nullable()), + ScalarValue::Null, ) .vortex_expect("Failed to create SparseArray for ALP patches") .into_array() diff --git a/encodings/fastlanes/src/bitpacking/compress.rs b/encodings/fastlanes/src/bitpacking/compress.rs index 62c1882e73..3d10335bce 100644 --- a/encodings/fastlanes/src/bitpacking/compress.rs +++ b/encodings/fastlanes/src/bitpacking/compress.rs @@ -9,7 +9,7 @@ use vortex_dtype::{ match_each_integer_ptype, match_each_unsigned_integer_ptype, NativePType, PType, }; use vortex_error::{vortex_bail, vortex_err, VortexResult, VortexUnwrap}; -use vortex_scalar::Scalar; +use vortex_scalar::{Scalar, ScalarValue}; use crate::BitPackedArray; @@ -131,7 +131,7 @@ pub fn bitpack_patches( indices.into_array(), PrimitiveArray::from_vec(values, Validity::AllValid).into_array(), parray.len(), - Scalar::null(parray.dtype().as_nullable()), + ScalarValue::Null, ) .vortex_unwrap() .into_array() diff --git a/encodings/fastlanes/src/bitpacking/compute/scalar_at.rs b/encodings/fastlanes/src/bitpacking/compute/scalar_at.rs index 48bedc9c13..08d3e3cdc8 100644 --- a/encodings/fastlanes/src/bitpacking/compute/scalar_at.rs +++ b/encodings/fastlanes/src/bitpacking/compute/scalar_at.rs @@ -30,7 +30,7 @@ mod test { use vortex::IntoArray; use vortex_buffer::Buffer; use vortex_dtype::{DType, Nullability, PType}; - use vortex_scalar::Scalar; + use vortex_scalar::{Scalar, ScalarValue}; use crate::BitPackedArray; @@ -45,7 +45,7 @@ mod test { PrimitiveArray::from(vec![1u64]).into_array(), PrimitiveArray::from_vec(vec![999u32], Validity::AllValid).into_array(), 8, - Scalar::null(DType::Primitive(PType::U32, Nullability::Nullable)), + ScalarValue::Null, ) .unwrap() .into_array(), diff --git a/encodings/fastlanes/src/for/compress.rs b/encodings/fastlanes/src/for/compress.rs index d01837d0f6..a36482b53b 100644 --- a/encodings/fastlanes/src/for/compress.rs +++ b/encodings/fastlanes/src/for/compress.rs @@ -6,7 +6,7 @@ use vortex::validity::LogicalValidity; use vortex::{Array, ArrayDType, IntoArray, IntoArrayVariant}; use vortex_dtype::{match_each_integer_ptype, NativePType}; use vortex_error::{vortex_err, VortexResult}; -use vortex_scalar::Scalar; +use vortex_scalar::{Scalar, ScalarValue}; use crate::FoRArray; @@ -41,7 +41,7 @@ pub fn for_compress(array: &PrimitiveArray) -> VortexResult { ConstantArray::new(Scalar::zero::<$T>(array.dtype().nullability()), valid_len) .into_array(), array.len(), - Scalar::null(array.dtype().clone()), + ScalarValue::Null, )? .into_array() } diff --git a/encodings/runend/src/compute.rs b/encodings/runend/src/compute.rs index 96686a18c2..e5cba2eea4 100644 --- a/encodings/runend/src/compute.rs +++ b/encodings/runend/src/compute.rs @@ -5,7 +5,7 @@ use vortex::validity::Validity; use vortex::{Array, ArrayDType, IntoArray, IntoArrayVariant}; use vortex_dtype::match_each_integer_ptype; use vortex_error::{VortexExpect as _, VortexResult}; -use vortex_scalar::Scalar; +use vortex_scalar::{Scalar, ScalarValue}; use crate::RunEndArray; @@ -86,7 +86,7 @@ impl TakeFn for RunEndArray { dense_nonnull_indices, filtered_values, length, - Scalar::null(self.dtype().clone()), + ScalarValue::Null, )? .into_array() } diff --git a/vortex-array/src/array/sparse/compute/mod.rs b/vortex-array/src/array/sparse/compute/mod.rs index 456ed8d380..a6d96b9855 100644 --- a/vortex-array/src/array/sparse/compute/mod.rs +++ b/vortex-array/src/array/sparse/compute/mod.rs @@ -9,7 +9,7 @@ use crate::compute::{ search_sorted, take, ArrayCompute, FilterFn, SearchResult, SearchSortedFn, SearchSortedSide, SliceFn, TakeFn, }; -use crate::{Array, ArrayDType, IntoArray, IntoArrayVariant}; +use crate::{Array, IntoArray, IntoArrayVariant}; mod slice; mod take; @@ -38,18 +38,16 @@ impl ArrayCompute for SparseArray { impl ScalarAtFn for SparseArray { fn scalar_at(&self, index: usize) -> VortexResult { - match self.search_index(index)?.to_found() { - None => self.fill_value().clone().cast(self.dtype()), - Some(idx) => scalar_at_unchecked(&self.values(), idx).cast(self.dtype()), - } + Ok(match self.search_index(index)?.to_found() { + None => self.fill_scalar(), + Some(idx) => scalar_at_unchecked(&self.values(), idx), + }) } fn scalar_at_unchecked(&self, index: usize) -> Scalar { match self.search_index(index).vortex_unwrap().to_found() { - None => self.fill_value().clone().cast(self.dtype()).vortex_unwrap(), - Some(idx) => scalar_at_unchecked(&self.values(), idx) - .cast(self.dtype()) - .vortex_unwrap(), + None => self.fill_scalar(), + Some(idx) => scalar_at_unchecked(&self.values(), idx), } } } @@ -115,8 +113,7 @@ impl FilterFn for SparseArray { #[cfg(test)] mod test { use rstest::{fixture, rstest}; - use vortex_dtype::{DType, Nullability, PType}; - use vortex_scalar::Scalar; + use vortex_scalar::ScalarValue; use crate::array::primitive::PrimitiveArray; use crate::array::sparse::SparseArray; @@ -131,7 +128,7 @@ mod test { PrimitiveArray::from(vec![2u64, 9, 15]).into_array(), PrimitiveArray::from_vec(vec![33_i32, 44, 55], Validity::AllValid).into_array(), 20, - Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)), + ScalarValue::Null, ) .unwrap() .into_array() @@ -176,7 +173,7 @@ mod test { PrimitiveArray::from(vec![0u64]).into_array(), PrimitiveArray::from_vec(vec![0u8], Validity::AllValid).into_array(), 2, - Scalar::null(DType::Primitive(PType::U8, Nullability::Nullable)), + ScalarValue::Null, ) .unwrap() .into_array(); @@ -216,7 +213,7 @@ mod test { PrimitiveArray::from(vec![0_u64, 3, 6]).into_array(), PrimitiveArray::from_vec(vec![33_i32, 44, 55], Validity::AllValid).into_array(), 7, - Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)), + ScalarValue::Null, ) .unwrap() .into_array(); diff --git a/vortex-array/src/array/sparse/compute/slice.rs b/vortex-array/src/array/sparse/compute/slice.rs index 799291d49c..97bb7f200f 100644 --- a/vortex-array/src/array/sparse/compute/slice.rs +++ b/vortex-array/src/array/sparse/compute/slice.rs @@ -23,9 +23,6 @@ impl SliceFn for SparseArray { #[cfg(test)] mod tests { - use vortex_dtype::Nullability; - use vortex_scalar::Scalar; - use super::*; use crate::IntoArrayVariant; @@ -34,14 +31,9 @@ mod tests { let values = vec![15_u32, 135, 13531, 42].into_array(); let indices = vec![10_u64, 11, 50, 100].into_array(); - let sparse = SparseArray::try_new( - indices.clone(), - values, - 101, - Scalar::primitive(0_u32, Nullability::NonNullable), - ) - .unwrap() - .into_array(); + let sparse = SparseArray::try_new(indices.clone(), values, 101, 0_u32.into()) + .unwrap() + .into_array(); let sliced = slice(&sparse, 15, 100).unwrap(); assert_eq!(sliced.len(), 100 - 15); @@ -59,14 +51,9 @@ mod tests { let values = vec![15_u32, 135, 13531, 42].into_array(); let indices = vec![10_u64, 11, 50, 100].into_array(); - let sparse = SparseArray::try_new( - indices.clone(), - values, - 101, - Scalar::primitive(0_u32, Nullability::NonNullable), - ) - .unwrap() - .into_array(); + let sparse = SparseArray::try_new(indices.clone(), values, 101, 0_u32.into()) + .unwrap() + .into_array(); let sliced = slice(&sparse, 15, 100).unwrap(); assert_eq!(sliced.len(), 100 - 15); diff --git a/vortex-array/src/array/sparse/compute/take.rs b/vortex-array/src/array/sparse/compute/take.rs index 483ea01deb..2b44ecbfa5 100644 --- a/vortex-array/src/array/sparse/compute/take.rs +++ b/vortex-array/src/array/sparse/compute/take.rs @@ -84,8 +84,7 @@ fn take_search_sorted( #[cfg(test)] mod test { use itertools::Itertools; - use vortex_dtype::{DType, Nullability, PType}; - use vortex_scalar::Scalar; + use vortex_scalar::ScalarValue; use crate::array::primitive::PrimitiveArray; use crate::array::sparse::compute::take::take_map; @@ -100,7 +99,7 @@ mod test { PrimitiveArray::from_vec(vec![1.23f64, 0.47, 9.99, 3.5], Validity::AllValid) .into_array(), 100, - Scalar::null(DType::Primitive(PType::F64, Nullability::Nullable)), + ScalarValue::Null, ) .unwrap() .into_array() diff --git a/vortex-array/src/array/sparse/flatten.rs b/vortex-array/src/array/sparse/flatten.rs index 6110be7999..17049c1a92 100644 --- a/vortex-array/src/array/sparse/flatten.rs +++ b/vortex-array/src/array/sparse/flatten.rs @@ -2,7 +2,7 @@ use arrow_buffer::{BooleanBuffer, BooleanBufferBuilder, MutableBuffer}; use itertools::Itertools; use vortex_dtype::{match_each_native_ptype, DType, NativePType}; use vortex_error::{VortexError, VortexResult}; -use vortex_scalar::Scalar; +use vortex_scalar::ScalarValue; use crate::array::primitive::PrimitiveArray; use crate::array::sparse::SparseArray; @@ -46,7 +46,7 @@ fn canonicalize_sparse_bools( values: BooleanBuffer, indices: &[usize], len: usize, - fill_value: &Scalar, + fill_value: &ScalarValue, mut validity_buffer: BooleanBufferBuilder, ) -> VortexResult { let fill_bool: bool = if fill_value.is_null() { @@ -67,12 +67,12 @@ fn canonicalize_sparse_bools( } fn canonicalize_sparse_primitives< - T: NativePType + for<'a> TryFrom<&'a Scalar, Error = VortexError>, + T: NativePType + for<'a> TryFrom<&'a ScalarValue, Error = VortexError>, >( values: &[T], indices: &[usize], len: usize, - fill_value: &Scalar, + fill_value: &ScalarValue, mut validity: BooleanBufferBuilder, ) -> VortexResult { let primitive_fill = if fill_value.is_null() { diff --git a/vortex-array/src/array/sparse/mod.rs b/vortex-array/src/array/sparse/mod.rs index 95bde7fa0b..6853a641e6 100644 --- a/vortex-array/src/array/sparse/mod.rs +++ b/vortex-array/src/array/sparse/mod.rs @@ -1,7 +1,7 @@ use ::serde::{Deserialize, Serialize}; use vortex_dtype::{match_each_integer_ptype, DType}; use vortex_error::{vortex_bail, vortex_panic, VortexExpect as _, VortexResult}; -use vortex_scalar::Scalar; +use vortex_scalar::{Scalar, ScalarValue}; use crate::array::constant::ConstantArray; use crate::compute::unary::scalar_at; @@ -23,7 +23,7 @@ pub struct SparseMetadata { // Offset value for patch indices as a result of slicing indices_offset: usize, indices_len: usize, - fill_value: Scalar, + fill_value: ScalarValue, } impl SparseArray { @@ -31,7 +31,7 @@ impl SparseArray { indices: Array, values: Array, len: usize, - fill_value: Scalar, + fill_value: ScalarValue, ) -> VortexResult { Self::try_new_with_offset(indices, values, len, 0, fill_value) } @@ -41,18 +41,11 @@ impl SparseArray { values: Array, len: usize, indices_offset: usize, - fill_value: Scalar, + fill_value: ScalarValue, ) -> VortexResult { if !matches!(indices.dtype(), &DType::IDX) { vortex_bail!("Cannot use {} as indices", indices.dtype()); } - if values.dtype() != fill_value.dtype() { - vortex_bail!( - "Mismatched fill value dtype {} and values dtype {}", - fill_value.dtype(), - values.dtype(), - ); - } if indices.len() != values.len() { vortex_bail!( "Mismatched indices {} and values {} length", @@ -102,10 +95,15 @@ impl SparseArray { } #[inline] - pub fn fill_value(&self) -> &Scalar { + pub fn fill_value(&self) -> &ScalarValue { &self.metadata().fill_value } + #[inline] + pub fn fill_scalar(&self) -> Scalar { + Scalar::new(self.dtype().clone(), self.fill_value().clone()) + } + /// Returns the position or the insertion point of a given index in the indices array. fn search_index(&self, index: usize) -> VortexResult { search_sorted( @@ -196,7 +194,7 @@ impl ArrayValidity for SparseArray { #[cfg(test)] mod test { use itertools::Itertools; - use vortex_dtype::Nullability::{self, Nullable}; + use vortex_dtype::Nullability::Nullable; use vortex_dtype::{DType, PType}; use vortex_error::VortexError; use vortex_scalar::Scalar; @@ -221,9 +219,14 @@ mod test { let mut values = vec![100i32, 200, 300].into_array(); values = try_cast(&values, fill_value.dtype()).unwrap(); - SparseArray::try_new(vec![2u64, 5, 8].into_array(), values, 10, fill_value) - .unwrap() - .into_array() + SparseArray::try_new( + vec![2u64, 5, 8].into_array(), + values, + 10, + fill_value.value().clone(), + ) + .unwrap() + .into_array() } fn assert_sparse_array(sparse: &Array, values: &[Option]) { @@ -372,13 +375,7 @@ mod test { let values = vec![15_u32, 135, 13531, 42].into_array(); let indices = vec![10_u64, 11, 50, 100].into_array(); - SparseArray::try_new( - indices.clone(), - values, - 100, - Scalar::primitive(0_u32, Nullability::NonNullable), - ) - .unwrap(); + SparseArray::try_new(indices.clone(), values, 100, 0_u32.into()).unwrap(); } #[test] @@ -386,12 +383,6 @@ mod test { let values = vec![15_u32, 135, 13531, 42].into_array(); let indices = vec![10_u64, 11, 50, 100].into_array(); - SparseArray::try_new( - indices.clone(), - values, - 101, - Scalar::primitive(0_u32, Nullability::NonNullable), - ) - .unwrap(); + SparseArray::try_new(indices.clone(), values, 101, 0_u32.into()).unwrap(); } } diff --git a/vortex-array/src/array/sparse/variants.rs b/vortex-array/src/array/sparse/variants.rs index b0109e1519..68db18eb52 100644 --- a/vortex-array/src/array/sparse/variants.rs +++ b/vortex-array/src/array/sparse/variants.rs @@ -68,7 +68,7 @@ impl StructArrayTrait for SparseArray { let values = self .values() .with_dyn(|s| s.as_struct_array().and_then(|s| s.field(idx)))?; - let scalar = StructScalar::try_from(self.fill_value()) + let scalar = StructScalar::try_new(self.dtype(), self.fill_value()) .ok()? .field_by_idx(idx)?; @@ -78,7 +78,7 @@ impl StructArrayTrait for SparseArray { values, self.len(), self.indices_offset(), - scalar, + scalar.value().clone(), ) .ok()? .into_array(), @@ -91,14 +91,14 @@ impl StructArrayTrait for SparseArray { .ok_or_else(|| vortex_err!("Chunk was not a StructArray"))? .project(projection) })?; - let scalar = StructScalar::try_from(self.fill_value())?.project(projection)?; + let scalar = StructScalar::try_new(self.dtype(), self.fill_value())?.project(projection)?; SparseArray::try_new_with_offset( self.indices().clone(), values, self.len(), self.indices_offset(), - scalar, + scalar.value().clone(), ) .map(|a| a.into_array()) } diff --git a/vortex-array/src/canonical.rs b/vortex-array/src/canonical.rs index bc213924ee..96948101b6 100644 --- a/vortex-array/src/canonical.rs +++ b/vortex-array/src/canonical.rs @@ -455,8 +455,6 @@ mod test { }; use arrow_buffer::NullBufferBuilder; use arrow_schema::{DataType, Field}; - use vortex_dtype::Nullability; - use vortex_scalar::Scalar; use crate::array::{PrimitiveArray, SparseArray, StructArray}; use crate::arrow::FromArrowArray; @@ -483,7 +481,7 @@ mod test { PrimitiveArray::from_vec(vec![0u64; 1], Validity::NonNullable).into_array(), PrimitiveArray::from_vec(vec![100i64], Validity::NonNullable).into_array(), 1, - Scalar::primitive(0i64, Nullability::NonNullable), + 0i64.into(), ) .unwrap() .into_array(), diff --git a/vortex-scalar/src/bool.rs b/vortex-scalar/src/bool.rs index 17c712d7d0..263a8ec872 100644 --- a/vortex-scalar/src/bool.rs +++ b/vortex-scalar/src/bool.rs @@ -68,11 +68,27 @@ impl From for Scalar { fn from(value: bool) -> Self { Self { dtype: DType::Bool(NonNullable), - value: ScalarValue::Bool(value), + value: value.into(), } } } +impl TryFrom<&ScalarValue> for bool { + type Error = VortexError; + + fn try_from(value: &ScalarValue) -> VortexResult { + value + .as_bool()? + .ok_or_else(|| vortex_err!("Can't extract present value from null scalar")) + } +} + +impl From for ScalarValue { + fn from(value: bool) -> Self { + ScalarValue::Bool(value) + } +} + #[cfg(test)] mod test { use super::*; diff --git a/vortex-scalar/src/primitive.rs b/vortex-scalar/src/primitive.rs index c84a553748..65cb77101e 100644 --- a/vortex-scalar/src/primitive.rs +++ b/vortex-scalar/src/primitive.rs @@ -133,7 +133,7 @@ macro_rules! primitive_scalar { fn from(value: $T) -> Self { Scalar { dtype: DType::Primitive(<$T>::PTYPE, Nullability::NonNullable), - value: ScalarValue::Primitive(value.into()), + value: value.into(), } } } @@ -142,9 +142,7 @@ macro_rules! primitive_scalar { fn from(value: Option<$T>) -> Self { Scalar { dtype: DType::Primitive(<$T>::PTYPE, Nullability::Nullable), - value: value - .map(|v| ScalarValue::Primitive(v.into())) - .unwrap_or_else(|| ScalarValue::Null), + value: value.into(), } } } @@ -167,6 +165,20 @@ macro_rules! primitive_scalar { } } + impl From<$T> for ScalarValue { + fn from(value: $T) -> Self { + ScalarValue::Primitive(value.into()) + } + } + + impl From> for ScalarValue { + fn from(value: Option<$T>) -> Self { + value + .map(|v| ScalarValue::Primitive(v.into())) + .unwrap_or_else(|| ScalarValue::Null) + } + } + impl TryFrom<&ScalarValue> for $T { type Error = VortexError; @@ -211,11 +223,21 @@ impl TryFrom<&Scalar> for usize { type Error = VortexError; fn try_from(value: &Scalar) -> Result { - u64::try_from( - value - .cast(&DType::Primitive(PType::U64, Nullability::NonNullable))? - .as_ref(), - ) - .map(|v| v as Self) + value.value().try_into() + } +} + +impl From for ScalarValue { + fn from(value: usize) -> Self { + ScalarValue::Primitive(PValue::U64(value as u64)) + } +} + +/// Read a scalar as usize. For usize only, we implicitly cast for better ergonomics. +impl TryFrom<&ScalarValue> for usize { + type Error = VortexError; + + fn try_from(value: &ScalarValue) -> Result { + u64::try_from(value).map(|v| v as Self) } }