diff --git a/vortex-buffer/src/string.rs b/vortex-buffer/src/string.rs index 4e6ae2210b..de3c7748c3 100644 --- a/vortex-buffer/src/string.rs +++ b/vortex-buffer/src/string.rs @@ -66,3 +66,9 @@ impl AsRef for BufferString { self.as_str() } } + +impl AsRef<[u8]> for BufferString { + fn as_ref(&self) -> &[u8] { + self.as_str().as_bytes() + } +} diff --git a/vortex-scalar/src/binary.rs b/vortex-scalar/src/binary.rs index 10a6ee2f74..66654f4169 100644 --- a/vortex-scalar/src/binary.rs +++ b/vortex-scalar/src/binary.rs @@ -57,3 +57,36 @@ impl<'a> TryFrom<&'a Scalar> for Buffer { .ok_or_else(|| vortex_err!("Can't extract present value from null scalar")) } } + +impl TryFrom<&ScalarValue> for Buffer { + type Error = VortexError; + + fn try_from(value: &ScalarValue) -> Result { + Option::::try_from(value)? + .ok_or_else(|| vortex_err!("Can't extract present value from null scalar")) + } +} + +impl TryFrom for Buffer { + type Error = VortexError; + + fn try_from(value: ScalarValue) -> Result { + Buffer::try_from(&value) + } +} + +impl TryFrom<&ScalarValue> for Option { + type Error = VortexError; + + fn try_from(value: &ScalarValue) -> Result { + value.as_buffer() + } +} + +impl TryFrom for Option { + type Error = VortexError; + + fn try_from(value: ScalarValue) -> Result { + Option::::try_from(&value) + } +} diff --git a/vortex-scalar/src/bool.rs b/vortex-scalar/src/bool.rs index 263a8ec872..7450f11f56 100644 --- a/vortex-scalar/src/bool.rs +++ b/vortex-scalar/src/bool.rs @@ -73,16 +73,39 @@ impl From for Scalar { } } +impl TryFrom<&ScalarValue> for Option { + type Error = VortexError; + + fn try_from(value: &ScalarValue) -> VortexResult { + value.as_bool() + } +} + +impl TryFrom for Option { + type Error = VortexError; + + fn try_from(value: ScalarValue) -> VortexResult { + Option::::try_from(&value) + } +} + impl TryFrom<&ScalarValue> for bool { type Error = VortexError; fn try_from(value: &ScalarValue) -> VortexResult { - value - .as_bool()? + Option::::try_from(value)? .ok_or_else(|| vortex_err!("Can't extract present value from null scalar")) } } +impl TryFrom for bool { + type Error = VortexError; + + fn try_from(value: ScalarValue) -> VortexResult { + bool::try_from(&value) + } +} + impl From for ScalarValue { fn from(value: bool) -> Self { ScalarValue::Bool(value) diff --git a/vortex-scalar/src/primitive.rs b/vortex-scalar/src/primitive.rs index b7add0d85e..fedd92d7bc 100644 --- a/vortex-scalar/src/primitive.rs +++ b/vortex-scalar/src/primitive.rs @@ -154,19 +154,37 @@ macro_rules! primitive_scalar { impl TryFrom<&ScalarValue> for $T { type Error = VortexError; + fn try_from(value: &ScalarValue) -> Result { + Option::<$T>::try_from(value)? + .ok_or_else(|| vortex_err!("Can't extract present value from null scalar")) + } + } + + impl TryFrom for $T { + type Error = VortexError; + + fn try_from(value: ScalarValue) -> Result { + <$T>::try_from(&value) + } + } + + impl TryFrom<&ScalarValue> for Option<$T> { + type Error = VortexError; + fn try_from(value: &ScalarValue) -> Result { match value { - ScalarValue::Primitive(pvalue) => <$T>::try_from(*pvalue), + ScalarValue::Null => Ok(None), + ScalarValue::Primitive(pvalue) => Ok(Some(<$T>::try_from(*pvalue)?)), _ => vortex_bail!("expected primitive"), } } } - impl TryFrom for $T { + impl TryFrom for Option<$T> { type Error = VortexError; fn try_from(value: ScalarValue) -> Result { - <$T>::try_from(&value) + Option::<$T>::try_from(&value) } } }; diff --git a/vortex-scalar/src/utf8.rs b/vortex-scalar/src/utf8.rs index cf74c43235..4a99a46b31 100644 --- a/vortex-scalar/src/utf8.rs +++ b/vortex-scalar/src/utf8.rs @@ -88,3 +88,36 @@ impl From<&str> for Scalar { } } } + +impl TryFrom<&ScalarValue> for BufferString { + type Error = VortexError; + + fn try_from(value: &ScalarValue) -> Result { + Option::::try_from(value)? + .ok_or_else(|| vortex_err!("Can't extract present value from null scalar")) + } +} + +impl TryFrom for BufferString { + type Error = VortexError; + + fn try_from(value: ScalarValue) -> Result { + BufferString::try_from(&value) + } +} + +impl TryFrom<&ScalarValue> for Option { + type Error = VortexError; + + fn try_from(value: &ScalarValue) -> Result { + value.as_buffer_string() + } +} + +impl TryFrom for Option { + type Error = VortexError; + + fn try_from(value: ScalarValue) -> Result { + Option::::try_from(&value) + } +} diff --git a/vortex-scalar/src/value.rs b/vortex-scalar/src/value.rs index 89cbcfbb84..364048bbe6 100644 --- a/vortex-scalar/src/value.rs +++ b/vortex-scalar/src/value.rs @@ -4,7 +4,7 @@ use std::sync::Arc; use half::f16; use vortex_buffer::{Buffer, BufferString}; use vortex_dtype::DType; -use vortex_error::{vortex_err, VortexResult}; +use vortex_error::{vortex_err, VortexError, VortexResult}; use crate::pvalue::PValue; @@ -94,6 +94,13 @@ impl ScalarValue { } } + pub fn as_null(&self) -> VortexResult> { + match self { + Self::Null => Ok(None), + _ => Err(vortex_err!("Expected a Null scalar, found {:?}", self)), + } + } + pub fn as_bool(&self) -> VortexResult> { match self { Self::Null => Ok(None), @@ -211,6 +218,24 @@ from_vec_for_scalar_value!(BufferString); from_vec_for_scalar_value!(bytes::Bytes); from_vec_for_scalar_value!(Buffer); +pub enum Empty {} + +impl TryFrom<&ScalarValue> for Option { + type Error = VortexError; + + fn try_from(value: &ScalarValue) -> Result { + value.as_null() + } +} + +impl TryFrom for Option { + type Error = VortexError; + + fn try_from(value: ScalarValue) -> Result { + Option::::try_from(&value) + } +} + #[cfg(test)] mod test { use vortex_dtype::{DType, Nullability, PType, StructDType}; diff --git a/vortex-serde/src/layouts/write/writer.rs b/vortex-serde/src/layouts/write/writer.rs index a6befcd7cb..9c9c0a930b 100644 --- a/vortex-serde/src/layouts/write/writer.rs +++ b/vortex-serde/src/layouts/write/writer.rs @@ -4,21 +4,24 @@ use std::{io, mem}; use flatbuffers::FlatBufferBuilder; use futures::{Stream, TryStreamExt}; -use vortex::array::{ChunkedArray, StructArray}; -use vortex::stats::{ArrayStatistics, Stat}; +use vortex::array::{ + BoolArray, ChunkedArray, NullArray, PrimitiveArray, StructArray, VarBinViewArray, +}; +use vortex::stats::{ArrayStatistics, Stat, Statistics}; use vortex::stream::ArrayStream; use vortex::validity::Validity; use vortex::{Array, ArrayDType as _, IntoArray}; use vortex_buffer::io_buf::IoBuf; -use vortex_dtype::{DType, Nullability, PType}; -use vortex_error::{vortex_bail, vortex_err, vortex_panic, VortexExpect as _, VortexResult}; +use vortex_buffer::{Buffer, BufferString}; +use vortex_dtype::{match_each_native_ptype, DType, Nullability}; +use vortex_error::{vortex_bail, vortex_err, VortexError, VortexExpect as _, VortexResult}; use vortex_flatbuffers::WriteFlatBuffer; -use vortex_scalar::{Scalar, ScalarValue}; +use vortex_scalar::{Empty, ScalarValue}; use crate::io::VortexWrite; use crate::layouts::write::footer::{Footer, Postscript}; use crate::layouts::write::layouts::Layout; -use crate::layouts::{EOF_SIZE, MAGIC_BYTES, METADATA_FIELD_NAMES, PRUNING_STATS, VERSION}; +use crate::layouts::{EOF_SIZE, MAGIC_BYTES, METADATA_FIELD_NAMES, VERSION}; use crate::stream_writer::ByteRange; use crate::MessageWriter; @@ -27,7 +30,7 @@ pub struct LayoutWriter { row_count: u64, dtype: Option, - column_chunks: Vec, + column_chunks: Vec>, } impl LayoutWriter { @@ -92,7 +95,7 @@ impl LayoutWriter { let accumulator = match self.column_chunks.get_mut(column_idx) { None => { self.column_chunks - .push(ColumnChunkAccumulator::new(size_hint, stream.dtype())); + .push(new_column_chunk_accumulator(size_hint, stream.dtype())); assert_eq!( self.column_chunks.len(), @@ -108,21 +111,13 @@ impl LayoutWriter { } Some(x) => x, }; - let mut n_rows_written = *accumulator - .row_offsets - .last() - .vortex_expect("row offsets cannot be empty by construction"); let mut byte_offsets = Vec::with_capacity(size_hint); byte_offsets.push(self.msgs.tell()); while let Some(chunk) = stream.try_next().await? { - for stat in PRUNING_STATS { - accumulator.push_stat(stat, chunk.statistics().compute(stat))?; - } - - n_rows_written += chunk.len() as u64; - accumulator.push_row_offset(n_rows_written); + accumulator.push_statistics(chunk.statistics())?; + accumulator.push_rows_written(chunk.len() as u64); self.msgs.write_batch(chunk).await?; byte_offsets.push(self.msgs.tell()); @@ -195,81 +190,136 @@ async fn write_fb_raw(mut writer: W, fb: F) Ok(writer) } -struct ColumnChunkAccumulator { - pub dtype: DType, +fn utf8_array_from_vec(x: Vec>) -> Array { + VarBinViewArray::from_iter(x, DType::Utf8(Nullability::Nullable)).into_array() +} + +fn binary_array_from_vec(x: Vec>) -> Array { + VarBinViewArray::from_iter(x, DType::Binary(Nullability::Nullable)).into_array() +} + +fn new_column_chunk_accumulator( + size_hint: usize, + dtype: &DType, +) -> Box { + match dtype { + // The nullability of the array is irrelevant because the array could be empty which has + // a Null minima and maxima. + DType::Bool(_) => Box::new(TypedColumnChunkAccumulator::::new(size_hint, |x| { + BoolArray::from_iter(x).into_array() + })), + DType::Null => Box::new(TypedColumnChunkAccumulator::::new(size_hint, |x| { + NullArray::new(x.len()).into_array() + })), + DType::Primitive(ptype, _nullability) => { + match_each_native_ptype!(ptype, |$P| { + Box::new(TypedColumnChunkAccumulator::<$P>::new(size_hint, |x| { + PrimitiveArray::from_nullable_vec(x).into_array() + })) + }) + } + DType::Utf8(_nullability) => Box::new(TypedColumnChunkAccumulator::::new( + size_hint, + utf8_array_from_vec, + )), + DType::Binary(_nullability) => Box::new(TypedColumnChunkAccumulator::::new( + size_hint, + binary_array_from_vec, + )), + DType::Struct(_struct_dtype, _nullability) => todo!(), + DType::List(_arc, _nullability) => todo!(), + DType::Extension(_ext_dtype, _nullability) => todo!(), + } +} + +trait ColumnChunkAccumulator { + fn push_rows_written(&mut self, n_rows: u64); + + fn push_batch_byte_offsets(&mut self, batch_byte_offsets: Vec); + + fn push_statistics(&mut self, statistics: &dyn Statistics) -> VortexResult<()>; + + fn n_rows_written(&self) -> u64; + + fn into_chunks_and_metadata(self: Box) -> VortexResult<(VecDeque, Array)>; +} + +struct TypedColumnChunkAccumulator { + // pub dtype: DType, pub row_offsets: Vec, pub batch_byte_offsets: Vec>, - pub minima: Vec, - pub maxima: Vec, - pub null_counts: Vec>, - pub true_counts: Vec>, + pub minima: Vec>, + pub maxima: Vec>, + pub null_counts: Vec, + pub true_counts: Vec, + pub to_array: fn(Vec>) -> Array, } -impl ColumnChunkAccumulator { - pub fn new(size_hint: usize, dtype: &DType) -> Self { +impl TypedColumnChunkAccumulator { + fn new(size_hint: usize, to_array: fn(Vec>) -> Array) -> Self { let mut row_offsets = Vec::with_capacity(size_hint + 1); row_offsets.push(0); Self { - dtype: dtype.as_nullable(), + // dtype: dtype.as_nullable(), row_offsets, batch_byte_offsets: Vec::new(), minima: Vec::with_capacity(size_hint), maxima: Vec::with_capacity(size_hint), null_counts: Vec::with_capacity(size_hint), true_counts: Vec::with_capacity(size_hint), + to_array, } } +} - fn push_row_offset(&mut self, row_offset: u64) { - self.row_offsets.push(row_offset); +impl ColumnChunkAccumulator for TypedColumnChunkAccumulator +where + Option: TryFrom, +{ + fn push_rows_written(&mut self, n_rows: u64) { + self.row_offsets.push(self.n_rows_written() + n_rows); } fn push_batch_byte_offsets(&mut self, batch_byte_offsets: Vec) { self.batch_byte_offsets.push(batch_byte_offsets); } - fn push_stat(&mut self, stat: Stat, value: Option) -> VortexResult<()> { - if matches!(stat, Stat::Min | Stat::Max) { - if let Some(ref value) = value { - if !value.value().is_instance_of(&self.dtype) { - vortex_bail!( - "Expected all min/max values to have dtype {}, got {}", - self.dtype, - value.dtype() - ); - } - } - } + fn push_statistics(&mut self, statistics: &dyn Statistics) -> VortexResult<()> { + let minimum = statistics + .compute(Stat::Min) + .vortex_expect("Every writable array must implement Stat::Min."); + self.minima + .push(Option::::try_from(minimum.into_value())?); + + let maximum = statistics + .compute(Stat::Max) + .vortex_expect("Every writable array must implement Stat::Max."); + self.maxima + .push(Option::::try_from(maximum.into_value())?); + + let null_count = statistics + .compute(Stat::NullCount) + .vortex_expect("Every writable array must implement Stat::NullCount."); + self.null_counts + .push(u64::try_from(null_count.into_value())?); + + let true_count = statistics + .compute(Stat::TrueCount) + .vortex_expect("Every writable array must implement Stat::Min."); + self.true_counts + .push(u64::try_from(true_count.into_value())?); - match stat { - Stat::Min => self.minima.push( - value - .map(|v| v.into_value()) - .unwrap_or_else(|| ScalarValue::Null), - ), - Stat::Max => self.maxima.push( - value - .map(|v| v.into_value()) - .unwrap_or_else(|| ScalarValue::Null), - ), - Stat::NullCount => self.null_counts.push(value.and_then(|v| { - v.into_value() - .as_pvalue() - .vortex_expect("null count is a primitive value") - .and_then(|v| v.as_u64()) - })), - Stat::TrueCount => self.true_counts.push(value.and_then(|v| { - v.into_value() - .as_pvalue() - .vortex_expect("true count is a primitive value") - .and_then(|v| v.as_u64()) - })), - _ => vortex_bail!("Unsupported pruning stat: {stat}"), - } Ok(()) } - fn into_chunks_and_metadata(mut self) -> VortexResult<(VecDeque, Array)> { + fn n_rows_written(&self) -> u64 { + *self + .row_offsets + .last() + .vortex_expect("row offsets cannot be empty by construction") + } + + fn into_chunks_and_metadata(mut self: Box) -> VortexResult<(VecDeque, Array)> { // we don't need the last row offset; that's just the total number of rows let length = self.row_offsets.len() - 1; self.row_offsets.truncate(length); @@ -293,58 +343,26 @@ impl ColumnChunkAccumulator { ); } - let mut names: Vec> = vec!["row_offset".into()]; - let mut fields = vec![mem::take(&mut self.row_offsets).into_array()]; - - for stat in PRUNING_STATS { - let values = match stat { - Stat::Min => mem::take(&mut self.minima), - Stat::Max => mem::take(&mut self.maxima), - Stat::NullCount => self - .null_counts - .iter() - .cloned() - .map(ScalarValue::from) - .collect(), - Stat::TrueCount => self - .true_counts - .iter() - .cloned() - .map(ScalarValue::from) - .collect(), - _ => vortex_bail!("Unsupported pruning stat: {}", stat), - }; - if values.len() != length { - vortex_bail!( - "Expected {} values for stat {}, found {}", - length, - stat, - values.len() - ); - } + let null_count_array = + PrimitiveArray::from_vec(mem::take(&mut self.null_counts), Validity::NonNullable) + .into_array(); - if values.iter().all(|v| v.is_null()) { - // no point in writing all nulls - continue; - }; + let true_count_array = + PrimitiveArray::from_vec(mem::take(&mut self.true_counts), Validity::NonNullable) + .into_array(); - let dtype = match stat { - Stat::Min | Stat::Max => self.dtype.clone(), - _ => DType::Primitive(PType::U64, Nullability::Nullable), - }; + let names: Vec> = METADATA_FIELD_NAMES + .iter() + .map(|x| Arc::from(x.to_string())) + .collect(); - names.push(format!("{stat}").to_lowercase().into()); - fields.push(Array::from_scalar_values(dtype, values)?); - } - for name in &names { - if !METADATA_FIELD_NAMES.contains(&name.as_ref()) { - vortex_panic!( - "Found unexpected metadata field name {}, expected one of {:?}", - name, - METADATA_FIELD_NAMES - ); - } - } + let fields = vec![ + mem::take(&mut self.row_offsets).into_array(), + (self.to_array)(mem::take(&mut self.minima)), + (self.to_array)(mem::take(&mut self.maxima)), + null_count_array, + true_count_array, + ]; Ok(( chunks,