Skip to content

Commit

Permalink
feat: teach ScalarValue and PValue is_instance_of (#958)
Browse files Browse the repository at this point in the history
  • Loading branch information
danking authored Oct 2, 2024
1 parent 4300661 commit 1aa072c
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 0 deletions.
7 changes: 7 additions & 0 deletions vortex-array/src/array/sparse/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ impl SparseArray {
if !matches!(indices.dtype(), &DType::IDX) {
vortex_bail!("Cannot use {} as indices", indices.dtype());
}
if !fill_value.is_instance_of(values.dtype()) {
vortex_bail!(
"fill value, {:?}, should be instance of values dtype, {}",
fill_value,
values.dtype(),
);
}
if indices.len() != values.len() {
vortex_bail!(
"Mismatched indices {} and values {} length",
Expand Down
30 changes: 30 additions & 0 deletions vortex-scalar/src/pvalue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ impl PValue {
}
}

pub fn is_instance_of(&self, ptype: &PType) -> bool {
&self.ptype() == ptype
}

#[allow(clippy::transmute_int_to_float, clippy::transmute_float_to_int)]
pub fn reinterpret_cast(&self, ptype: PType) -> Self {
if ptype == self.ptype() {
Expand Down Expand Up @@ -262,3 +266,29 @@ impl_pvalue!(i64, I64);
impl_pvalue!(f16, F16);
impl_pvalue!(f32, F32);
impl_pvalue!(f64, F64);

#[cfg(test)]
mod test {
use vortex_dtype::half::f16;
use vortex_dtype::PType;

use crate::PValue;

#[test]
pub fn test_is_instance_of() {
assert!(PValue::U8(10).is_instance_of(&PType::U8));
assert!(!PValue::U8(10).is_instance_of(&PType::U16));
assert!(!PValue::U8(10).is_instance_of(&PType::I8));
assert!(!PValue::U8(10).is_instance_of(&PType::F16));

assert!(PValue::I8(10).is_instance_of(&PType::I8));
assert!(!PValue::I8(10).is_instance_of(&PType::I16));
assert!(!PValue::I8(10).is_instance_of(&PType::U8));
assert!(!PValue::I8(10).is_instance_of(&PType::F16));

assert!(PValue::F16(f16::from_f32(10.0)).is_instance_of(&PType::F16));
assert!(!PValue::F16(f16::from_f32(10.0)).is_instance_of(&PType::F32));
assert!(!PValue::F16(f16::from_f32(10.0)).is_instance_of(&PType::U16));
assert!(!PValue::F16(f16::from_f32(10.0)).is_instance_of(&PType::I16));
}
}
109 changes: 109 additions & 0 deletions vortex-scalar/src/value.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::sync::Arc;

use vortex_buffer::{Buffer, BufferString};
use vortex_dtype::DType;
use vortex_error::{vortex_err, VortexResult};

use crate::pvalue::PValue;
Expand Down Expand Up @@ -28,6 +29,26 @@ impl ScalarValue {
matches!(self, Self::Null)
}

pub fn is_instance_of(&self, dtype: &DType) -> bool {
match (self, dtype) {
(ScalarValue::Bool(_), DType::Bool(_)) => true,
(ScalarValue::Primitive(pvalue), DType::Primitive(ptype, _)) => {
pvalue.is_instance_of(ptype)
}
(ScalarValue::Buffer(_), DType::Binary(_)) => true,
(ScalarValue::BufferString(_), DType::Utf8(_)) => true,
(ScalarValue::List(values), DType::List(dtype, _)) => {
values.iter().all(|v| v.is_instance_of(dtype))
}
(ScalarValue::List(values), DType::Struct(structdt, _)) => values
.iter()
.zip(structdt.dtypes().to_vec())
.all(|(v, dt)| v.is_instance_of(&dt)),
(ScalarValue::Null, dtype) => dtype.is_nullable(),
(..) => false,
}
}

pub fn as_bool(&self) -> VortexResult<Option<bool>> {
match self {
Self::Null => Ok(None),
Expand Down Expand Up @@ -69,3 +90,91 @@ impl ScalarValue {
}
}
}

#[cfg(test)]
mod test {
use vortex_dtype::{DType, Nullability, PType, StructDType};

use crate::{PValue, ScalarValue};

#[test]
pub fn test_is_instance_of_bool() {
assert!(ScalarValue::Bool(true).is_instance_of(&DType::Bool(Nullability::Nullable)));
assert!(ScalarValue::Bool(true).is_instance_of(&DType::Bool(Nullability::NonNullable)));
assert!(ScalarValue::Bool(false).is_instance_of(&DType::Bool(Nullability::Nullable)));
assert!(ScalarValue::Bool(false).is_instance_of(&DType::Bool(Nullability::NonNullable)));
}

#[test]
pub fn test_is_instance_of_primitive() {
assert!(ScalarValue::Primitive(PValue::F64(0.0))
.is_instance_of(&DType::Primitive(PType::F64, Nullability::NonNullable)));
}

#[test]
pub fn test_is_instance_of_list_and_struct() {
let tbool = DType::Bool(Nullability::NonNullable);
let tboolnull = DType::Bool(Nullability::Nullable);
let tnull = DType::Null;

let bool_null = ScalarValue::List(vec![ScalarValue::Bool(true), ScalarValue::Null].into());
let bool_bool =
ScalarValue::List(vec![ScalarValue::Bool(true), ScalarValue::Bool(false)].into());

fn tlist(element: &DType) -> DType {
DType::List(element.clone().into(), Nullability::NonNullable)
}

assert!(bool_null.is_instance_of(&tlist(&tboolnull)));
assert!(!bool_null.is_instance_of(&tlist(&tbool)));
assert!(bool_bool.is_instance_of(&tlist(&tbool)));
assert!(bool_bool.is_instance_of(&tlist(&tbool)));

fn tstruct(left: &DType, right: &DType) -> DType {
DType::Struct(
StructDType::new(
vec!["left".into(), "right".into()].into(),
vec![left.clone(), right.clone()],
),
Nullability::NonNullable,
)
}

assert!(bool_null.is_instance_of(&tstruct(&tboolnull, &tboolnull)));
assert!(bool_null.is_instance_of(&tstruct(&tbool, &tboolnull)));
assert!(!bool_null.is_instance_of(&tstruct(&tboolnull, &tbool)));
assert!(!bool_null.is_instance_of(&tstruct(&tbool, &tbool)));

assert!(bool_null.is_instance_of(&tstruct(&tbool, &tnull)));
assert!(!bool_null.is_instance_of(&tstruct(&tnull, &tbool)));

assert!(bool_bool.is_instance_of(&tstruct(&tboolnull, &tboolnull)));
assert!(bool_bool.is_instance_of(&tstruct(&tbool, &tboolnull)));
assert!(bool_bool.is_instance_of(&tstruct(&tboolnull, &tbool)));
assert!(bool_bool.is_instance_of(&tstruct(&tbool, &tbool)));

assert!(!bool_bool.is_instance_of(&tstruct(&tbool, &tnull)));
assert!(!bool_bool.is_instance_of(&tstruct(&tnull, &tbool)));
}

#[test]
pub fn test_is_instance_of_null() {
assert!(ScalarValue::Null.is_instance_of(&DType::Bool(Nullability::Nullable)));
assert!(!ScalarValue::Null.is_instance_of(&DType::Bool(Nullability::NonNullable)));

assert!(
ScalarValue::Null.is_instance_of(&DType::Primitive(PType::U8, Nullability::Nullable))
);
assert!(ScalarValue::Null.is_instance_of(&DType::Utf8(Nullability::Nullable)));
assert!(ScalarValue::Null.is_instance_of(&DType::Binary(Nullability::Nullable)));
assert!(ScalarValue::Null.is_instance_of(&DType::Struct(
StructDType::new([].into(), [].into()),
Nullability::Nullable,
)));
assert!(ScalarValue::Null.is_instance_of(&DType::List(
DType::Utf8(Nullability::NonNullable).into(),
Nullability::Nullable
)));
assert!(ScalarValue::Null.is_instance_of(&DType::Null));
}
}

0 comments on commit 1aa072c

Please sign in to comment.