Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: port in more from the C++ code #24

Merged
merged 14 commits into from
Sep 3, 2024
21 changes: 20 additions & 1 deletion benches/compress.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
//! Benchmarks for FSST compression, decompression, and symbol table training.
#![allow(missing_docs)]
use core::str;
use std::{fs::File, io::Read};

use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput};

Expand Down Expand Up @@ -52,5 +53,23 @@ fn bench_fsst(c: &mut Criterion) {
});
}

criterion_group!(compress_bench, bench_fsst);
fn bench_tpch_comments(c: &mut Criterion) {
let mut group = c.benchmark_group("tpch");

// Load the entire file into memory
let mut file = File::open("/Users/aduffy/code/cwi-fsst/build/comments").unwrap();
let mut text = String::new();
file.read_to_string(&mut text).unwrap();

let lines: Vec<&str> = text.lines().collect();
let lines_sliced: Vec<&[u8]> = lines.iter().map(|s| s.as_bytes()).collect();

group.bench_function("compress-comments", |b| {
b.iter(|| {
std::hint::black_box(Compressor::train_bulk(&lines_sliced));
});
});
}

criterion_group!(compress_bench, bench_fsst, bench_tpch_comments);
criterion_main!(compress_bench);
20 changes: 10 additions & 10 deletions examples/file_compressor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,20 +42,20 @@ fn main() {
let f = File::open(input_path).unwrap();
let size_bytes = f.metadata().unwrap().size() as usize;

const CHUNK_SIZE: usize = 16 * 1024 * 1024;
const CHUNK_SIZE: usize = 16 * 1024;

let mut chunk_idx = 1;
// let mut chunk_idx = 1;
let mut pos = 0;
let mut chunk = vec![0u8; CHUNK_SIZE];
while pos + CHUNK_SIZE < size_bytes {
f.read_exact_at(&mut chunk, pos as u64).unwrap();
// Compress the chunk, don't write it anywhere.
let compact = compressor.compress(&chunk);
let compression_ratio = (CHUNK_SIZE as f64) / (compact.len() as f64);
println!("compressed chunk {chunk_idx} with ratio {compression_ratio}");
let _ = std::hint::black_box(compressor.compress(&chunk));
// let compression_ratio = (CHUNK_SIZE as f64) / (compact.len() as f64);
// println!("compressed chunk {chunk_idx} with ratio {compression_ratio}");

pos += CHUNK_SIZE;
chunk_idx += 1;
// chunk_idx += 1;
}

// Read last chunk with a new custom-sized buffer.
Expand All @@ -64,9 +64,9 @@ fn main() {
chunk = vec![0u8; size_bytes - pos];
f.read_exact_at(&mut chunk, pos as u64).unwrap();
// Compress the chunk, don't write it anywhere.
let compact = compressor.compress(&chunk[0..amount]);
let compression_ratio = (amount as f64) / (compact.len() as f64);
println!("compressed chunk {chunk_idx} with ratio {compression_ratio}");
let _ = std::hint::black_box(compressor.compress(&chunk[0..amount]));
// let compression_ratio = (amount as f64) / (compact.len() as f64);
// println!("compressed chunk {chunk_idx} with ratio {compression_ratio}");
}
println!("done");
println!("done compressing");
}
228 changes: 200 additions & 28 deletions src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
use std::cmp::Ordering;
use std::collections::BinaryHeap;

use crate::{CodeMeta, Compressor, Symbol, ESCAPE_CODE, MAX_CODE};
use crate::{advance_8byte_word, extract_u64, CodeMeta, Compressor, Symbol, MAX_CODE};

/// Bitmap that only works for values up to 512
#[derive(Clone, Copy, Debug, Default)]
Expand Down Expand Up @@ -210,6 +210,70 @@ const MAX_GENERATIONS: usize = 5;
#[cfg(miri)]
const MAX_GENERATIONS: usize = 2;

const FSST_SAMPLETARGET: usize = 1 << 14;
const FSST_SAMPLEMAX: usize = 1 << 15;
const FSST_SAMPLELINE: usize = 512;

