Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Get compress performance to match paper algorithm 4 #3

Merged
merged 20 commits into from
Aug 15, 2024
Merged
7 changes: 0 additions & 7 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,9 +1,2 @@
/target
.idea/


# Added by cargo
#
# already existing elements were commented out

#/target
13 changes: 11 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 14 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
[package]
name = "fsst-rs"
version = "0.0.1"
description = "Pure-Rust implementation of Fast Static Symbol Tables algorithm for string compression"
authors = ["SpiralDB Developers <[email protected]>"]
license = "Apache-2.0"
repository = "https://github.com/spiraldb/fsst"
edition = "2021"

[lints.rust]
Expand All @@ -22,7 +26,16 @@ use_debug = { level = "deny" }
criterion = "0.5"
lz4 = "1"

[[example]]
name = "round_trip"
bench = false
test = false

[[bench]]
name = "compress"
harness = false
bench = true

[[test]]
name = "correctness"
test = true
bench = false
36 changes: 7 additions & 29 deletions benches/compress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
//!
//! Also contains LZ4 baseline.
#![allow(missing_docs)]
use core::str;
use std::io::{Cursor, Read, Write};

use criterion::{black_box, criterion_group, criterion_main, Criterion};
use lz4::liblz4::BlockChecksum;
use lz4::{BlockSize, ContentChecksum};

use fsst_rs::{train, Code};
use fsst_rs::{train, ESCAPE_CODE};

