diff --git a/deps/fastlanez b/deps/fastlanez index d4ed218868..e1bc9e3ebf 160000 --- a/deps/fastlanez +++ b/deps/fastlanez @@ -1 +1 @@ -Subproject commit d4ed218868fdd8cf5a50f3e13fcbee34bc3af4e4 +Subproject commit e1bc9e3ebfeedaaad19a21db895ed3c458b6ef21 diff --git a/vortex/src/array/constant/compute.rs b/vortex/src/array/constant/compute/mod.rs similarity index 100% rename from vortex/src/array/constant/compute.rs rename to vortex/src/array/constant/compute/mod.rs diff --git a/vortex/src/array/primitive/compute/cast.rs b/vortex/src/array/primitive/compute/cast.rs new file mode 100644 index 0000000000..397f009421 --- /dev/null +++ b/vortex/src/array/primitive/compute/cast.rs @@ -0,0 +1,69 @@ +use crate::array::primitive::PrimitiveArray; +use crate::array::CloneOptionalArray; +use crate::compute::cast::CastPrimitiveFn; +use crate::error::{VortexError, VortexResult}; +use crate::match_each_native_ptype; +use crate::ptype::{NativePType, PType}; + +impl CastPrimitiveFn for PrimitiveArray { + fn cast_primitive(&self, ptype: &PType) -> VortexResult { + if self.ptype() == ptype { + Ok(self.clone()) + } else { + match_each_native_ptype!(ptype, |$T| { + Ok(PrimitiveArray::from_nullable( + cast::<$T>(self)?, + self.validity().clone_optional(), + )) + }) + } + } +} + +fn cast(array: &PrimitiveArray) -> VortexResult> { + match_each_native_ptype!(array.ptype(), |$E| { + array + .typed_data::<$E>() + .iter() + // TODO(ngates): allow configurable checked/unchecked casting + .map(|v| { + T::from(*v).ok_or_else(|| { + VortexError::ComputeError(format!("Failed to cast {} to {:?}", v, T::PTYPE).into()) + }) + }) + .collect() + }) +} + +#[cfg(test)] +mod test { + use crate::array::primitive::PrimitiveArray; + use crate::compute; + use crate::error::VortexError; + use crate::ptype::PType; + + #[test] + fn cast_u32_u8() { + let arr = PrimitiveArray::from_vec(vec![0u32, 10, 200]); + let u8arr = compute::cast::cast_primitive(&arr, &PType::U8).unwrap(); + assert_eq!(u8arr.typed_data::(), vec![0u8, 10, 200]); + } + + #[test] + fn cast_u32_f32() { + let arr = PrimitiveArray::from_vec(vec![0u32, 10, 200]); + let u8arr = compute::cast::cast_primitive(&arr, &PType::F32).unwrap(); + assert_eq!(u8arr.typed_data::(), vec![0.0f32, 10., 200.]); + } + + #[test] + fn cast_i32_u32() { + let arr = PrimitiveArray::from_vec(vec![-1i32]); + assert_eq!( + compute::cast::cast_primitive(&arr, &PType::U32) + .err() + .unwrap(), + VortexError::ComputeError("Failed to cast -1 to U32".into(),) + ) + } +} diff --git a/vortex/src/array/primitive/compute/mod.rs b/vortex/src/array/primitive/compute/mod.rs new file mode 100644 index 0000000000..5516e07ef3 --- /dev/null +++ b/vortex/src/array/primitive/compute/mod.rs @@ -0,0 +1,23 @@ +use crate::array::primitive::PrimitiveArray; +use crate::compute::cast::CastPrimitiveFn; +use crate::compute::patch::PatchFn; +use crate::compute::scalar_at::ScalarAtFn; +use crate::compute::ArrayCompute; + +mod cast; +mod patch; +mod scalar_at; + +impl ArrayCompute for PrimitiveArray { + fn cast_primitive(&self) -> Option<&dyn CastPrimitiveFn> { + Some(self) + } + + fn patch(&self) -> Option<&dyn PatchFn> { + Some(self) + } + + fn scalar_at(&self) -> Option<&dyn ScalarAtFn> { + Some(self) + } +} diff --git a/vortex/src/array/primitive/compute/patch.rs b/vortex/src/array/primitive/compute/patch.rs new file mode 100644 index 0000000000..67c16fb28a --- /dev/null +++ b/vortex/src/array/primitive/compute/patch.rs @@ -0,0 +1,39 @@ +use itertools::Itertools; + +use crate::array::downcast::DowncastArrayBuiltin; +use crate::array::primitive::PrimitiveArray; +use crate::array::sparse::{SparseArray, SparseEncoding}; +use crate::array::{Array, ArrayRef, CloneOptionalArray}; +use crate::compute::patch::PatchFn; +use crate::error::{VortexError, VortexResult}; +use crate::{compute, match_each_native_ptype}; + +impl PatchFn for PrimitiveArray { + fn patch(&self, patch: &dyn Array) -> VortexResult { + match patch.encoding().id() { + &SparseEncoding::ID => patch_with_sparse(self, patch.as_sparse()), + // TODO(ngates): support a default implementation based on iter_arrow? + _ => Err(VortexError::MissingKernel( + "patch", + self.encoding().id(), + vec![patch.encoding().id()], + )), + } + } +} + +fn patch_with_sparse(array: &PrimitiveArray, patch: &SparseArray) -> VortexResult { + let patch_indices = patch.resolved_indices(); + match_each_native_ptype!(array.ptype(), |$T| { + let mut values = Vec::from(array.typed_data::<$T>()); + let patch_values = compute::cast::cast_primitive(patch.values(), array.ptype())?; + for (idx, value) in patch_indices.iter().zip_eq(patch_values.typed_data::<$T>().iter()) { + values[*idx] = *value; + } + Ok(PrimitiveArray::from_nullable( + values, + // TODO(ngates): if patch values has null, we need to patch into the validity buffer + array.validity().clone_optional(), + ).boxed()) + }) +} diff --git a/vortex/src/array/primitive/compute.rs b/vortex/src/array/primitive/compute/scalar_at.rs similarity index 82% rename from vortex/src/array/primitive/compute.rs rename to vortex/src/array/primitive/compute/scalar_at.rs index b11e959fab..87e86d7fde 100644 --- a/vortex/src/array/primitive/compute.rs +++ b/vortex/src/array/primitive/compute/scalar_at.rs @@ -1,17 +1,10 @@ use crate::array::primitive::PrimitiveArray; use crate::array::Array; use crate::compute::scalar_at::ScalarAtFn; -use crate::compute::ArrayCompute; use crate::error::VortexResult; use crate::match_each_native_ptype; use crate::scalar::{NullableScalar, Scalar}; -impl ArrayCompute for PrimitiveArray { - fn scalar_at(&self) -> Option<&dyn ScalarAtFn> { - Some(self) - } -} - impl ScalarAtFn for PrimitiveArray { fn scalar_at(&self, index: usize) -> VortexResult> { if self.is_valid(index) { diff --git a/vortex/src/array/primitive/mod.rs b/vortex/src/array/primitive/mod.rs index 17f19470b2..90e3dd4577 100644 --- a/vortex/src/array/primitive/mod.rs +++ b/vortex/src/array/primitive/mod.rs @@ -9,7 +9,7 @@ use std::sync::{Arc, RwLock}; use allocator_api2::alloc::Allocator; use arrow::alloc::ALIGNMENT as ARROW_ALIGNMENT; use arrow::array::{make_array, ArrayData, AsArray}; -use arrow::buffer::{Buffer, NullBuffer}; +use arrow::buffer::{Buffer, NullBuffer, ScalarBuffer}; use linkme::distributed_slice; use log::debug; @@ -129,6 +129,17 @@ impl PrimitiveArray { pub fn validity(&self) -> Option<&dyn Array> { self.validity.as_deref() } + + pub fn scalar_buffer(&self) -> ScalarBuffer { + ScalarBuffer::from(self.buffer().clone()) + } + + pub fn typed_data(&self) -> &[T] { + if self.ptype() != &T::PTYPE { + panic!("Invalid PType") + } + self.buffer().typed_data() + } } impl Array for PrimitiveArray { diff --git a/vortex/src/array/sparse/mod.rs b/vortex/src/array/sparse/mod.rs index 4f6c8c0e64..5cfdc0b40b 100644 --- a/vortex/src/array/sparse/mod.rs +++ b/vortex/src/array/sparse/mod.rs @@ -3,8 +3,9 @@ use std::iter; use std::sync::{Arc, RwLock}; use arrow::array::AsArray; -use arrow::array::BooleanBufferBuilder; -use arrow::array::{ArrayRef as ArrowArrayRef, PrimitiveArray as ArrowPrimitiveArray}; +use arrow::array::{ + ArrayRef as ArrowArrayRef, BooleanBufferBuilder, PrimitiveArray as ArrowPrimitiveArray, +}; use arrow::buffer::{NullBuffer, ScalarBuffer}; use arrow::datatypes::UInt64Type; use linkme::distributed_slice; @@ -79,6 +80,22 @@ impl SparseArray { pub fn indices(&self) -> &dyn Array { self.indices.as_ref() } + + /// Return indices as a vector of usize with the indices_offset applied. + pub fn resolved_indices(&self) -> Vec { + let mut indices = Vec::with_capacity(self.len()); + self.indices().iter_arrow().for_each(|c| { + indices.extend( + arrow::compute::cast(c.as_ref(), &arrow::datatypes::DataType::UInt64) + .unwrap() + .as_primitive::() + .values() + .into_iter() + .map(|v| (*v as usize) - self.indices_offset), + ) + }); + indices + } } impl Array for SparseArray { @@ -119,16 +136,7 @@ impl Array for SparseArray { fn iter_arrow(&self) -> Box { // Resolve our indices into a vector of usize applying the offset - let mut indices = Vec::with_capacity(self.len()); - self.indices().iter_arrow().for_each(|c| { - indices.extend( - c.as_primitive::() - .values() - .into_iter() - .map(|v| (*v as usize) - self.indices_offset), - ) - }); - + let indices = self.resolved_indices(); let array: ArrowArrayRef = match_arrow_numeric_type!(self.values().dtype(), |$E| { let mut validity = BooleanBufferBuilder::new(self.len()); validity.append_n(self.len(), false); @@ -147,7 +155,6 @@ impl Array for SparseArray { Some(NullBuffer::from(validity.finish())), )) }); - Box::new(iter::once(array)) } diff --git a/vortex/src/compute/as_contiguous.rs b/vortex/src/compute/as_contiguous.rs index a98375397e..37903e7594 100644 --- a/vortex/src/compute/as_contiguous.rs +++ b/vortex/src/compute/as_contiguous.rs @@ -1,5 +1,6 @@ use arrow::buffer::BooleanBuffer; use itertools::Itertools; +use vortex_alloc::{AlignedVec, ALIGNED_ALLOCATOR}; use crate::array::bool::{BoolArray, BoolEncoding}; use crate::array::downcast::DowncastArrayBuiltin; @@ -7,7 +8,6 @@ use crate::array::primitive::{PrimitiveArray, PrimitiveEncoding}; use crate::array::{Array, ArrayRef, CloneOptionalArray}; use crate::error::{VortexError, VortexResult}; use crate::ptype::{match_each_native_ptype, NativePType}; -use vortex_alloc::{AlignedVec, ALIGNED_ALLOCATOR}; pub fn as_contiguous(arrays: Vec) -> VortexResult { if arrays.is_empty() { diff --git a/vortex/src/compute/cast.rs b/vortex/src/compute/cast.rs index 4cd56bee8c..6fd6c0ddae 100644 --- a/vortex/src/compute/cast.rs +++ b/vortex/src/compute/cast.rs @@ -1,6 +1,21 @@ -use crate::dtype::DType; -use crate::scalar::Scalar; +use crate::array::primitive::PrimitiveArray; +use crate::array::Array; +use crate::error::{VortexError, VortexResult}; +use crate::ptype::PType; -pub fn cast_scalar(_value: &dyn Scalar, _dtype: &DType) -> Box { - todo!() +pub trait CastPrimitiveFn { + fn cast_primitive(&self, ptype: &PType) -> VortexResult; +} + +pub fn cast_primitive(array: &dyn Array, ptype: &PType) -> VortexResult { + PType::try_from(array.dtype()).map_err(|_| VortexError::InvalidDType(array.dtype().clone()))?; + array + .cast_primitive() + .map(|t| t.cast_primitive(ptype)) + .unwrap_or_else(|| { + Err(VortexError::NotImplemented( + "cast_primitive", + array.encoding().id(), + )) + }) } diff --git a/vortex/src/compute/mod.rs b/vortex/src/compute/mod.rs index f9273d7829..429f2f69e6 100644 --- a/vortex/src/compute/mod.rs +++ b/vortex/src/compute/mod.rs @@ -1,19 +1,29 @@ -use crate::compute::scalar_at::ScalarAtFn; +use cast::CastPrimitiveFn; +use patch::PatchFn; +use scalar_at::ScalarAtFn; use take::TakeFn; pub mod add; pub mod as_contiguous; pub mod cast; +pub mod patch; pub mod repeat; pub mod scalar_at; pub mod search_sorted; pub mod take; pub trait ArrayCompute { - fn scalar_at(&self) -> Option<&dyn ScalarAtFn> { + fn cast_primitive(&self) -> Option<&dyn CastPrimitiveFn> { + None + } + + fn patch(&self) -> Option<&dyn PatchFn> { None } + fn scalar_at(&self) -> Option<&dyn ScalarAtFn> { + None + } fn take(&self) -> Option<&dyn TakeFn> { None } diff --git a/vortex/src/compute/patch.rs b/vortex/src/compute/patch.rs new file mode 100644 index 0000000000..f58a650cb0 --- /dev/null +++ b/vortex/src/compute/patch.rs @@ -0,0 +1,22 @@ +use crate::array::{Array, ArrayRef}; +use crate::error::{VortexError, VortexResult}; + +pub trait PatchFn { + fn patch(&self, patch: &dyn Array) -> VortexResult; +} + +/// Returns a new array where the non-null values from the patch array are replaced in the original. +pub fn patch(array: &dyn Array, patch: &dyn Array) -> VortexResult { + if array.len() != patch.len() { + return Err(VortexError::InvalidArgument( + "patch array must have the same length as the original array".into(), + )); + } + + // TODO(ngates): check the dtype matches + + array + .patch() + .map(|t| t.patch(patch)) + .unwrap_or_else(|| Err(VortexError::NotImplemented("take", array.encoding().id()))) +} diff --git a/vortex/src/compute/scalar_at.rs b/vortex/src/compute/scalar_at.rs index cf37161247..1b6b90bc23 100644 --- a/vortex/src/compute/scalar_at.rs +++ b/vortex/src/compute/scalar_at.rs @@ -15,9 +15,9 @@ pub fn scalar_at(array: &dyn Array, index: usize) -> VortexResult VortexResult { - array.take().map(|t| t.take(indices)).unwrap_or_else(|| { - // TODO(ngates): default implementation of decode and then try again - Err(VortexError::ComputeError( - format!("take not implemented for {}", &array.encoding().id()).into(), - )) - }) + array + .take() + .map(|t| t.take(indices)) + .unwrap_or_else(|| Err(VortexError::NotImplemented("take", array.encoding().id()))) } diff --git a/vortex/src/error.rs b/vortex/src/error.rs index 9fa652a6f9..c82575f124 100644 --- a/vortex/src/error.rs +++ b/vortex/src/error.rs @@ -51,6 +51,14 @@ pub enum VortexError { LengthMismatch, #[error("{0}")] ComputeError(ErrString), + #[error("{0}")] + InvalidArgument(ErrString), + // Used when a function is not implemented for a given array type. + #[error("function {0} not implemented for {1}")] + NotImplemented(&'static str, &'static EncodingId), + // Used when a function is implemented for an array type, but the RHS is not supported. + #[error("missing kernel {0} for {1} and {2:?}")] + MissingKernel(&'static str, &'static EncodingId, Vec<&'static EncodingId>), #[error("invalid data type: {0}")] InvalidDType(DType), #[error("invalid physical type: {0:?}")]