// Create a sample from a set of strings in the input
//
// SAFETY: sample_buf must be >= FSST_SAMPLEMAX bytes long. Providing something less may cause unexpected failures.
fn make_sample<'a, 'b: 'a>(sample_buf: &'a mut Vec<u8>, str_in: &Vec<&'b [u8]>) -> Vec<&'a [u8]> {
debug_assert!(
sample_buf.capacity() >= FSST_SAMPLEMAX,
"sample_buf.len() < FSST_SAMPLEMAX"
);

let mut sample: Vec<&[u8]> = Vec::new();

let tot_size: usize = str_in.iter().map(|s| s.len()).sum();
if tot_size < FSST_SAMPLETARGET {
return str_in.clone();
}

let mut sample_rnd = fsst_hash(4637947);
let sample_lim = FSST_SAMPLETARGET;
let mut sample_buf_offset: usize = 0;

while sample_buf_offset < sample_lim {
sample_rnd = fsst_hash(sample_rnd);
let mut line_nr = sample_rnd % str_in.len();

// Find the first non-empty chunk starting at line_nr, wrapping around if
// necessary.
//
// TODO: this will loop infinitely if there are no non-empty lines in the sample
while str_in[line_nr].len() == 0 {
if line_nr == str_in.len() {
line_nr = 0;
}
}

let line = str_in[line_nr];
let chunks = 1 + ((line.len() - 1) / FSST_SAMPLELINE);
sample_rnd = fsst_hash(sample_rnd);
let chunk = FSST_SAMPLELINE * (sample_rnd % chunks);

let len = FSST_SAMPLELINE.min(line.len() - chunk);
// println!("extending sample with chunk str_in[{line_nr}][{chunk}...len={len}]");

sample_buf.extend_from_slice(&str_in[line_nr][chunk..chunk + len]);

// SAFETY: this is the data we just placed into `sample_buf` in the line above.
let slice =
unsafe { std::slice::from_raw_parts(sample_buf.as_ptr().add(sample_buf_offset), len) };

sample.push(slice);

sample_buf_offset += len;
}

sample
}

fn fsst_hash(value: usize) -> usize {
(value * 2971215073) ^ (value >> 15)
}

