From 15f11b0781305c77fafc750a986bc285ea60322e Mon Sep 17 00:00:00 2001 From: Andrew Duffy Date: Tue, 3 Sep 2024 12:47:58 -0400 Subject: [PATCH] remove method that does not affect perf --- benches/micro.rs | 2 +- src/builder.rs | 23 +++++-------- src/lib.rs | 89 +++++++++++------------------------------------- 3 files changed, 29 insertions(+), 85 deletions(-) diff --git a/benches/micro.rs b/benches/micro.rs index b4bc3a7..d55402e 100644 --- a/benches/micro.rs +++ b/benches/micro.rs @@ -11,7 +11,7 @@ fn one_megabyte(seed: &[u8]) -> Vec { fn bench_compress(c: &mut Criterion) { let mut group = c.benchmark_group("compress-overhead"); // Reusable memory to hold outputs - let mut output_buf: Vec = Vec::with_capacity(12); + let mut output_buf: Vec = Vec::with_capacity(8 * 1024 * 1024); // We create a symbol table that requires probing the hash table to perform // decompression. diff --git a/src/builder.rs b/src/builder.rs index 32a9cbf..ed24874 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -8,7 +8,7 @@ use std::cmp::Ordering; use std::collections::BinaryHeap; use crate::{ - advance_8byte_word, compare_masked, extract_u64, lossy_pht::LossyPHT, Code, Compressor, Symbol, + advance_8byte_word, compare_masked, lossy_pht::LossyPHT, Code, Compressor, Symbol, FSST_CODE_BASE, FSST_CODE_MASK, }; @@ -709,20 +709,13 @@ impl CompressorBuilder { // Load the last `remaining_byte`s of data into a final world. We then replicate the loop above, // but shift data out of this word rather than advancing an input pointer and potentially reading // unowned memory - let mut last_word = unsafe { - match remaining_bytes { - 0 => 0, - 1 => extract_u64::<1>(in_ptr), - 2 => extract_u64::<2>(in_ptr), - 3 => extract_u64::<3>(in_ptr), - 4 => extract_u64::<4>(in_ptr), - 5 => extract_u64::<5>(in_ptr), - 6 => extract_u64::<6>(in_ptr), - 7 => extract_u64::<7>(in_ptr), - 8 => extract_u64::<8>(in_ptr), - _ => unreachable!("remaining bytes must be <= 8"), - } - }; + let mut bytes = [0u8; 8]; + unsafe { + // SAFETY: it is safe to read up to remaining_bytes from in_ptr, and remaining_bytes + // will be <= 8 bytes. + std::ptr::copy_nonoverlapping(in_ptr, bytes.as_mut_ptr(), remaining_bytes); + } + let mut last_word = u64::from_le_bytes(bytes); let mut remaining_bytes = remaining_bytes; diff --git a/src/lib.rs b/src/lib.rs index 65c0522..4f00b47 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -367,7 +367,7 @@ impl Compressor { /// # Safety /// /// `out_ptr` must never be NULL or otherwise point to invalid memory. - // #[inline] + #[inline(never)] pub unsafe fn compress_word(&self, word: u64, out_ptr: *mut u8) -> (usize, usize) { // Speculatively write the first byte of `word` at offset 1. This is necessary if it is an escape, and // if it isn't, it will be overwritten anyway. @@ -482,39 +482,27 @@ impl Compressor { let remaining_bytes = unsafe { in_end.byte_offset_from(in_ptr) }; assert!( - in_ptr == in_end || remaining_bytes.is_positive(), - "in_ptr exceeded in_end, should not be possible" + out_ptr < out_end || remaining_bytes == 0, + "output buffer sized too small" ); + let remaining_bytes = remaining_bytes as usize; // Load the last `remaining_byte`s of data into a final world. We then replicate the loop above, // but shift data out of this word rather than advancing an input pointer and potentially reading // unowned memory. - let mut last_word = unsafe { - match remaining_bytes { - 0 => 0, - 1 => extract_u64::<1>(in_ptr), - 2 => extract_u64::<2>(in_ptr), - 3 => extract_u64::<3>(in_ptr), - 4 => extract_u64::<4>(in_ptr), - 5 => extract_u64::<5>(in_ptr), - 6 => extract_u64::<6>(in_ptr), - 7 => extract_u64::<7>(in_ptr), - 8 => extract_u64::<8>(in_ptr), - _ => unreachable!("remaining bytes must be <= 8"), - } - }; + let mut bytes = [0u8; 8]; + std::ptr::copy_nonoverlapping(in_ptr, bytes.as_mut_ptr(), remaining_bytes); + let mut last_word = u64::from_le_bytes(bytes); while in_ptr < in_end && out_ptr < out_end { - unsafe { - // Load a full 8-byte word of data from in_ptr. - // SAFETY: caller asserts in_ptr is not null. we may read past end of pointer though. - let (advance_in, advance_out) = self.compress_word(last_word, out_ptr); - in_ptr = in_ptr.byte_add(advance_in); - out_ptr = out_ptr.byte_add(advance_out); + // Load a full 8-byte word of data from in_ptr. + // SAFETY: caller asserts in_ptr is not null. we may read past end of pointer though. + let (advance_in, advance_out) = self.compress_word(last_word, out_ptr); + in_ptr = in_ptr.byte_add(advance_in); + out_ptr = out_ptr.byte_add(advance_out); - last_word = advance_8byte_word(last_word, advance_in); - } + last_word = advance_8byte_word(last_word, advance_in); } // in_ptr should have exceeded in_end @@ -522,15 +510,13 @@ impl Compressor { // Count the number of bytes written // SAFETY: assertion - unsafe { - let bytes_written = out_ptr.offset_from(values.as_ptr()); - assert!( - bytes_written.is_positive(), - "out_ptr ended before it started, not possible" - ); - - values.set_len(bytes_written as usize); - } + let bytes_written = out_ptr.offset_from(values.as_ptr()); + assert!( + bytes_written.is_positive(), + "out_ptr ended before it started, not possible" + ); + + values.set_len(bytes_written as usize); } /// Use the symbol table to compress the plaintext into a sequence of codes and escapes. @@ -588,38 +574,3 @@ pub(crate) fn compare_masked(left: u64, right: u64, ignored_bits: u16) -> bool { let mask = u64::MAX >> ignored_bits; (left & mask) == right } - -/// This is a function that will get monomorphized based on the value of `N` to do -/// a load of `N` values from the pointer in a minimum number of instructions into -/// an output `u64`. -#[inline] -pub(crate) unsafe fn extract_u64(ptr: *const u8) -> u64 { - match N { - 1 => std::ptr::read(ptr) as u64, - 2 => std::ptr::read_unaligned(ptr as *const u16) as u64, - 3 => { - let low = std::ptr::read(ptr) as u64; - let high = std::ptr::read_unaligned(ptr.byte_add(1) as *const u16) as u64; - high << 8 | low - } - 4 => std::ptr::read_unaligned(ptr as *const u32) as u64, - 5 => { - let low = std::ptr::read_unaligned(ptr as *const u32) as u64; - let high = ptr.byte_add(4).read() as u64; - high << 32 | low - } - 6 => { - let low = std::ptr::read_unaligned(ptr as *const u32) as u64; - let high = std::ptr::read_unaligned(ptr.byte_add(4) as *const u16) as u64; - high << 32 | low - } - 7 => { - let low = std::ptr::read_unaligned(ptr as *const u32) as u64; - let mid = std::ptr::read_unaligned(ptr.byte_add(4) as *const u16) as u64; - let high = std::ptr::read(ptr.byte_add(6)) as u64; - (high << 48) | (mid << 32) | low - } - 8 => std::ptr::read_unaligned(ptr as *const u64), - _ => unreachable!("N must be <= 8"), - } -}