diff --git a/encodings/fsst/src/array.rs b/encodings/fsst/src/array.rs index 618acb021c..2051874652 100644 --- a/encodings/fsst/src/array.rs +++ b/encodings/fsst/src/array.rs @@ -98,10 +98,13 @@ impl FSSTArray { } /// Build a [`Decompressor`][fsst::Decompressor] that can be used to decompress values from - /// this array. + /// this array, and pass it to the given function. /// - /// This is private to the crate to avoid leaking `fsst` as part of the public API. - pub(crate) fn decompressor(&self) -> Decompressor { + /// This is private to the crate to avoid leaking `fsst-rs` types as part of the public API. + pub(crate) fn with_decompressor(&self, apply: F) -> R + where + F: FnOnce(Decompressor) -> R, + { // canonicalize the symbols child array, so we can view it contiguously let symbols_array = self .symbols() @@ -119,21 +122,13 @@ impl FSSTArray { .unwrap(); let symbol_lengths = symbol_lengths_array.maybe_null_slice::(); - // SAFETY: we transmute to remove lifetime restrictions. - // Without this, the compiler complains that `symbol_lengths is tied to the lifetime of - // the `symbol_lengths_array` local variable, but it's actually tied to the lifetime of - // the `symbols` child array of self. We can't represent this in the type system right now, - // so we transmute to kill the lifetime complaints. - // This is fine because the returned `Decompressor`'s lifetime is tied to the lifetime - // of these same arrays. - let symbol_lengths = unsafe { std::mem::transmute::<&[u8], &[u8]>(symbol_lengths) }; - // Transmute the 64-bit symbol values into fsst `Symbol`s. // SAFETY: Symbol is guaranteed to be 8 bytes, guaranteed by the compiler. let symbols = unsafe { std::mem::transmute::<&[u64], &[Symbol]>(symbols) }; // Build a new decompressor that uses these symbols. - Decompressor::new(symbols, symbol_lengths) + let decompressor = Decompressor::new(symbols, symbol_lengths); + apply(decompressor) } } diff --git a/encodings/fsst/src/canonical.rs b/encodings/fsst/src/canonical.rs index 9cbfcdfe58..4e2bf21f73 100644 --- a/encodings/fsst/src/canonical.rs +++ b/encodings/fsst/src/canonical.rs @@ -11,46 +11,46 @@ use crate::FSSTArray; impl IntoCanonical for FSSTArray { fn into_canonical(self) -> VortexResult { - let decompressor = self.decompressor(); - - // Note: the maximum amount of decompressed space for an FSST array is 8 * n_elements, - // as each code can expand into a symbol of 1-8 bytes. - let max_items = self.len(); - let max_bytes = self.codes().nbytes() * size_of::(); - - // Create the target Arrow binary array - // TODO(aduffy): switch to BinaryView when PR https://github.com/spiraldb/vortex/pull/476 merges - let mut builder = GenericByteBuilder::::with_capacity(max_items, max_bytes); - - // TODO(aduffy): add decompression functions that support writing directly into and output buffer. - let codes_array = self.codes().into_canonical()?.into_varbin()?; - - // TODO(aduffy): make this loop faster. - for idx in 0..self.len() { - if !codes_array.is_valid(idx) { - builder.append_null() - } else { - let compressed = codes_array.bytes_at(idx)?; - let value = decompressor.decompress(compressed.as_slice()); - builder.append_value(value) + self.with_decompressor(|decompressor| { + // Note: the maximum amount of decompressed space for an FSST array is 8 * n_elements, + // as each code can expand into a symbol of 1-8 bytes. + let max_items = self.len(); + let max_bytes = self.codes().nbytes() * size_of::(); + + // Create the target Arrow binary array + // TODO(aduffy): switch to BinaryView when PR https://github.com/spiraldb/vortex/pull/476 merges + let mut builder = GenericByteBuilder::::with_capacity(max_items, max_bytes); + + // TODO(aduffy): add decompression functions that support writing directly into and output buffer. + let codes_array = self.codes().into_canonical()?.into_varbin()?; + + // TODO(aduffy): make this loop faster. + for idx in 0..self.len() { + if !codes_array.is_valid(idx) { + builder.append_null() + } else { + let compressed = codes_array.bytes_at(idx)?; + let value = decompressor.decompress(compressed.as_slice()); + builder.append_value(value) + } } - } - let arrow_array = builder.finish(); + let arrow_array = builder.finish(); - // Force the DTYpe - let canonical_varbin = VarBinArray::try_from(&vortex::Array::from_arrow( - &arrow_array, - self.dtype().is_nullable(), - ))?; + // Force the DTYpe + let canonical_varbin = VarBinArray::try_from(&vortex::Array::from_arrow( + &arrow_array, + self.dtype().is_nullable(), + ))?; - let forced_dtype = VarBinArray::try_new( - canonical_varbin.offsets(), - canonical_varbin.bytes(), - self.dtype().clone(), - canonical_varbin.validity(), - )?; + let forced_dtype = VarBinArray::try_new( + canonical_varbin.offsets(), + canonical_varbin.bytes(), + self.dtype().clone(), + canonical_varbin.validity(), + )?; - Ok(Canonical::VarBin(forced_dtype)) + Ok(Canonical::VarBin(forced_dtype)) + }) } } diff --git a/encodings/fsst/src/compress.rs b/encodings/fsst/src/compress.rs index 3ef33e038d..96213e28c8 100644 --- a/encodings/fsst/src/compress.rs +++ b/encodings/fsst/src/compress.rs @@ -10,8 +10,7 @@ use vortex_dtype::DType; use crate::FSSTArray; -/// Compress an array using FSST. If a compressor is provided, use the existing compressor, else -/// it will train a new compressor directly from the `strings`. +/// Compress an array using FSST. /// /// # Panics /// diff --git a/encodings/fsst/src/compute.rs b/encodings/fsst/src/compute.rs index 792d8165f4..5159932ec6 100644 --- a/encodings/fsst/src/compute.rs +++ b/encodings/fsst/src/compute.rs @@ -64,10 +64,11 @@ impl ScalarAtFn for FSSTArray { let compressed = scalar_at_unchecked(&self.codes(), index); let binary_datum = compressed.value().as_buffer().unwrap().unwrap(); - let decompressor = self.decompressor(); - let decoded_buffer: Buffer = decompressor.decompress(binary_datum.as_slice()).into(); + self.with_decompressor(|decompressor| { + let decoded_buffer: Buffer = decompressor.decompress(binary_datum.as_slice()).into(); - varbin_scalar(decoded_buffer, self.dtype()) + varbin_scalar(decoded_buffer, self.dtype()) + }) } }