diff --git a/encodings/datetime-parts/src/compute.rs b/encodings/datetime-parts/src/compute.rs index 5e57800580..81be7325ec 100644 --- a/encodings/datetime-parts/src/compute.rs +++ b/encodings/datetime-parts/src/compute.rs @@ -5,9 +5,9 @@ use vortex::compute::{slice, take, ArrayCompute, SliceFn, TakeFn}; use vortex::validity::ArrayValidity; use vortex::{Array, ArrayDType, IntoArray, IntoArrayVariant}; use vortex_datetime_dtype::{TemporalMetadata, TimeUnit}; -use vortex_dtype::{DType, PType}; +use vortex_dtype::DType; use vortex_error::{vortex_bail, VortexResult, VortexUnwrap as _}; -use vortex_scalar::Scalar; +use vortex_scalar::{Scalar, ScalarValue}; use crate::DateTimePartsArray; @@ -51,7 +51,7 @@ impl SliceFn for DateTimePartsArray { impl ScalarAtFn for DateTimePartsArray { fn scalar_at(&self, index: usize) -> VortexResult { - let DType::Extension(ext, nullability) = self.dtype().clone() else { + let DType::Extension(ext) = self.dtype().clone() else { vortex_bail!( "DateTimePartsArray must have extension dtype, found {}", self.dtype() @@ -63,10 +63,7 @@ impl ScalarAtFn for DateTimePartsArray { }; if !self.is_valid(index) { - return Ok(Scalar::extension( - ext, - Scalar::null(DType::Primitive(PType::I64, nullability)), - )); + return Ok(Scalar::extension(ext, ScalarValue::Null)); } let divisor = match time_unit { @@ -83,10 +80,7 @@ impl ScalarAtFn for DateTimePartsArray { let scalar = days * 86_400 * divisor + seconds * divisor + subseconds; - Ok(Scalar::extension( - ext, - Scalar::primitive(scalar, nullability), - )) + Ok(Scalar::extension(ext, scalar.into())) } fn scalar_at_unchecked(&self, index: usize) -> Scalar { @@ -98,7 +92,7 @@ impl ScalarAtFn for DateTimePartsArray { /// /// Enforces that the passed array is actually a [DateTimePartsArray] with proper metadata. pub fn decode_to_temporal(array: &DateTimePartsArray) -> VortexResult { - let DType::Extension(ext, _) = array.dtype().clone() else { + let DType::Extension(ext) = array.dtype().clone() else { vortex_bail!(ComputeError: "expected dtype to be DType::Extension variant") }; @@ -187,7 +181,7 @@ mod test { assert_eq!(validity, raw_millis.validity()); let date_times = DateTimePartsArray::try_new( - DType::Extension(temporal_array.ext_dtype().clone(), validity.nullability()), + DType::Extension(temporal_array.ext_dtype().clone()), days, seconds, subseconds, diff --git a/pyvortex/src/python_repr.rs b/pyvortex/src/python_repr.rs index 83e34899a2..30c646fc99 100644 --- a/pyvortex/src/python_repr.rs +++ b/pyvortex/src/python_repr.rs @@ -46,13 +46,18 @@ impl Display for DTypePythonRepr<'_> { n.python_repr() ), DType::List(c, n) => write!(f, "list({}, {})", c.python_repr(), n.python_repr()), - DType::Extension(ext, n) => { - write!(f, "ext(\"{}\", ", ext.id().python_repr())?; + DType::Extension(ext) => { + write!( + f, + "ext(\"{}\", {}, ", + ext.id().python_repr(), + ext.scalars_dtype().python_repr() + )?; match ext.metadata() { None => write!(f, "None")?, Some(metadata) => write!(f, "{}", metadata.python_repr())?, }; - write!(f, ", {})", n.python_repr()) + write!(f, ", {})", ext.scalars_dtype().nullability().python_repr()) } } } diff --git a/vortex-array/src/array/chunked/canonical.rs b/vortex-array/src/array/chunked/canonical.rs index 54fbeb2aec..7398a89957 100644 --- a/vortex-array/src/array/chunked/canonical.rs +++ b/vortex-array/src/array/chunked/canonical.rs @@ -71,7 +71,7 @@ pub(crate) fn try_canonicalize_chunks( // / \ // storage storage // - DType::Extension(ext_dtype, _) => { + DType::Extension(ext_dtype) => { // Recursively apply canonicalization and packing to the storage array backing // each chunk of the extension array. let storage_chunks: Vec = chunks diff --git a/vortex-array/src/array/constant/variants.rs b/vortex-array/src/array/constant/variants.rs index b17028446f..e2cc1f8a7c 100644 --- a/vortex-array/src/array/constant/variants.rs +++ b/vortex-array/src/array/constant/variants.rs @@ -3,8 +3,8 @@ use std::sync::Arc; use vortex_dtype::field::Field; use vortex_dtype::{DType, PType}; -use vortex_error::{vortex_panic, VortexError, VortexExpect as _, VortexResult}; -use vortex_scalar::{ExtScalar, Scalar, ScalarValue, StructScalar}; +use vortex_error::{VortexError, VortexExpect as _, VortexResult}; +use vortex_scalar::{Scalar, ScalarValue, StructScalar}; use crate::array::constant::ConstantArray; use crate::iter::{Accessor, AccessorRef}; @@ -203,22 +203,9 @@ impl ListArrayTrait for ConstantArray {} impl ExtensionArrayTrait for ConstantArray { fn storage_array(&self) -> Array { - let scalar_ext = ExtScalar::try_new(self.dtype(), self.scalar_value()) - .vortex_expect("Expected an extension scalar"); - - // FIXME(ngates): there's not enough information to get the storage array. - let n = self.dtype().nullability(); - let storage_dtype = match scalar_ext.value() { - ScalarValue::Bool(_) => DType::Binary(n), - ScalarValue::Primitive(pvalue) => DType::Primitive(pvalue.ptype(), n), - ScalarValue::Buffer(_) => DType::Binary(n), - ScalarValue::BufferString(_) => DType::Utf8(n), - ScalarValue::List(_) => vortex_panic!("List not supported"), - ScalarValue::Null => DType::Null, - }; - + let storage_dtype = self.ext_dtype().scalars_dtype().clone(); ConstantArray::new( - Scalar::new(storage_dtype, scalar_ext.value().clone()), + Scalar::new(storage_dtype, self.scalar_value().clone()), self.len(), ) .into_array() diff --git a/vortex-array/src/array/extension/compute.rs b/vortex-array/src/array/extension/compute.rs index 97c971013e..f550dc7293 100644 --- a/vortex-array/src/array/extension/compute.rs +++ b/vortex-array/src/array/extension/compute.rs @@ -60,14 +60,14 @@ impl ScalarAtFn for ExtensionArray { fn scalar_at(&self, index: usize) -> VortexResult { Ok(Scalar::extension( self.ext_dtype().clone(), - scalar_at(self.storage(), index)?, + scalar_at(self.storage(), index)?.into_value(), )) } fn scalar_at_unchecked(&self, index: usize) -> Scalar { Scalar::extension( self.ext_dtype().clone(), - scalar_at_unchecked(self.storage(), index), + scalar_at_unchecked(self.storage(), index).into_value(), ) } } diff --git a/vortex-array/src/array/extension/mod.rs b/vortex-array/src/array/extension/mod.rs index 66c6e5b0d0..cf4b736f5b 100644 --- a/vortex-array/src/array/extension/mod.rs +++ b/vortex-array/src/array/extension/mod.rs @@ -16,9 +16,7 @@ mod compute; impl_encoding!("vortex.ext", ids::EXTENSION, Extension); #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ExtensionMetadata { - storage_dtype: DType, -} +pub struct ExtensionMetadata; impl Display for ExtensionMetadata { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -28,12 +26,16 @@ impl Display for ExtensionMetadata { impl ExtensionArray { pub fn new(ext_dtype: ExtDType, storage: Array) -> Self { + assert_eq!( + ext_dtype.scalars_dtype(), + storage.dtype(), + "ExtensionArray: scalars_dtype must match storage array DType", + ); + Self::try_from_parts( - DType::Extension(ext_dtype, storage.dtype().nullability()), + DType::Extension(ext_dtype), storage.len(), - ExtensionMetadata { - storage_dtype: storage.dtype().clone(), - }, + ExtensionMetadata, [storage].into(), Default::default(), ) @@ -42,7 +44,7 @@ impl ExtensionArray { pub fn storage(&self) -> Array { self.as_ref() - .child(0, &self.metadata().storage_dtype, self.len()) + .child(0, self.ext_dtype().scalars_dtype(), self.len()) .vortex_expect("Missing storage array for ExtensionArray") } diff --git a/vortex-array/src/arrow/dtype.rs b/vortex-array/src/arrow/dtype.rs index b71554bb2b..a6c6bd1dd0 100644 --- a/vortex-array/src/arrow/dtype.rs +++ b/vortex-array/src/arrow/dtype.rs @@ -71,8 +71,7 @@ impl FromArrowType<&Field> for DType { | DataType::Time32(_) | DataType::Time64(_) | DataType::Timestamp(..) => Extension( - make_temporal_ext_dtype(field.data_type()), - field.is_nullable().into(), + make_temporal_ext_dtype(field.data_type()).with_scalars_nullability(nullability), ), DataType::List(e) | DataType::LargeList(e) => { List(Arc::new(Self::from_arrow(e.as_ref())), nullability) diff --git a/vortex-array/src/variants.rs b/vortex-array/src/variants.rs index ad01b69e65..1e4e0046f7 100644 --- a/vortex-array/src/variants.rs +++ b/vortex-array/src/variants.rs @@ -236,7 +236,7 @@ pub trait ListArrayTrait: ArrayTrait {} pub trait ExtensionArrayTrait: ArrayTrait { fn ext_dtype(&self) -> &ExtDType { - let DType::Extension(ext_dtype, _nullability) = self.dtype() else { + let DType::Extension(ext_dtype) = self.dtype() else { vortex_panic!("Expected ExtDType") }; ext_dtype diff --git a/vortex-datafusion/src/datatype.rs b/vortex-datafusion/src/datatype.rs index fc2f61aea4..9520c98016 100644 --- a/vortex-datafusion/src/datatype.rs +++ b/vortex-datafusion/src/datatype.rs @@ -90,7 +90,7 @@ pub(crate) fn infer_data_type(dtype: &DType) -> DataType { dtype.is_nullable(), ))) } - DType::Extension(ext_dtype, _) => { + DType::Extension(ext_dtype) => { // Try and match against the known extension DTypes. if is_temporal_ext_type(ext_dtype.id()) { make_arrow_temporal_dtype(ext_dtype) @@ -168,14 +168,11 @@ mod test { #[test] #[should_panic] fn test_dtype_conversion_panics() { - let _ = infer_data_type(&DType::Extension( - ExtDType::new( - ExtID::from("my-fake-ext-dtype"), - Arc::new(PType::I32.into()), - None, - ), - Nullability::NonNullable, - )); + let _ = infer_data_type(&DType::Extension(ExtDType::new( + ExtID::from("my-fake-ext-dtype"), + Arc::new(PType::I32.into()), + None, + ))); } #[test] diff --git a/vortex-dtype/src/dtype.rs b/vortex-dtype/src/dtype.rs index eda0b4ae89..9ea014767d 100644 --- a/vortex-dtype/src/dtype.rs +++ b/vortex-dtype/src/dtype.rs @@ -29,7 +29,7 @@ pub enum DType { Binary(Nullability), Struct(StructDType, Nullability), List(Arc, Nullability), - Extension(ExtDType, Nullability), + Extension(ExtDType), } impl DType { @@ -53,7 +53,7 @@ impl DType { Binary(n) => matches!(n, Nullable), Struct(_, n) => matches!(n, Nullable), List(_, n) => matches!(n, Nullable), - Extension(_, n) => matches!(n, Nullable), + Extension(ext_dtype) => ext_dtype.scalars_dtype().is_nullable(), } } @@ -74,7 +74,7 @@ impl DType { Binary(_) => Binary(nullability), Struct(st, _) => Struct(st.clone(), nullability), List(c, _) => List(c.clone(), nullability), - Extension(ext, _) => Extension(ext.clone(), nullability), + Extension(ext) => Extension(ext.with_scalars_nullability(nullability)), } } @@ -133,15 +133,16 @@ impl Display for DType { n ), List(c, n) => write!(f, "list({}){}", c, n), - Extension(ext, n) => write!( + Extension(ext) => write!( f, "ext({}, {}{}){}", ext.id(), - ext.scalars_dtype(), + ext.scalars_dtype() + .with_nullability(Nullability::NonNullable), ext.metadata() .map(|m| format!(", {:?}", m)) .unwrap_or_else(|| "".to_string()), - n + ext.scalars_dtype().nullability(), ), } } @@ -208,7 +209,7 @@ mod test { #[test] fn size_of() { - assert_eq!(mem::size_of::(), 40); + assert_eq!(mem::size_of::(), 48); } #[test] diff --git a/vortex-dtype/src/extension.rs b/vortex-dtype/src/extension.rs index b67a260f39..67a543d1c0 100644 --- a/vortex-dtype/src/extension.rs +++ b/vortex-dtype/src/extension.rs @@ -1,7 +1,7 @@ use std::fmt::{Display, Formatter}; use std::sync::Arc; -use crate::DType; +use crate::{DType, Nullability}; #[derive(Debug, Clone, PartialEq, Eq, Ord, PartialOrd, Hash)] #[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))] @@ -110,6 +110,14 @@ impl ExtDType { self.scalars_dtype.as_ref() } + pub fn with_scalars_nullability(&self, nullability: Nullability) -> Self { + Self::new( + self.id.clone(), + Arc::new(self.scalars_dtype.with_nullability(nullability)), + self.metadata.clone(), + ) + } + #[inline] pub fn metadata(&self) -> Option<&ExtMetadata> { self.metadata.as_ref() diff --git a/vortex-dtype/src/serde/flatbuffers/mod.rs b/vortex-dtype/src/serde/flatbuffers/mod.rs index 2264cfd5f2..2006eaa6b6 100644 --- a/vortex-dtype/src/serde/flatbuffers/mod.rs +++ b/vortex-dtype/src/serde/flatbuffers/mod.rs @@ -86,22 +86,17 @@ impl TryFrom> for DType { vortex_err!("failed to parse extension id from flatbuffer") })?); let metadata = fb_ext.metadata().map(|m| ExtMetadata::from(m.bytes())); - Ok(Self::Extension( - ExtDType::new( - id, - Arc::new( - DType::try_from(fb_ext.scalars_dtype().ok_or_else(|| { - vortex_err!( + Ok(Self::Extension(ExtDType::new( + id, + Arc::new( + DType::try_from(fb_ext.scalars_dtype().ok_or_else(|| { + vortex_err!( InvalidSerde: "scalars_dtype must be present on DType fbs message") - })?) - .map_err(|e| { - vortex_err!("failed to create DType from fbs message: {e}") - })?, - ), - metadata, + })?) + .map_err(|e| vortex_err!("failed to create DType from fbs message: {e}"))?, ), - fb_ext.nullable().into(), - )) + metadata, + ))) } _ => Err(vortex_err!("Unknown DType variant")), } @@ -183,7 +178,7 @@ impl WriteFlatBuffer for DType { ) .as_union_value() } - Self::Extension(ext, n) => { + Self::Extension(ext) => { let id = Some(fbb.create_string(ext.id().as_ref())); let scalars_dtype = Some(ext.scalars_dtype().write_flatbuffer(fbb)); let metadata = ext.metadata().map(|m| fbb.create_vector(m.as_ref())); @@ -193,7 +188,6 @@ impl WriteFlatBuffer for DType { id, scalars_dtype, metadata, - nullable: (*n).into(), }, ) .as_union_value() diff --git a/vortex-dtype/src/serde/proto.rs b/vortex-dtype/src/serde/proto.rs index f062f800f5..1e5879dd64 100644 --- a/vortex-dtype/src/serde/proto.rs +++ b/vortex-dtype/src/serde/proto.rs @@ -55,7 +55,6 @@ impl TryFrom<&pb::DType> for DType { ).map_err(|e| vortex_err!("failed converting DType from proto message: {}", e))?), e.metadata.as_ref().map(|m| ExtMetadata::from(m.as_ref())), ), - e.nullable.into(), )), } } @@ -88,11 +87,10 @@ impl From<&DType> for pb::DType { element_type: Some(Box::new(l.as_ref().into())), nullable: (*n).into(), })), - DType::Extension(e, n) => DtypeType::Extension(Box::new(pb::Extension { + DType::Extension(e) => DtypeType::Extension(Box::new(pb::Extension { id: e.id().as_ref().into(), scalars_dtype: Some(Box::new(e.scalars_dtype().into())), metadata: e.metadata().map(|m| m.as_ref().into()), - nullable: (*n).into(), })), }), } diff --git a/vortex-flatbuffers/flatbuffers/vortex-dtype/dtype.fbs b/vortex-flatbuffers/flatbuffers/vortex-dtype/dtype.fbs index b8d618e8e1..6904727f01 100644 --- a/vortex-flatbuffers/flatbuffers/vortex-dtype/dtype.fbs +++ b/vortex-flatbuffers/flatbuffers/vortex-dtype/dtype.fbs @@ -54,7 +54,6 @@ table Extension { id: string; scalars_dtype: DType; metadata: [ubyte]; - nullable: bool; } union Type { diff --git a/vortex-flatbuffers/src/generated/dtype.rs b/vortex-flatbuffers/src/generated/dtype.rs index 264e0c09b9..e9af36b397 100644 --- a/vortex-flatbuffers/src/generated/dtype.rs +++ b/vortex-flatbuffers/src/generated/dtype.rs @@ -1130,7 +1130,6 @@ impl<'a> Extension<'a> { pub const VT_ID: flatbuffers::VOffsetT = 4; pub const VT_SCALARS_DTYPE: flatbuffers::VOffsetT = 6; pub const VT_METADATA: flatbuffers::VOffsetT = 8; - pub const VT_NULLABLE: flatbuffers::VOffsetT = 10; #[inline] pub unsafe fn init_from_table(table: flatbuffers::Table<'a>) -> Self { @@ -1145,7 +1144,6 @@ impl<'a> Extension<'a> { if let Some(x) = args.metadata { builder.add_metadata(x); } if let Some(x) = args.scalars_dtype { builder.add_scalars_dtype(x); } if let Some(x) = args.id { builder.add_id(x); } - builder.add_nullable(args.nullable); builder.finish() } @@ -1171,13 +1169,6 @@ impl<'a> Extension<'a> { // which contains a valid value in this slot unsafe { self._tab.get::>>(Extension::VT_METADATA, None)} } - #[inline] - pub fn nullable(&self) -> bool { - // Safety: - // Created from valid Table for this object - // which contains a valid value in this slot - unsafe { self._tab.get::(Extension::VT_NULLABLE, Some(false)).unwrap()} - } } impl flatbuffers::Verifiable for Extension<'_> { @@ -1190,7 +1181,6 @@ impl flatbuffers::Verifiable for Extension<'_> { .visit_field::>("id", Self::VT_ID, false)? .visit_field::>("scalars_dtype", Self::VT_SCALARS_DTYPE, false)? .visit_field::>>("metadata", Self::VT_METADATA, false)? - .visit_field::("nullable", Self::VT_NULLABLE, false)? .finish(); Ok(()) } @@ -1199,7 +1189,6 @@ pub struct ExtensionArgs<'a> { pub id: Option>, pub scalars_dtype: Option>>, pub metadata: Option>>, - pub nullable: bool, } impl<'a> Default for ExtensionArgs<'a> { #[inline] @@ -1208,7 +1197,6 @@ impl<'a> Default for ExtensionArgs<'a> { id: None, scalars_dtype: None, metadata: None, - nullable: false, } } } @@ -1231,10 +1219,6 @@ impl<'a: 'b, 'b, A: flatbuffers::Allocator + 'a> ExtensionBuilder<'a, 'b, A> { self.fbb_.push_slot_always::>(Extension::VT_METADATA, metadata); } #[inline] - pub fn add_nullable(&mut self, nullable: bool) { - self.fbb_.push_slot::(Extension::VT_NULLABLE, nullable, false); - } - #[inline] pub fn new(_fbb: &'b mut flatbuffers::FlatBufferBuilder<'a, A>) -> ExtensionBuilder<'a, 'b, A> { let start = _fbb.start_table(); ExtensionBuilder { @@ -1255,7 +1239,6 @@ impl core::fmt::Debug for Extension<'_> { ds.field("id", &self.id()); ds.field("scalars_dtype", &self.scalars_dtype()); ds.field("metadata", &self.metadata()); - ds.field("nullable", &self.nullable()); ds.finish() } } diff --git a/vortex-proto/proto/dtype.proto b/vortex-proto/proto/dtype.proto index eaa71f14a8..259e881fb2 100644 --- a/vortex-proto/proto/dtype.proto +++ b/vortex-proto/proto/dtype.proto @@ -56,7 +56,6 @@ message Extension { string id = 1; DType scalars_dtype = 2; optional bytes metadata = 3; - bool nullable = 4; } message DType { diff --git a/vortex-proto/src/generated/vortex.dtype.rs b/vortex-proto/src/generated/vortex.dtype.rs index cadc67536e..c961e6e4a0 100644 --- a/vortex-proto/src/generated/vortex.dtype.rs +++ b/vortex-proto/src/generated/vortex.dtype.rs @@ -56,8 +56,6 @@ pub struct Extension { pub scalars_dtype: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(bytes = "vec", optional, tag = "3")] pub metadata: ::core::option::Option<::prost::alloc::vec::Vec>, - #[prost(bool, tag = "4")] - pub nullable: bool, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct DType { diff --git a/vortex-scalar/src/arrow.rs b/vortex-scalar/src/arrow.rs index 32a3a46773..0b904f8ba9 100644 --- a/vortex-scalar/src/arrow.rs +++ b/vortex-scalar/src/arrow.rs @@ -66,7 +66,7 @@ impl TryFrom<&Scalar> for Arc { DType::List(..) => { todo!("list scalar conversion") } - DType::Extension(ext, _) => { + DType::Extension(ext) => { if is_temporal_ext_type(ext.id()) { let metadata = TemporalMetadata::try_from(ext)?; let pv = value.value.as_pvalue()?; diff --git a/vortex-scalar/src/datafusion.rs b/vortex-scalar/src/datafusion.rs index c5edb6ac43..f6f258fb1b 100644 --- a/vortex-scalar/src/datafusion.rs +++ b/vortex-scalar/src/datafusion.rs @@ -64,7 +64,7 @@ impl TryFrom for ScalarValue { DType::List(..) => { todo!("list scalar conversion") } - DType::Extension(ext, _) => { + DType::Extension(ext) => { if is_temporal_ext_type(ext.id()) { let metadata = TemporalMetadata::try_from(&ext)?; let pv = value.value.as_pvalue()?; @@ -147,9 +147,10 @@ impl From for Scalar { ScalarValue::Date32(v) | ScalarValue::Time32Second(v) | ScalarValue::Time32Millisecond(v) => v.map(|i| { - let ext_dtype = make_temporal_ext_dtype(&value.data_type()); + let ext_dtype = make_temporal_ext_dtype(&value.data_type()) + .with_scalars_nullability(Nullability::Nullable); Scalar::new( - DType::Extension(ext_dtype, Nullability::Nullable), + DType::Extension(ext_dtype), crate::ScalarValue::Primitive(PValue::I32(i)), ) }), @@ -162,7 +163,7 @@ impl From for Scalar { | ScalarValue::TimestampNanosecond(v, _) => v.map(|i| { let ext_dtype = make_temporal_ext_dtype(&value.data_type()); Scalar::new( - DType::Extension(ext_dtype, Nullability::Nullable), + DType::Extension(ext_dtype.with_scalars_nullability(Nullability::Nullable)), crate::ScalarValue::Primitive(PValue::I64(i)), ) }), diff --git a/vortex-scalar/src/display.rs b/vortex-scalar/src/display.rs index 31922a90f3..da732a1804 100644 --- a/vortex-scalar/src/display.rs +++ b/vortex-scalar/src/display.rs @@ -60,7 +60,7 @@ impl Display for Scalar { } DType::List(..) => todo!(), // Specialized handling for date/time/timestamp builtin extension types. - DType::Extension(dtype, _) if is_temporal_ext_type(dtype.id()) => { + DType::Extension(dtype) if is_temporal_ext_type(dtype.id()) => { let metadata = TemporalMetadata::try_from(dtype).map_err(|_| std::fmt::Error)?; match ExtScalar::try_from(self) .map_err(|_| std::fmt::Error)? @@ -246,14 +246,11 @@ mod tests { #[test] fn display_time() { fn dtype() -> DType { - DType::Extension( - ExtDType::new( - TIME_ID.clone(), - Arc::new(PType::I32.into()), - Some(ExtMetadata::from(TemporalMetadata::Time(TimeUnit::S))), - ), - Nullable, - ) + DType::Extension(ExtDType::new( + TIME_ID.clone(), + Arc::new(DType::Primitive(PType::I32, Nullable)), + Some(ExtMetadata::from(TemporalMetadata::Time(TimeUnit::S))), + )) } assert_eq!(format!("{}", Scalar::null(dtype())), "null"); @@ -273,14 +270,11 @@ mod tests { #[test] fn display_date() { fn dtype() -> DType { - DType::Extension( - ExtDType::new( - DATE_ID.clone(), - Arc::new(PType::I32.into()), - Some(ExtMetadata::from(TemporalMetadata::Date(TimeUnit::D))), - ), - Nullable, - ) + DType::Extension(ExtDType::new( + DATE_ID.clone(), + Arc::new(DType::Primitive(PType::I32, Nullable)), + Some(ExtMetadata::from(TemporalMetadata::Date(TimeUnit::D))), + )) } assert_eq!(format!("{}", Scalar::null(dtype())), "null"); @@ -313,17 +307,14 @@ mod tests { #[test] fn display_local_timestamp() { fn dtype() -> DType { - DType::Extension( - ExtDType::new( - TIMESTAMP_ID.clone(), - Arc::new(PType::I32.into()), - Some(ExtMetadata::from(TemporalMetadata::Timestamp( - TimeUnit::S, - None, - ))), - ), - Nullable, - ) + DType::Extension(ExtDType::new( + TIMESTAMP_ID.clone(), + Arc::new(DType::Primitive(PType::I32, Nullable)), + Some(ExtMetadata::from(TemporalMetadata::Timestamp( + TimeUnit::S, + None, + ))), + )) } assert_eq!(format!("{}", Scalar::null(dtype())), "null"); @@ -344,17 +335,14 @@ mod tests { #[test] fn display_zoned_timestamp() { fn dtype() -> DType { - DType::Extension( - ExtDType::new( - TIMESTAMP_ID.clone(), - Arc::new(PType::I64.into()), - Some(ExtMetadata::from(TemporalMetadata::Timestamp( - TimeUnit::S, - Some(String::from("Pacific/Guam")), - ))), - ), - Nullable, - ) + DType::Extension(ExtDType::new( + TIMESTAMP_ID.clone(), + Arc::new(DType::Primitive(PType::I64, Nullable)), + Some(ExtMetadata::from(TemporalMetadata::Timestamp( + TimeUnit::S, + Some(String::from("Pacific/Guam")), + ))), + )) } assert_eq!(format!("{}", Scalar::null(dtype())), "null"); diff --git a/vortex-scalar/src/extension.rs b/vortex-scalar/src/extension.rs index e345693d3f..90009ff249 100644 --- a/vortex-scalar/src/extension.rs +++ b/vortex-scalar/src/extension.rs @@ -44,10 +44,11 @@ impl<'a> TryFrom<&'a Scalar> for ExtScalar<'a> { } impl Scalar { - pub fn extension(ext_dtype: ExtDType, storage: Self) -> Self { + pub fn extension(ext_dtype: ExtDType, value: ScalarValue) -> Self { + // Ensure that the ext_dtype is compatible with our scalar value type instead. Self { - dtype: DType::Extension(ext_dtype, storage.dtype().nullability()), - value: storage.value, + dtype: DType::Extension(ext_dtype), + value, } } }