Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Get compress performance to match paper algorithm 4 #3

Merged
merged 20 commits into from
Aug 15, 2024
Merged
22 changes: 22 additions & 0 deletions .cargo/config.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
[target.aarch64-apple-darwin]
rustflags = [
"-C",
"link-arg=-undefined",
"-C",
"link-arg=dynamic_lookup",
"-Z",
"verbose-internals",
"-Z",
"track-diagnostics",
]
[target.x86_64-apple-darwin]
rustflags = [
"-C",
"link-arg=-undefined",
"-C",
"link-arg=dynamic_lookup",
"-Z",
"verbose-internals",
"-Z",
"track-diagnostics",
]
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@
# already existing elements were commented out

#/target

# compiler debug reports
rustc-ice*
13 changes: 11 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 14 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
[package]
name = "fsst-rs"
version = "0.0.1"
description = "Pure-Rust implementation of Fast Static Symbol Tables algorithm for string compression"
authors = ["SpiralDB Developers <[email protected]>"]
license = "Apache-2.0"
repository = "https://github.com/spiraldb/fsst"
edition = "2021"

[lints.rust]
Expand All @@ -22,7 +26,16 @@ use_debug = { level = "deny" }
criterion = "0.5"
lz4 = "1"

[[example]]
name = "round_trip"
bench = false
test = false

[[bench]]
name = "compress"
harness = false
bench = true

[[test]]
name = "correctness"
test = true
bench = false
13 changes: 7 additions & 6 deletions benches/compress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
//!
//! Also contains LZ4 baseline.
#![allow(missing_docs)]
use core::str;
use std::io::{Cursor, Read, Write};

use criterion::{black_box, criterion_group, criterion_main, Criterion};
use lz4::liblz4::BlockChecksum;
use lz4::{BlockSize, ContentChecksum};

use fsst_rs::{train, Code};
use fsst_rs::{train, ESCAPE_CODE};

