Skip to content

Commit

Permalink
add scalars_dtype to ExtDType
Browse files Browse the repository at this point in the history
  • Loading branch information
a10y committed Oct 9, 2024
1 parent f15b162 commit c6e7add
Show file tree
Hide file tree
Showing 19 changed files with 178 additions and 67 deletions.
32 changes: 19 additions & 13 deletions vortex-array/src/array/datetime/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#[cfg(test)]
mod test;

use std::sync::Arc;

use vortex_datetime_dtype::{TemporalMetadata, TimeUnit, DATE_ID, TIMESTAMP_ID, TIME_ID};
use vortex_dtype::{DType, ExtDType};
use vortex_error::{vortex_panic, VortexError};
Expand Down Expand Up @@ -68,26 +70,22 @@ impl TemporalArray {
///
/// If any other time unit is provided, it panics.
pub fn new_date(array: Array, time_unit: TimeUnit) -> Self {
let ext_dtype = match time_unit {
match time_unit {
TimeUnit::D => {
assert_width!(i32, array);

ExtDType::new(
DATE_ID.clone(),
Some(TemporalMetadata::Date(time_unit).into()),
)
}
TimeUnit::Ms => {
assert_width!(i64, array);

ExtDType::new(
DATE_ID.clone(),
Some(TemporalMetadata::Date(time_unit).into()),
)
}
_ => vortex_panic!("invalid TimeUnit {time_unit} for vortex.date"),
};

let ext_dtype = ExtDType::new(
DATE_ID.clone(),
Arc::new(array.dtype().clone()),
Some(TemporalMetadata::Date(time_unit).into()),
);

Self {
ext: ExtensionArray::new(ext_dtype, array),
temporal_metadata: TemporalMetadata::Date(time_unit),
Expand Down Expand Up @@ -123,7 +121,11 @@ impl TemporalArray {
let temporal_metadata = TemporalMetadata::Time(time_unit);
Self {
ext: ExtensionArray::new(
ExtDType::new(TIME_ID.clone(), Some(temporal_metadata.clone().into())),
ExtDType::new(
TIME_ID.clone(),
Arc::new(array.dtype().clone()),
Some(temporal_metadata.clone().into()),
),
array,
),
temporal_metadata,
Expand All @@ -145,7 +147,11 @@ impl TemporalArray {

Self {
ext: ExtensionArray::new(
ExtDType::new(TIMESTAMP_ID.clone(), Some(temporal_metadata.clone().into())),
ExtDType::new(
TIMESTAMP_ID.clone(),
Arc::new(array.dtype().clone()),
Some(temporal_metadata.clone().into()),
),
array,
),
temporal_metadata,
Expand Down
8 changes: 7 additions & 1 deletion vortex-datafusion/src/datatype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ pub(crate) fn infer_data_type(dtype: &DType) -> DataType {
if is_temporal_ext_type(ext_dtype.id()) {
make_arrow_temporal_dtype(ext_dtype)
} else {
// TODO(aduffy): allow extension type authors to plugin their own to/from Arrow
// conversions.
vortex_panic!("Unsupported extension type \"{}\"", ext_dtype.id())
}
}
Expand Down Expand Up @@ -167,7 +169,11 @@ mod test {
#[should_panic]
fn test_dtype_conversion_panics() {
let _ = infer_data_type(&DType::Extension(
ExtDType::new(ExtID::from("my-fake-ext-dtype"), None),
ExtDType::new(
ExtID::from("my-fake-ext-dtype"),
Arc::new(PType::I32.into()),
None,
),
Nullability::NonNullable,
));
}
Expand Down
11 changes: 9 additions & 2 deletions vortex-datetime-dtype/src/arrow.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#![cfg(feature = "arrow")]

use std::sync::Arc;

use arrow_schema::{DataType, TimeUnit as ArrowTimeUnit};
use vortex_dtype::ExtDType;
use vortex_dtype::{ExtDType, PType};
use vortex_error::{vortex_bail, vortex_panic, VortexError, VortexExpect as _, VortexResult};

use crate::temporal::{TemporalMetadata, DATE_ID, TIMESTAMP_ID, TIME_ID};
Expand All @@ -17,32 +19,37 @@ pub fn make_temporal_ext_dtype(data_type: &DataType) -> ExtDType {
DataType::Timestamp(time_unit, time_zone) => {
let time_unit = TimeUnit::from(time_unit);
let tz = time_zone.clone().map(|s| s.to_string());

// PType is inferred for arrow based on the time units.
ExtDType::new(
TIMESTAMP_ID.clone(),
Arc::new(PType::I64.into()),
Some(TemporalMetadata::Timestamp(time_unit, tz).into()),
)
}
DataType::Time32(time_unit) => {
let time_unit = TimeUnit::from(time_unit);
ExtDType::new(
TIME_ID.clone(),
Arc::new(PType::I32.into()),
Some(TemporalMetadata::Time(time_unit).into()),
)
}
DataType::Time64(time_unit) => {
let time_unit = TimeUnit::from(time_unit);
ExtDType::new(
TIME_ID.clone(),
Arc::new(PType::I64.into()),
Some(TemporalMetadata::Time(time_unit).into()),
)
}
DataType::Date32 => ExtDType::new(
DATE_ID.clone(),
Arc::new(PType::I32.into()),
Some(TemporalMetadata::Date(TimeUnit::D).into()),
),
DataType::Date64 => ExtDType::new(
DATE_ID.clone(),
Arc::new(PType::I64.into()),
Some(TemporalMetadata::Date(TimeUnit::Ms).into()),
),
_ => unimplemented!("{data_type} conversion"),
Expand Down
12 changes: 9 additions & 3 deletions vortex-datetime-dtype/src/temporal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,9 @@ impl From<TemporalMetadata> for ExtMetadata {

#[cfg(test)]
mod tests {
use vortex_dtype::{ExtDType, ExtMetadata};
use std::sync::Arc;

use vortex_dtype::{ExtDType, ExtMetadata, PType};

use crate::{TemporalMetadata, TimeUnit, TIMESTAMP_ID};

Expand All @@ -207,8 +209,12 @@ mod tests {
.as_slice()
);

let temporal_metadata =
TemporalMetadata::try_from(&ExtDType::new(TIMESTAMP_ID.clone(), Some(meta))).unwrap();
let temporal_metadata = TemporalMetadata::try_from(&ExtDType::new(
TIMESTAMP_ID.clone(),
Arc::new(PType::I64.into()),
Some(meta),
))
.unwrap();

assert_eq!(
temporal_metadata,
Expand Down
3 changes: 2 additions & 1 deletion vortex-dtype/src/dtype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,9 @@ impl Display for DType {
List(c, n) => write!(f, "list({}){}", c, n),
Extension(ext, n) => write!(
f,
"ext({}{}){}",
"ext({}, {}{}){}",
ext.id(),
ext.scalars_dtype(),
ext.metadata()
.map(|m| format!(", {:?}", m))
.unwrap_or_else(|| "".to_string()),
Expand Down
46 changes: 44 additions & 2 deletions vortex-dtype/src/extension.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use std::fmt::{Display, Formatter};
use std::sync::Arc;

use crate::DType;

#[derive(Debug, Clone, PartialEq, Eq, Ord, PartialOrd, Hash)]
#[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))]
pub struct ExtID(Arc<str>);
Expand Down Expand Up @@ -55,19 +57,59 @@ impl From<&[u8]> for ExtMetadata {
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct ExtDType {
id: ExtID,
scalars_dtype: Arc<DType>,
metadata: Option<ExtMetadata>,
}

impl ExtDType {
pub fn new(id: ExtID, metadata: Option<ExtMetadata>) -> Self {
Self { id, metadata }
/// Creates a new `ExtDType`.
///
/// Extension data types in Vortex allows library users to express additional semantic meaning
/// on top of a set of scalar values. Metadata can optionally be provided for the extension type
/// to allow for parameterized types.
///
/// A simple example would be if one wanted to create a `vortex.temperature` extension type. The
/// canonical encoding for such values would be `f64`, and the metadata can contain an optional
/// temperature unit, allowing downstream users to be sure they properly account for Celsius
/// and Fahrenheit conversions.
///
/// ```
/// use std::sync::Arc;
/// use vortex_dtype::{DType, ExtDType, ExtID, ExtMetadata, Nullability, PType};
///
/// #[repr(u8)]
/// enum TemperatureUnit {
/// C = 0u8,
/// F = 1u8,
/// }
///
/// // Make a new extension type that encodes the unit for a set of nullable `f64`.
/// pub fn create_temperature_type(unit: TemperatureUnit) -> ExtDType {
/// ExtDType::new(
/// ExtID::new("vortex.temperature".into()),
/// Arc::new(DType::Primitive(PType::F64, Nullability::Nullable)),
/// Some(ExtMetadata::new([unit as u8].into()))
/// )
/// }
/// ```
pub fn new(id: ExtID, scalars_dtype: Arc<DType>, metadata: Option<ExtMetadata>) -> Self {
Self {
id,
scalars_dtype,
metadata,
}
}

#[inline]
pub fn id(&self) -> &ExtID {
&self.id
}

#[inline]
pub fn scalars_dtype(&self) -> &DType {
self.scalars_dtype.as_ref()
}

#[inline]
pub fn metadata(&self) -> Option<&ExtMetadata> {
self.metadata.as_ref()
Expand Down
16 changes: 15 additions & 1 deletion vortex-dtype/src/serde/flatbuffers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,19 @@ impl TryFrom<fb::DType<'_>> for DType {
})?);
let metadata = fb_ext.metadata().map(|m| ExtMetadata::from(m.bytes()));
Ok(Self::Extension(
ExtDType::new(id, metadata),
ExtDType::new(
id,
Arc::new(
DType::try_from(fb_ext.scalars_dtype().ok_or_else(|| {
vortex_err!(
InvalidSerde: "scalars_dtype must be present on DType fbs message")
})?)
.map_err(|e| {
vortex_err!("failed to create DType from fbs message: {e}")
})?,
),
metadata,
),
fb_ext.nullable().into(),
))
}
Expand Down Expand Up @@ -173,11 +185,13 @@ impl WriteFlatBuffer for DType {
}
Self::Extension(ext, n) => {
let id = Some(fbb.create_string(ext.id().as_ref()));
let scalars_dtype = Some(ext.scalars_dtype().write_flatbuffer(fbb));
let metadata = ext.metadata().map(|m| fbb.create_vector(m.as_ref()));
fb::Extension::create(
fbb,
&fb::ExtensionArgs {
id,
scalars_dtype,
metadata,
nullable: (*n).into(),
},
Expand Down
10 changes: 8 additions & 2 deletions vortex-dtype/src/serde/proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ impl TryFrom<&pb::DType> for DType {
DtypeType::Extension(e) => Ok(Self::Extension(
ExtDType::new(
ExtID::from(e.id.as_str()),
Arc::new(DType::try_from(e.scalars_dtype
.as_ref()
.ok_or_else(|| vortex_err!(InvalidSerde: "scalars_dtype must be provided in DType proto message"))?
.as_ref(),
).map_err(|e| vortex_err!("failed converting DType from proto message: {}", e))?),
e.metadata.as_ref().map(|m| ExtMetadata::from(m.as_ref())),
),
e.nullable.into(),
Expand Down Expand Up @@ -83,11 +88,12 @@ impl From<&DType> for pb::DType {
element_type: Some(Box::new(l.as_ref().into())),
nullable: (*n).into(),
})),
DType::Extension(e, n) => DtypeType::Extension(pb::Extension {
DType::Extension(e, n) => DtypeType::Extension(Box::new(pb::Extension {
id: e.id().as_ref().into(),
scalars_dtype: Some(Box::new(e.scalars_dtype().into())),
metadata: e.metadata().map(|m| m.as_ref().into()),
nullable: (*n).into(),
}),
})),
}),
}
}
Expand Down
1 change: 1 addition & 0 deletions vortex-flatbuffers/flatbuffers/vortex-dtype/dtype.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ table List {

table Extension {
id: string;
scalars_dtype: DType;
metadata: [ubyte];
nullable: bool;
}
Expand Down
2 changes: 1 addition & 1 deletion vortex-flatbuffers/src/generated/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

// @generated

use crate::scalar::*;
use crate::dtype::*;
use crate::scalar::*;
use core::mem;
use core::cmp::Ordering;

Expand Down
Loading

0 comments on commit c6e7add

Please sign in to comment.