From 048db576c15e20fa4d5dc699f1daedaa79172bbf Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Mon, 4 Mar 2024 23:33:17 +0000 Subject: [PATCH] Cast Bool (#61) --- vortex/src/array/bool/compute.rs | 11 +++++++++++ vortex/src/compute/cast.rs | 14 ++++++++++++++ vortex/src/compute/mod.rs | 7 ++++++- 3 files changed, 31 insertions(+), 1 deletion(-) diff --git a/vortex/src/array/bool/compute.rs b/vortex/src/array/bool/compute.rs index 806f4f410a..a9a80a653a 100644 --- a/vortex/src/array/bool/compute.rs +++ b/vortex/src/array/bool/compute.rs @@ -1,16 +1,27 @@ use crate::array::bool::BoolArray; use crate::array::Array; +use crate::compute::cast::CastBoolFn; use crate::compute::scalar_at::ScalarAtFn; use crate::compute::ArrayCompute; use crate::error::VortexResult; use crate::scalar::{NullableScalar, Scalar}; impl ArrayCompute for BoolArray { + fn cast_bool(&self) -> Option<&dyn CastBoolFn> { + Some(self) + } + fn scalar_at(&self) -> Option<&dyn ScalarAtFn> { Some(self) } } +impl CastBoolFn for BoolArray { + fn cast_bool(&self) -> VortexResult { + Ok(self.clone()) + } +} + impl ScalarAtFn for BoolArray { fn scalar_at(&self, index: usize) -> VortexResult> { if self.is_valid(index) { diff --git a/vortex/src/compute/cast.rs b/vortex/src/compute/cast.rs index 6fd6c0ddae..d5657f2ae1 100644 --- a/vortex/src/compute/cast.rs +++ b/vortex/src/compute/cast.rs @@ -1,3 +1,4 @@ +use crate::array::bool::BoolArray; use crate::array::primitive::PrimitiveArray; use crate::array::Array; use crate::error::{VortexError, VortexResult}; @@ -19,3 +20,16 @@ pub fn cast_primitive(array: &dyn Array, ptype: &PType) -> VortexResult VortexResult; +} + +pub fn cast_bool(array: &dyn Array) -> VortexResult { + array.cast_bool().map(|t| t.cast_bool()).unwrap_or_else(|| { + Err(VortexError::NotImplemented( + "cast_bool", + array.encoding().id(), + )) + }) +} diff --git a/vortex/src/compute/mod.rs b/vortex/src/compute/mod.rs index 429f2f69e6..4391e4ccf7 100644 --- a/vortex/src/compute/mod.rs +++ b/vortex/src/compute/mod.rs @@ -1,4 +1,4 @@ -use cast::CastPrimitiveFn; +use cast::{CastBoolFn, CastPrimitiveFn}; use patch::PatchFn; use scalar_at::ScalarAtFn; use take::TakeFn; @@ -13,6 +13,10 @@ pub mod search_sorted; pub mod take; pub trait ArrayCompute { + fn cast_bool(&self) -> Option<&dyn CastBoolFn> { + None + } + fn cast_primitive(&self) -> Option<&dyn CastPrimitiveFn> { None } @@ -24,6 +28,7 @@ pub trait ArrayCompute { fn scalar_at(&self) -> Option<&dyn ScalarAtFn> { None } + fn take(&self) -> Option<&dyn TakeFn> { None }