const CORPUS: &str = include_str!("dracula.txt");
const TEST: &str = "I found my smattering of German very useful here";
Expand All @@ -26,17 +27,17 @@ fn bench_fsst(c: &mut Criterion) {
let plaintext = TEST.as_bytes();

let compressed = table.compress(plaintext);
let escape_count = compressed
.iter()
.filter(|b| **b == Code::ESCAPE_CODE)
.count();
let escape_count = compressed.iter().filter(|b| **b == ESCAPE_CODE).count();
let ratio = (plaintext.len() as f64) / (compressed.len() as f64);
println!(
"Escapes = {escape_count}/{}, compression_ratio = {ratio}",
compressed.len()
);

assert_eq!(table.decompress(&compressed), TEST.as_bytes());
let decompressed = table.decompress(&compressed);
let decompressed = str::from_utf8(&decompressed).unwrap();
println!("DECODED: {}", decompressed);
assert_eq!(decompressed, TEST);

group.bench_function("compress-single", |b| {
b.iter(|| black_box(table.compress(black_box(plaintext))));
Expand Down
19 changes: 19 additions & 0 deletions examples/round_trip.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
//! Simple example where we show round-tripping a string through the static symbol table.

use core::str;

fn main() {
// Train on a sample.
let sample = "the quick brown fox jumped over the lazy dog";
let trained = fsst_rs::train(sample.as_bytes());
let compressed = trained.compress(sample.as_bytes());
println!("compressed: {} => {}", sample.len(), compressed.len());
// decompress now
let decode = trained.decompress(&compressed);
let output = str::from_utf8(&decode).unwrap();
println!(
"decoded to the original: len={} text='{}'",
decode.len(),
output
);
}
4 changes: 2 additions & 2 deletions rust-toolchain.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[toolchain]
channel = "nightly-2024-06-19"
# channel = "stable"
channel = "nightly-2024-08-14"
components = ["rust-src", "rustfmt", "clippy"]
profile = "minimal"

80 changes: 50 additions & 30 deletions src/builder.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
// Copyright 2024 Spiral, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//! Functions and types used for building a [`SymbolTable`] from a corpus of text.
//!
//! This module implements the logic from Algorithm 3 of the [FSST Paper].
Expand All @@ -7,7 +21,8 @@
use std::cmp::Ordering;
use std::collections::BinaryHeap;

use crate::{Code, Symbol, SymbolTable};
use crate::find_longest::FindLongestSymbol;
use crate::{CodeMeta, Symbol, SymbolTable, MAX_CODE};

#[derive(Debug, Clone)]
struct Counter {
Expand All @@ -21,29 +36,29 @@ struct Counter {
impl Counter {
fn new() -> Self {
Self {
counts1: vec![0; Code::CODE_MAX as usize],
counts2: vec![vec![0; Code::CODE_MAX as usize]; Code::CODE_MAX as usize],
counts1: vec![0; MAX_CODE as usize],
counts2: vec![vec![0; MAX_CODE as usize]; MAX_CODE as usize],
}
}

#[inline]
fn record_count1(&mut self, code1: Code) {
self.counts1[code1.0 as usize] += 1;
fn record_count1(&mut self, code1: u16) {
self.counts1[code1 as usize] += 1;
}

#[inline]
fn record_count2(&mut self, code1: Code, code2: Code) {
self.counts2[code1.0 as usize][code2.0 as usize] += 1;
fn record_count2(&mut self, code1: u16, code2: u16) {
self.counts2[code1 as usize][code2 as usize] += 1;
}

#[inline]
fn count1(&self, code: Code) -> usize {
self.counts1[code.0 as usize]
fn count1(&self, code: u16) -> usize {
self.counts1[code as usize]
}

#[inline]
fn count2(&self, code1: Code, code2: Code) -> usize {
self.counts2[code1.0 as usize][code2.0 as usize]
fn count2(&self, code1: u16, code2: u16) -> usize {
self.counts2[code1 as usize][code2 as usize]
}
}

Expand All @@ -65,6 +80,9 @@ pub fn train(corpus: impl AsRef<[u8]>) -> SymbolTable {
let mut table = SymbolTable::default();
// TODO(aduffy): handle truncating/sampling if corpus > requires sample size.
let sample = corpus.as_ref();
if sample.is_empty() {
return table;
}
for _generation in 0..MAX_GENERATIONS {
let counter = table.compress_count(sample);
table = table.optimize(counter);
Expand All @@ -81,13 +99,13 @@ impl SymbolTable {
let len = sample.len();
let mut prev_code = self.find_longest_symbol(sample);
counter.record_count1(prev_code);
let mut pos = self.symbols[prev_code.0 as usize].len();
let mut pos = self.symbols[prev_code as usize].len();

while pos < len {
let code = self.find_longest_symbol(&sample[pos..len]);
counter.record_count1(code);
counter.record_count2(prev_code, code);
pos += self.symbols[code.0 as usize].len();
pos += self.symbols[code as usize].len();
prev_code = code;
}

Expand All @@ -100,17 +118,15 @@ impl SymbolTable {
let mut res = SymbolTable::default();
let mut pqueue = BinaryHeap::new();
for code1 in 0..511 {
let code1 = Code::from_u16(code1);
let symbol1 = self.symbols[code1.0 as usize];
let symbol1 = self.symbols[code1 as usize];
let gain = counters.count1(code1) * symbol1.len();
pqueue.push(Candidate {
symbol: symbol1,
gain,
});

for code2 in 0..511 {
let code2 = Code::from_u16(code2);
let symbol2 = &self.symbols[code2.0 as usize];
let symbol2 = &self.symbols[code2 as usize];
// If either symbol is zero-length, or if merging would yield a symbol of
// length greater than 8, skip.
if symbol1.len() + symbol2.len() >= 8 || symbol1.is_empty() || symbol2.is_empty() {
Expand All @@ -133,10 +149,13 @@ impl SymbolTable {
}

// Pop the 255 best symbols.
pqueue
.iter()
.take(255)
.for_each(|candidate| res.insert(candidate.symbol));
let mut n_symbols = 0;
while !pqueue.is_empty() && n_symbols < 255 {
let candidate = pqueue.pop().unwrap();
if res.insert(candidate.symbol) {
n_symbols += 1;
}
}

res
}
Expand Down Expand Up @@ -181,7 +200,7 @@ impl Ord for Candidate {

#[cfg(test)]
mod test {
use crate::{train, Code};
use crate::{train, ESCAPE_CODE};

#[test]
fn test_builder() {
Expand All @@ -191,28 +210,29 @@ mod test {

// Use the table to compress a string, see the values
let compressed = table.compress(text.as_bytes());
assert_eq!(compressed, vec![0u8, 1u8, 2u8]);

// Ensure that the compressed string has no escape bytes
assert!(compressed.iter().all(|b| *b != Code::ESCAPE_CODE));
assert!(compressed.iter().all(|b| *b != ESCAPE_CODE));

// Ensure that we can compress a string with no values seen at training time.
let compressed = table.compress("xyz123".as_bytes());
assert_eq!(
compressed,
vec![
Code::ESCAPE_CODE,
ESCAPE_CODE,
b'x',
Code::ESCAPE_CODE,
ESCAPE_CODE,
b'y',
Code::ESCAPE_CODE,
ESCAPE_CODE,
b'z',
Code::ESCAPE_CODE,
ESCAPE_CODE,
b'1',
Code::ESCAPE_CODE,
ESCAPE_CODE,
b'2',
Code::ESCAPE_CODE,
ESCAPE_CODE,
b'3',
]
)
);
}
}
21 changes: 21 additions & 0 deletions src/find_longest/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Copyright 2024 Spiral, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use crate::CodeMeta;

mod naive;

pub trait FindLongestSymbol {
fn find_longest_symbol(&self, text: &[u8]) -> u16;
}
42 changes: 42 additions & 0 deletions src/find_longest/naive.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Copyright 2024 Spiral, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use crate::find_longest::FindLongestSymbol;
use crate::{CodeMeta, SymbolTable};

// Find the code that maps to a symbol with longest-match to a piece of text.
//
// This is the naive algorithm that just scans the whole table and is very slow.

impl FindLongestSymbol for SymbolTable {
a10y marked this conversation as resolved.
Show resolved Hide resolved
// NOTE(aduffy): if you don't disable inlining, this function won't show up in profiles.
#[inline(never)]
fn find_longest_symbol(&self, text: &[u8]) -> u16 {
debug_assert!(!text.is_empty(), "text must not be empty");

// Find the code that best maps to the provided text table here.
// Start with the code corresponding to the escape of the first character in the text
let mut best_code = text[0] as u16;
let mut best_overlap = 1;
for code in 256..511 {
let symbol = &self.symbols[code as usize];
if symbol.is_prefix(text) && symbol.len() > best_overlap {
best_code = code;
best_overlap = symbol.len();
}
}

best_code
}
}
Loading
Loading