impl Compressor {
/// Clear all set items from the compressor.
///
Expand Down Expand Up @@ -248,6 +312,10 @@ impl Compressor {
return compressor;
}

// Make the sample for each iteration.
//
// The sample is just a vector of slices, so we don't actually have to move anything around.

let mut counter = Counter::new();
for _generation in 0..(MAX_GENERATIONS - 1) {
compressor.compress_count(sample, &mut counter);
Expand All @@ -260,42 +328,147 @@ impl Compressor {

compressor
}

/// Train on a collection of samples.
pub fn train_bulk(values: &Vec<&[u8]>) -> Self {
let mut sample_memory = Vec::with_capacity(FSST_SAMPLEMAX);
let sample = make_sample(&mut sample_memory, values);

let mut counters = Counter::new();
let mut compressor = Compressor::default();

for sample_frac in [8usize, 38, 68, 98, 128] {
for i in 0..sample.len() {
if sample_frac < 128 {
if fsst_hash(i) & 127 > sample_frac {
continue;
}
}

compressor.compress_count(sample[i], &mut counters);
}

compressor.optimize(&counters, sample_frac == 128);
counters.clear();
}

compressor
}
}

impl Compressor {
/// Compress the text using the current symbol table. Count the code occurrences
/// and code-pair occurrences to allow us to calculate apparent gain.
///
/// NOTE: this is largely an unfortunate amount of copy-paste from `compress`, just to make sure
/// we can do all the counting in a single pass.
fn compress_count(&self, sample: &[u8], counter: &mut Counter) {
let compressed = self.compress(sample);
let len = compressed.len();

if len == 0 {
if sample.is_empty() {
return;
}

fn next_code(pos: usize, compressed: &[u8]) -> (u16, usize) {
if compressed[pos] == ESCAPE_CODE {
(compressed[pos + 1] as u16, 2)
} else {
(256 + compressed[pos] as u16, 1)
}
}
// Output space.
let mut out_buf = [0u8, 0u8];

let mut in_ptr = sample.as_ptr();
let out_ptr = out_buf.as_mut_ptr();

// SAFETY: `end` will point just after the end of the `plaintext` slice.
let in_end = unsafe { in_ptr.byte_add(sample.len()) };
let in_end_sub8 = in_end as usize - 8;

let mut prev_code: u16 = MAX_CODE;

while (in_ptr as usize) < in_end_sub8 {
// SAFETY: pointer ranges are checked in the loop condition
unsafe {
// Load a full 8-byte word of data from in_ptr.
// SAFETY: caller asserts in_ptr is not null. we may read past end of pointer though.
let word: u64 = std::ptr::read_unaligned(in_ptr as *const u64);
let (advance_in, advance_out) = self.compress_word(word, out_ptr);
match advance_out {
1 => {
// Record a true symbol
let code_u16 = out_ptr.read() as u16 + 256u16;
counter.record_count1(code_u16);
if prev_code != MAX_CODE {
counter.record_count2(prev_code, code_u16);
}
prev_code = code_u16;
}
2 => {
// Record an escape.
let escape_code = out_ptr.byte_offset(1).read() as u16;
counter.record_count1(escape_code);
if prev_code != MAX_CODE {
counter.record_count2(prev_code, escape_code);
}
prev_code = escape_code;
}
_ => unreachable!("advance_out will only be 1 or 2 bytes"),
}

// Get first code, record count
let (code, pos) = next_code(0, &compressed);
counter.record_count1(code);
in_ptr = in_ptr.byte_add(advance_in);
};
}

let mut pos = pos;
let mut prev_code = code;
let remaining_bytes = unsafe { in_end.byte_offset_from(in_ptr) };
debug_assert!(
remaining_bytes.is_positive(),
"in_ptr exceeded in_end, should not be possible"
);
let remaining_bytes = remaining_bytes as usize;

// Load the last `remaining_byte`s of data into a final world. We then replicate the loop above,
// but shift data out of this word rather than advancing an input pointer and potentially reading
// unowned memory.
let mut last_word = unsafe {
match remaining_bytes {
0 => 0,
1 => extract_u64::<1>(in_ptr),
2 => extract_u64::<2>(in_ptr),
3 => extract_u64::<3>(in_ptr),
4 => extract_u64::<4>(in_ptr),
5 => extract_u64::<5>(in_ptr),
6 => extract_u64::<6>(in_ptr),
7 => extract_u64::<7>(in_ptr),
8 => extract_u64::<8>(in_ptr),
_ => unreachable!("remaining bytes must be <= 8"),
}
};

while pos < len {
let (code, advance) = next_code(pos, &compressed);
pos += advance;
while in_ptr < in_end {
unsafe {
// Load a full 8-byte word of data from in_ptr.
// SAFETY: caller asserts in_ptr is not null. we may read past end of pointer though.
let (advance_in, advance_out) = self.compress_word(last_word, out_ptr);

match advance_out {
1 => {
// Record a true symbol
let code_u16 = out_buf[0] as u16 + 256u16;
counter.record_count1(code_u16);
if prev_code != MAX_CODE {
counter.record_count2(prev_code, code_u16);
}
prev_code = code_u16;
}
2 => {
// Record an escape.
let escape_code = out_buf[1] as u16;
counter.record_count1(escape_code);
if prev_code != MAX_CODE {
counter.record_count2(prev_code, escape_code);
}
prev_code = escape_code;
}
_ => unreachable!("advance_out will only be 1 or 2 bytes"),
}

counter.record_count1(code);
counter.record_count2(prev_code, code);
in_ptr = in_ptr.byte_add(advance_in);

prev_code = code;
last_word = advance_8byte_word(last_word, advance_in);
}
}
}

Expand All @@ -308,17 +481,14 @@ impl Compressor {
let symbol1 = self.symbols[code1 as usize];
let symbol1_len = symbol1.len();
let count = counters.count1(code1);
// If count is zero, we can skip the whole inner loop.
if count == 0 {
continue;
}

let mut gain = count * symbol1_len;
// NOTE: use heuristic from C++ implementation to boost the gain of single-byte symbols.
// This helps to reduce exception counts.
if code1 < 256 {
gain *= 8;
}

if gain > 0 {
pqueue.push(Candidate {
symbol: symbol1,
Expand All @@ -327,7 +497,7 @@ impl Compressor {
}

for code2 in counters.second_codes(code1) {
let symbol2 = &self.symbols[code2 as usize];
let symbol2 = self.symbols[code2 as usize];

// If merging would yield a symbol of length greater than 8, skip.
if symbol1_len + symbol2.len() > 8 {
Expand Down Expand Up @@ -419,6 +589,8 @@ impl Ord for Candidate {
mod test {
use crate::{builder::CodesBitmap, Compressor, ESCAPE_CODE};

use super::{make_sample, FSST_SAMPLEMAX};

#[test]
fn test_builder() {
// Train a Compressor on the toy string
Expand Down
Loading
Loading