Skip to content

Commit

Permalink
remove method that does not affect perf
Browse files Browse the repository at this point in the history
  • Loading branch information
a10y committed Sep 3, 2024
1 parent 9c9ad25 commit 15f11b0
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 85 deletions.
2 changes: 1 addition & 1 deletion benches/micro.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ fn one_megabyte(seed: &[u8]) -> Vec<u8> {
fn bench_compress(c: &mut Criterion) {
let mut group = c.benchmark_group("compress-overhead");
// Reusable memory to hold outputs
let mut output_buf: Vec<u8> = Vec::with_capacity(12);
let mut output_buf: Vec<u8> = Vec::with_capacity(8 * 1024 * 1024);

// We create a symbol table that requires probing the hash table to perform
// decompression.
Expand Down
23 changes: 8 additions & 15 deletions src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use std::cmp::Ordering;
use std::collections::BinaryHeap;

use crate::{
advance_8byte_word, compare_masked, extract_u64, lossy_pht::LossyPHT, Code, Compressor, Symbol,
advance_8byte_word, compare_masked, lossy_pht::LossyPHT, Code, Compressor, Symbol,
FSST_CODE_BASE, FSST_CODE_MASK,
};

Expand Down Expand Up @@ -709,20 +709,13 @@ impl CompressorBuilder {
// Load the last `remaining_byte`s of data into a final world. We then replicate the loop above,
// but shift data out of this word rather than advancing an input pointer and potentially reading
// unowned memory
let mut last_word = unsafe {
match remaining_bytes {
0 => 0,
1 => extract_u64::<1>(in_ptr),
2 => extract_u64::<2>(in_ptr),
3 => extract_u64::<3>(in_ptr),
4 => extract_u64::<4>(in_ptr),
5 => extract_u64::<5>(in_ptr),
6 => extract_u64::<6>(in_ptr),
7 => extract_u64::<7>(in_ptr),
8 => extract_u64::<8>(in_ptr),
_ => unreachable!("remaining bytes must be <= 8"),
}
};
let mut bytes = [0u8; 8];
unsafe {
// SAFETY: it is safe to read up to remaining_bytes from in_ptr, and remaining_bytes
// will be <= 8 bytes.
std::ptr::copy_nonoverlapping(in_ptr, bytes.as_mut_ptr(), remaining_bytes);
}
let mut last_word = u64::from_le_bytes(bytes);

let mut remaining_bytes = remaining_bytes;

Expand Down
89 changes: 20 additions & 69 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ impl Compressor {
/// # Safety
///
/// `out_ptr` must never be NULL or otherwise point to invalid memory.
// #[inline]
#[inline(never)]
pub unsafe fn compress_word(&self, word: u64, out_ptr: *mut u8) -> (usize, usize) {
// Speculatively write the first byte of `word` at offset 1. This is necessary if it is an escape, and
// if it isn't, it will be overwritten anyway.
Expand Down Expand Up @@ -482,55 +482,41 @@ impl Compressor {

let remaining_bytes = unsafe { in_end.byte_offset_from(in_ptr) };
assert!(
in_ptr == in_end || remaining_bytes.is_positive(),
"in_ptr exceeded in_end, should not be possible"
out_ptr < out_end || remaining_bytes == 0,
"output buffer sized too small"
);

let remaining_bytes = remaining_bytes as usize;

// Load the last `remaining_byte`s of data into a final world. We then replicate the loop above,
// but shift data out of this word rather than advancing an input pointer and potentially reading
// unowned memory.
let mut last_word = unsafe {
match remaining_bytes {
0 => 0,
1 => extract_u64::<1>(in_ptr),
2 => extract_u64::<2>(in_ptr),
3 => extract_u64::<3>(in_ptr),
4 => extract_u64::<4>(in_ptr),
5 => extract_u64::<5>(in_ptr),
6 => extract_u64::<6>(in_ptr),
7 => extract_u64::<7>(in_ptr),
8 => extract_u64::<8>(in_ptr),
_ => unreachable!("remaining bytes must be <= 8"),
}
};
let mut bytes = [0u8; 8];
std::ptr::copy_nonoverlapping(in_ptr, bytes.as_mut_ptr(), remaining_bytes);
let mut last_word = u64::from_le_bytes(bytes);

while in_ptr < in_end && out_ptr < out_end {
unsafe {
// Load a full 8-byte word of data from in_ptr.
// SAFETY: caller asserts in_ptr is not null. we may read past end of pointer though.
let (advance_in, advance_out) = self.compress_word(last_word, out_ptr);
in_ptr = in_ptr.byte_add(advance_in);
out_ptr = out_ptr.byte_add(advance_out);
// Load a full 8-byte word of data from in_ptr.
// SAFETY: caller asserts in_ptr is not null. we may read past end of pointer though.
let (advance_in, advance_out) = self.compress_word(last_word, out_ptr);
in_ptr = in_ptr.byte_add(advance_in);
out_ptr = out_ptr.byte_add(advance_out);

last_word = advance_8byte_word(last_word, advance_in);
}
last_word = advance_8byte_word(last_word, advance_in);
}

// in_ptr should have exceeded in_end
assert!(in_ptr >= in_end, "exhausted output buffer before exhausting input, there is a bug in SymbolTable::compress()");

// Count the number of bytes written
// SAFETY: assertion
unsafe {
let bytes_written = out_ptr.offset_from(values.as_ptr());
assert!(
bytes_written.is_positive(),
"out_ptr ended before it started, not possible"
);

values.set_len(bytes_written as usize);
}
let bytes_written = out_ptr.offset_from(values.as_ptr());
assert!(
bytes_written.is_positive(),
"out_ptr ended before it started, not possible"
);

values.set_len(bytes_written as usize);
}

/// Use the symbol table to compress the plaintext into a sequence of codes and escapes.
Expand Down Expand Up @@ -588,38 +574,3 @@ pub(crate) fn compare_masked(left: u64, right: u64, ignored_bits: u16) -> bool {
let mask = u64::MAX >> ignored_bits;
(left & mask) == right
}

/// This is a function that will get monomorphized based on the value of `N` to do
/// a load of `N` values from the pointer in a minimum number of instructions into
/// an output `u64`.
#[inline]
pub(crate) unsafe fn extract_u64<const N: usize>(ptr: *const u8) -> u64 {
match N {
1 => std::ptr::read(ptr) as u64,
2 => std::ptr::read_unaligned(ptr as *const u16) as u64,
3 => {
let low = std::ptr::read(ptr) as u64;
let high = std::ptr::read_unaligned(ptr.byte_add(1) as *const u16) as u64;
high << 8 | low
}
4 => std::ptr::read_unaligned(ptr as *const u32) as u64,
5 => {
let low = std::ptr::read_unaligned(ptr as *const u32) as u64;
let high = ptr.byte_add(4).read() as u64;
high << 32 | low
}
6 => {
let low = std::ptr::read_unaligned(ptr as *const u32) as u64;
let high = std::ptr::read_unaligned(ptr.byte_add(4) as *const u16) as u64;
high << 32 | low
}
7 => {
let low = std::ptr::read_unaligned(ptr as *const u32) as u64;
let mid = std::ptr::read_unaligned(ptr.byte_add(4) as *const u16) as u64;
let high = std::ptr::read(ptr.byte_add(6)) as u64;
(high << 48) | (mid << 32) | low
}
8 => std::ptr::read_unaligned(ptr as *const u64),
_ => unreachable!("N must be <= 8"),
}
}

0 comments on commit 15f11b0

Please sign in to comment.