Skip to content

Commit

Permalink
remove top-level nullability on Extension
Browse files Browse the repository at this point in the history
  • Loading branch information
a10y committed Oct 9, 2024
1 parent c6e7add commit 44b49f6
Show file tree
Hide file tree
Showing 21 changed files with 105 additions and 151 deletions.
20 changes: 7 additions & 13 deletions encodings/datetime-parts/src/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ use vortex::compute::{slice, take, ArrayCompute, SliceFn, TakeFn};
use vortex::validity::ArrayValidity;
use vortex::{Array, ArrayDType, IntoArray, IntoArrayVariant};
use vortex_datetime_dtype::{TemporalMetadata, TimeUnit};
use vortex_dtype::{DType, PType};
use vortex_dtype::DType;
use vortex_error::{vortex_bail, VortexResult, VortexUnwrap as _};
use vortex_scalar::Scalar;
use vortex_scalar::{Scalar, ScalarValue};

use crate::DateTimePartsArray;

Expand Down Expand Up @@ -51,7 +51,7 @@ impl SliceFn for DateTimePartsArray {

impl ScalarAtFn for DateTimePartsArray {
fn scalar_at(&self, index: usize) -> VortexResult<Scalar> {
let DType::Extension(ext, nullability) = self.dtype().clone() else {
let DType::Extension(ext) = self.dtype().clone() else {
vortex_bail!(
"DateTimePartsArray must have extension dtype, found {}",
self.dtype()
Expand All @@ -63,10 +63,7 @@ impl ScalarAtFn for DateTimePartsArray {
};

if !self.is_valid(index) {
return Ok(Scalar::extension(
ext,
Scalar::null(DType::Primitive(PType::I64, nullability)),
));
return Ok(Scalar::extension(ext, ScalarValue::Null));
}

let divisor = match time_unit {
Expand All @@ -83,10 +80,7 @@ impl ScalarAtFn for DateTimePartsArray {

let scalar = days * 86_400 * divisor + seconds * divisor + subseconds;

Ok(Scalar::extension(
ext,
Scalar::primitive(scalar, nullability),
))
Ok(Scalar::extension(ext, scalar.into()))
}

fn scalar_at_unchecked(&self, index: usize) -> Scalar {
Expand All @@ -98,7 +92,7 @@ impl ScalarAtFn for DateTimePartsArray {
///
/// Enforces that the passed array is actually a [DateTimePartsArray] with proper metadata.
pub fn decode_to_temporal(array: &DateTimePartsArray) -> VortexResult<TemporalArray> {
let DType::Extension(ext, _) = array.dtype().clone() else {
let DType::Extension(ext) = array.dtype().clone() else {
vortex_bail!(ComputeError: "expected dtype to be DType::Extension variant")
};

Expand Down Expand Up @@ -187,7 +181,7 @@ mod test {
assert_eq!(validity, raw_millis.validity());

let date_times = DateTimePartsArray::try_new(
DType::Extension(temporal_array.ext_dtype().clone(), validity.nullability()),
DType::Extension(temporal_array.ext_dtype().clone()),
days,
seconds,
subseconds,
Expand Down
11 changes: 8 additions & 3 deletions pyvortex/src/python_repr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,18 @@ impl Display for DTypePythonRepr<'_> {
n.python_repr()
),
DType::List(c, n) => write!(f, "list({}, {})", c.python_repr(), n.python_repr()),
DType::Extension(ext, n) => {
write!(f, "ext(\"{}\", ", ext.id().python_repr())?;
DType::Extension(ext) => {
write!(
f,
"ext(\"{}\", {}, ",
ext.id().python_repr(),
ext.scalars_dtype().python_repr()
)?;
match ext.metadata() {
None => write!(f, "None")?,
Some(metadata) => write!(f, "{}", metadata.python_repr())?,
};
write!(f, ", {})", n.python_repr())
write!(f, ", {})", ext.scalars_dtype().nullability().python_repr())
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion vortex-array/src/array/chunked/canonical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ pub(crate) fn try_canonicalize_chunks(
// / \
// storage storage
//
DType::Extension(ext_dtype, _) => {
DType::Extension(ext_dtype) => {
// Recursively apply canonicalization and packing to the storage array backing
// each chunk of the extension array.
let storage_chunks: Vec<Array> = chunks
Expand Down
21 changes: 4 additions & 17 deletions vortex-array/src/array/constant/variants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use std::sync::Arc;

use vortex_dtype::field::Field;
use vortex_dtype::{DType, PType};
use vortex_error::{vortex_panic, VortexError, VortexExpect as _, VortexResult};
use vortex_scalar::{ExtScalar, Scalar, ScalarValue, StructScalar};
use vortex_error::{VortexError, VortexExpect as _, VortexResult};
use vortex_scalar::{Scalar, ScalarValue, StructScalar};

use crate::array::constant::ConstantArray;
use crate::iter::{Accessor, AccessorRef};
Expand Down Expand Up @@ -203,22 +203,9 @@ impl ListArrayTrait for ConstantArray {}

impl ExtensionArrayTrait for ConstantArray {
fn storage_array(&self) -> Array {
let scalar_ext = ExtScalar::try_new(self.dtype(), self.scalar_value())
.vortex_expect("Expected an extension scalar");

// FIXME(ngates): there's not enough information to get the storage array.
let n = self.dtype().nullability();
let storage_dtype = match scalar_ext.value() {
ScalarValue::Bool(_) => DType::Binary(n),
ScalarValue::Primitive(pvalue) => DType::Primitive(pvalue.ptype(), n),
ScalarValue::Buffer(_) => DType::Binary(n),
ScalarValue::BufferString(_) => DType::Utf8(n),
ScalarValue::List(_) => vortex_panic!("List not supported"),
ScalarValue::Null => DType::Null,
};

let storage_dtype = self.ext_dtype().scalars_dtype().clone();
ConstantArray::new(
Scalar::new(storage_dtype, scalar_ext.value().clone()),
Scalar::new(storage_dtype, self.scalar_value().clone()),
self.len(),
)
.into_array()
Expand Down
4 changes: 2 additions & 2 deletions vortex-array/src/array/extension/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,14 @@ impl ScalarAtFn for ExtensionArray {
fn scalar_at(&self, index: usize) -> VortexResult<Scalar> {
Ok(Scalar::extension(
self.ext_dtype().clone(),
scalar_at(self.storage(), index)?,
scalar_at(self.storage(), index)?.into_value(),
))
}

fn scalar_at_unchecked(&self, index: usize) -> Scalar {
Scalar::extension(
self.ext_dtype().clone(),
scalar_at_unchecked(self.storage(), index),
scalar_at_unchecked(self.storage(), index).into_value(),
)
}
}
Expand Down
18 changes: 10 additions & 8 deletions vortex-array/src/array/extension/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@ mod compute;
impl_encoding!("vortex.ext", ids::EXTENSION, Extension);

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExtensionMetadata {
storage_dtype: DType,
}
pub struct ExtensionMetadata;

impl Display for ExtensionMetadata {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Expand All @@ -28,12 +26,16 @@ impl Display for ExtensionMetadata {

impl ExtensionArray {
pub fn new(ext_dtype: ExtDType, storage: Array) -> Self {
assert_eq!(
ext_dtype.scalars_dtype(),
storage.dtype(),
"ExtensionArray: scalars_dtype must match storage array DType",
);

Self::try_from_parts(
DType::Extension(ext_dtype, storage.dtype().nullability()),
DType::Extension(ext_dtype),
storage.len(),
ExtensionMetadata {
storage_dtype: storage.dtype().clone(),
},
ExtensionMetadata,
[storage].into(),
Default::default(),
)
Expand All @@ -42,7 +44,7 @@ impl ExtensionArray {

pub fn storage(&self) -> Array {
self.as_ref()
.child(0, &self.metadata().storage_dtype, self.len())
.child(0, self.ext_dtype().scalars_dtype(), self.len())
.vortex_expect("Missing storage array for ExtensionArray")
}

Expand Down
3 changes: 1 addition & 2 deletions vortex-array/src/arrow/dtype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,7 @@ impl FromArrowType<&Field> for DType {
| DataType::Time32(_)
| DataType::Time64(_)
| DataType::Timestamp(..) => Extension(
make_temporal_ext_dtype(field.data_type()),
field.is_nullable().into(),
make_temporal_ext_dtype(field.data_type()).with_scalars_nullability(nullability),
),
DataType::List(e) | DataType::LargeList(e) => {
List(Arc::new(Self::from_arrow(e.as_ref())), nullability)
Expand Down
2 changes: 1 addition & 1 deletion vortex-array/src/variants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ pub trait ListArrayTrait: ArrayTrait {}

pub trait ExtensionArrayTrait: ArrayTrait {
fn ext_dtype(&self) -> &ExtDType {
let DType::Extension(ext_dtype, _nullability) = self.dtype() else {
let DType::Extension(ext_dtype) = self.dtype() else {
vortex_panic!("Expected ExtDType")
};
ext_dtype
Expand Down
15 changes: 6 additions & 9 deletions vortex-datafusion/src/datatype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ pub(crate) fn infer_data_type(dtype: &DType) -> DataType {
dtype.is_nullable(),
)))
}
DType::Extension(ext_dtype, _) => {
DType::Extension(ext_dtype) => {
// Try and match against the known extension DTypes.
if is_temporal_ext_type(ext_dtype.id()) {
make_arrow_temporal_dtype(ext_dtype)
Expand Down Expand Up @@ -168,14 +168,11 @@ mod test {
#[test]
#[should_panic]
fn test_dtype_conversion_panics() {
let _ = infer_data_type(&DType::Extension(
ExtDType::new(
ExtID::from("my-fake-ext-dtype"),
Arc::new(PType::I32.into()),
None,
),
Nullability::NonNullable,
));
let _ = infer_data_type(&DType::Extension(ExtDType::new(
ExtID::from("my-fake-ext-dtype"),
Arc::new(PType::I32.into()),
None,
)));
}

#[test]
Expand Down
15 changes: 8 additions & 7 deletions vortex-dtype/src/dtype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pub enum DType {
Binary(Nullability),
Struct(StructDType, Nullability),
List(Arc<DType>, Nullability),
Extension(ExtDType, Nullability),
Extension(ExtDType),
}

impl DType {
Expand All @@ -53,7 +53,7 @@ impl DType {
Binary(n) => matches!(n, Nullable),
Struct(_, n) => matches!(n, Nullable),
List(_, n) => matches!(n, Nullable),
Extension(_, n) => matches!(n, Nullable),
Extension(ext_dtype) => ext_dtype.scalars_dtype().is_nullable(),
}
}

Expand All @@ -74,7 +74,7 @@ impl DType {
Binary(_) => Binary(nullability),
Struct(st, _) => Struct(st.clone(), nullability),
List(c, _) => List(c.clone(), nullability),
Extension(ext, _) => Extension(ext.clone(), nullability),
Extension(ext) => Extension(ext.with_scalars_nullability(nullability)),
}
}

Expand Down Expand Up @@ -133,15 +133,16 @@ impl Display for DType {
n
),
List(c, n) => write!(f, "list({}){}", c, n),
Extension(ext, n) => write!(
Extension(ext) => write!(
f,
"ext({}, {}{}){}",
ext.id(),
ext.scalars_dtype(),
ext.scalars_dtype()
.with_nullability(Nullability::NonNullable),
ext.metadata()
.map(|m| format!(", {:?}", m))
.unwrap_or_else(|| "".to_string()),
n
ext.scalars_dtype().nullability(),
),
}
}
Expand Down Expand Up @@ -208,7 +209,7 @@ mod test {

#[test]
fn size_of() {
assert_eq!(mem::size_of::<DType>(), 40);
assert_eq!(mem::size_of::<DType>(), 48);
}

#[test]
Expand Down
10 changes: 9 additions & 1 deletion vortex-dtype/src/extension.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::fmt::{Display, Formatter};
use std::sync::Arc;

use crate::DType;
use crate::{DType, Nullability};

#[derive(Debug, Clone, PartialEq, Eq, Ord, PartialOrd, Hash)]
#[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))]
Expand Down Expand Up @@ -110,6 +110,14 @@ impl ExtDType {
self.scalars_dtype.as_ref()
}

pub fn with_scalars_nullability(&self, nullability: Nullability) -> Self {
Self::new(
self.id.clone(),
Arc::new(self.scalars_dtype.with_nullability(nullability)),
self.metadata.clone(),
)
}

#[inline]
pub fn metadata(&self) -> Option<&ExtMetadata> {
self.metadata.as_ref()
Expand Down
26 changes: 10 additions & 16 deletions vortex-dtype/src/serde/flatbuffers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,22 +86,17 @@ impl TryFrom<fb::DType<'_>> for DType {
vortex_err!("failed to parse extension id from flatbuffer")
})?);
let metadata = fb_ext.metadata().map(|m| ExtMetadata::from(m.bytes()));
Ok(Self::Extension(
ExtDType::new(
id,
Arc::new(
DType::try_from(fb_ext.scalars_dtype().ok_or_else(|| {
vortex_err!(
Ok(Self::Extension(ExtDType::new(
id,
Arc::new(
DType::try_from(fb_ext.scalars_dtype().ok_or_else(|| {
vortex_err!(
InvalidSerde: "scalars_dtype must be present on DType fbs message")
})?)
.map_err(|e| {
vortex_err!("failed to create DType from fbs message: {e}")
})?,
),
metadata,
})?)
.map_err(|e| vortex_err!("failed to create DType from fbs message: {e}"))?,
),
fb_ext.nullable().into(),
))
metadata,
)))
}
_ => Err(vortex_err!("Unknown DType variant")),
}
Expand Down Expand Up @@ -183,7 +178,7 @@ impl WriteFlatBuffer for DType {
)
.as_union_value()
}
Self::Extension(ext, n) => {
Self::Extension(ext) => {
let id = Some(fbb.create_string(ext.id().as_ref()));
let scalars_dtype = Some(ext.scalars_dtype().write_flatbuffer(fbb));
let metadata = ext.metadata().map(|m| fbb.create_vector(m.as_ref()));
Expand All @@ -193,7 +188,6 @@ impl WriteFlatBuffer for DType {
id,
scalars_dtype,
metadata,
nullable: (*n).into(),
},
)
.as_union_value()
Expand Down
4 changes: 1 addition & 3 deletions vortex-dtype/src/serde/proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ impl TryFrom<&pb::DType> for DType {
).map_err(|e| vortex_err!("failed converting DType from proto message: {}", e))?),
e.metadata.as_ref().map(|m| ExtMetadata::from(m.as_ref())),
),
e.nullable.into(),
)),
}
}
Expand Down Expand Up @@ -88,11 +87,10 @@ impl From<&DType> for pb::DType {
element_type: Some(Box::new(l.as_ref().into())),
nullable: (*n).into(),
})),
DType::Extension(e, n) => DtypeType::Extension(Box::new(pb::Extension {
DType::Extension(e) => DtypeType::Extension(Box::new(pb::Extension {
id: e.id().as_ref().into(),
scalars_dtype: Some(Box::new(e.scalars_dtype().into())),
metadata: e.metadata().map(|m| m.as_ref().into()),
nullable: (*n).into(),
})),
}),
}
Expand Down
1 change: 0 additions & 1 deletion vortex-flatbuffers/flatbuffers/vortex-dtype/dtype.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ table Extension {
id: string;
scalars_dtype: DType;
metadata: [ubyte];
nullable: bool;
}

union Type {
Expand Down
Loading

0 comments on commit 44b49f6

Please sign in to comment.