const CORPUS: &str = include_str!("dracula.txt");
const TEST: &str = "I found my smattering of German very useful here";
Expand All @@ -26,17 +27,17 @@ fn bench_fsst(c: &mut Criterion) {
let plaintext = TEST.as_bytes();

let compressed = table.compress(plaintext);
let escape_count = compressed
.iter()
.filter(|b| **b == Code::ESCAPE_CODE)
.count();
let escape_count = compressed.iter().filter(|b| **b == ESCAPE_CODE).count();
let ratio = (plaintext.len() as f64) / (compressed.len() as f64);
println!(
"Escapes = {escape_count}/{}, compression_ratio = {ratio}",
compressed.len()
);

assert_eq!(table.decompress(&compressed), TEST.as_bytes());
let decompressed = table.decompress(&compressed);
let decompressed = str::from_utf8(&decompressed).unwrap();
println!("DECODED: {}", decompressed);
assert_eq!(decompressed, TEST);

group.bench_function("compress-single", |b| {
b.iter(|| black_box(table.compress(black_box(plaintext))));
Expand All @@ -50,29 +51,6 @@ fn bench_fsst(c: &mut Criterion) {
fn bench_lz4(c: &mut Criterion) {
let mut group = c.benchmark_group("lz4");

// {
// let compressed = Vec::with_capacity(10_000);
// let mut encoder = lz4::EncoderBuilder::new()
// .block_size(BlockSize::Max64KB)
// .build(compressed)
// .unwrap();
//
// encoder.write_all(TEST.as_bytes()).unwrap();
// let (compressed, result) = encoder.finish();
// result.unwrap();
//
// let ratio = (TEST.as_bytes().len() as f64) / (compressed.len() as f64);
// println!("LZ4 compress_ratio = {ratio}");
//
// // ensure decodes cleanly
// let cursor = Cursor::new(compressed);
// let mut decoder = lz4::Decoder::new(cursor).unwrap();
// let mut output = String::new();
//
// decoder.read_to_string(&mut output).unwrap();
// assert_eq!(output.as_str(), TEST);
// }

group.bench_function("compress-single", |b| {
let mut compressed = Vec::with_capacity(100_000_000);
let mut encoder = lz4::EncoderBuilder::new()
Expand Down
19 changes: 19 additions & 0 deletions examples/round_trip.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
//! Simple example where we show round-tripping a string through the static symbol table.

use core::str;

fn main() {
// Train on a sample.
let sample = "the quick brown fox jumped over the lazy dog";
let trained = fsst_rs::train(sample.as_bytes());
let compressed = trained.compress(sample.as_bytes());
println!("compressed: {} => {}", sample.len(), compressed.len());
// decompress now
let decode = trained.decompress(&compressed);
let output = str::from_utf8(&decode).unwrap();
println!(
"decoded to the original: len={} text='{}'",
decode.len(),
output
);
}
3 changes: 1 addition & 2 deletions rust-toolchain.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
[toolchain]
channel = "nightly-2024-06-19"
channel = "stable"
lwwmanning marked this conversation as resolved.
Show resolved Hide resolved
components = ["rust-src", "rustfmt", "clippy"]
profile = "minimal"

67 changes: 36 additions & 31 deletions src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
use std::cmp::Ordering;
use std::collections::BinaryHeap;

use crate::{Code, Symbol, SymbolTable};
use crate::find_longest::FindLongestSymbol;
use crate::{Symbol, SymbolTable, MAX_CODE};

#[derive(Debug, Clone)]
struct Counter {
Expand All @@ -21,29 +22,29 @@ struct Counter {
impl Counter {
fn new() -> Self {
Self {
counts1: vec![0; Code::CODE_MAX as usize],
counts2: vec![vec![0; Code::CODE_MAX as usize]; Code::CODE_MAX as usize],
counts1: vec![0; MAX_CODE as usize],
counts2: vec![vec![0; MAX_CODE as usize]; MAX_CODE as usize],
}
}

#[inline]
fn record_count1(&mut self, code1: Code) {
self.counts1[code1.0 as usize] += 1;
fn record_count1(&mut self, code1: u16) {
self.counts1[code1 as usize] += 1;
}

#[inline]
fn record_count2(&mut self, code1: Code, code2: Code) {
self.counts2[code1.0 as usize][code2.0 as usize] += 1;
fn record_count2(&mut self, code1: u16, code2: u16) {
self.counts2[code1 as usize][code2 as usize] += 1;
}

#[inline]
fn count1(&self, code: Code) -> usize {
self.counts1[code.0 as usize]
fn count1(&self, code: u16) -> usize {
self.counts1[code as usize]
}

#[inline]
fn count2(&self, code1: Code, code2: Code) -> usize {
self.counts2[code1.0 as usize][code2.0 as usize]
fn count2(&self, code1: u16, code2: u16) -> usize {
self.counts2[code1 as usize][code2 as usize]
}
}

Expand All @@ -65,6 +66,9 @@ pub fn train(corpus: impl AsRef<[u8]>) -> SymbolTable {
let mut table = SymbolTable::default();
// TODO(aduffy): handle truncating/sampling if corpus > requires sample size.
let sample = corpus.as_ref();
if sample.is_empty() {
return table;
}
for _generation in 0..MAX_GENERATIONS {
let counter = table.compress_count(sample);
table = table.optimize(counter);
Expand All @@ -81,13 +85,13 @@ impl SymbolTable {
let len = sample.len();
let mut prev_code = self.find_longest_symbol(sample);
counter.record_count1(prev_code);
let mut pos = self.symbols[prev_code.0 as usize].len();
let mut pos = self.symbols[prev_code as usize].len();

while pos < len {
let code = self.find_longest_symbol(&sample[pos..len]);
counter.record_count1(code);
counter.record_count2(prev_code, code);
pos += self.symbols[code.0 as usize].len();
pos += self.symbols[code as usize].len();
prev_code = code;
}

Expand All @@ -100,17 +104,15 @@ impl SymbolTable {
let mut res = SymbolTable::default();
let mut pqueue = BinaryHeap::new();
for code1 in 0..511 {
let code1 = Code::from_u16(code1);
let symbol1 = self.symbols[code1.0 as usize];
let symbol1 = self.symbols[code1 as usize];
let gain = counters.count1(code1) * symbol1.len();
pqueue.push(Candidate {
symbol: symbol1,
gain,
});

for code2 in 0..511 {
let code2 = Code::from_u16(code2);
let symbol2 = &self.symbols[code2.0 as usize];
let symbol2 = &self.symbols[code2 as usize];
// If either symbol is zero-length, or if merging would yield a symbol of
// length greater than 8, skip.
if symbol1.len() + symbol2.len() >= 8 || symbol1.is_empty() || symbol2.is_empty() {
Expand All @@ -133,10 +135,13 @@ impl SymbolTable {
}

// Pop the 255 best symbols.
pqueue
.iter()
.take(255)
.for_each(|candidate| res.insert(candidate.symbol));
let mut n_symbols = 0;
while !pqueue.is_empty() && n_symbols < 255 {
let candidate = pqueue.pop().unwrap();
if res.insert(candidate.symbol) {
n_symbols += 1;
}
}

res
}
Expand Down Expand Up @@ -181,7 +186,7 @@ impl Ord for Candidate {

#[cfg(test)]
mod test {
use crate::{train, Code};
use crate::{train, ESCAPE_CODE};

#[test]
fn test_builder() {
Expand All @@ -193,26 +198,26 @@ mod test {
let compressed = table.compress(text.as_bytes());

// Ensure that the compressed string has no escape bytes
assert!(compressed.iter().all(|b| *b != Code::ESCAPE_CODE));
assert!(compressed.iter().all(|b| *b != ESCAPE_CODE));

// Ensure that we can compress a string with no values seen at training time.
// Ensure that we can compress a string with no values seen at training time, with escape bytes
let compressed = table.compress("xyz123".as_bytes());
assert_eq!(
compressed,
vec![
Code::ESCAPE_CODE,
ESCAPE_CODE,
b'x',
Code::ESCAPE_CODE,
ESCAPE_CODE,
b'y',
Code::ESCAPE_CODE,
ESCAPE_CODE,
b'z',
Code::ESCAPE_CODE,
ESCAPE_CODE,
b'1',
Code::ESCAPE_CODE,
ESCAPE_CODE,
b'2',
Code::ESCAPE_CODE,
ESCAPE_CODE,
b'3',
]
)
);
}
}
5 changes: 5 additions & 0 deletions src/find_longest/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
mod naive;

pub trait FindLongestSymbol {
fn find_longest_symbol(&self, text: &[u8]) -> u16;
}
28 changes: 28 additions & 0 deletions src/find_longest/naive.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
use crate::find_longest::FindLongestSymbol;
use crate::SymbolTable;

// Find the code that maps to a symbol with longest-match to a piece of text.
//
// This is the naive algorithm that just scans the whole table and is very slow.

impl FindLongestSymbol for SymbolTable {
a10y marked this conversation as resolved.
Show resolved Hide resolved
// NOTE(aduffy): if you don't disable inlining, this function won't show up in profiles.
#[inline(never)]
fn find_longest_symbol(&self, text: &[u8]) -> u16 {
debug_assert!(!text.is_empty(), "text must not be empty");

// Find the code that best maps to the provided text table here.
// Start with the code corresponding to the escape of the first character in the text
let mut best_code = text[0] as u16;
let mut best_overlap = 1;
for code in 256..511 {
let symbol = &self.symbols[code as usize];
if symbol.is_prefix(text) && symbol.len() > best_overlap {
best_code = code;
best_overlap = symbol.len();
}
}

best_code
}
}
Loading
Loading