Skip to content

Commit

Permalink
first happy fse roundtrip
Browse files Browse the repository at this point in the history
  • Loading branch information
KillingSpark committed Oct 13, 2024
1 parent 73f7797 commit e02f5ba
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 16 deletions.
3 changes: 3 additions & 0 deletions src/fse/fse_decoder.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::dbg;

use crate::decoding::bit_reader::BitReader;
use crate::decoding::bit_reader_reverse::{BitReaderReversed, GetBitsError};
use alloc::vec::Vec;
Expand Down Expand Up @@ -197,6 +199,7 @@ impl<'t> FSEDecoder<'t> {

/// Advance the internal state to decode the next symbol in the bitstream.
pub fn update_state(&mut self, bits: &mut BitReaderReversed<'_>) {
dbg!(self.state);
let num_bits = self.state.num_bits;
let add = bits.get_bits(num_bits);
let base_line = self.state.base_line;
Expand Down
60 changes: 44 additions & 16 deletions src/fse/fse_encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::{
};

pub struct FSEEncoder {
table: FSETable,
pub(super) table: FSETable,
writer: BitWriter,
}

Expand All @@ -19,7 +19,14 @@ impl FSEEncoder {
}

pub fn encode(&mut self, data: &[u8]) -> Vec<u8> {
// TODO encode
let mut state = &self.table.states[data[data.len() - 1] as usize].states[0];
for x in data[0..data.len() - 1].iter().rev().copied() {
let next = self.table.next_state(x, state.index);
let diff = state.index - next.baseline;
self.writer.write_bits(diff as u64, next.num_bits as usize);
state = next;
}
self.writer.write_bits(state.index as u64, self.acc_log() as usize);

let mut writer = BitWriter::new();
core::mem::swap(&mut self.writer, &mut writer);
Expand All @@ -31,6 +38,18 @@ impl FSEEncoder {
}
writer.dump()
}

pub(super) fn probabilities(&self) -> Vec<i32> {
self.table
.states
.iter()
.map(|state| state.probability)
.collect()
}

pub(super) fn acc_log(&self) -> u8 {
self.table.table_size.ilog2() as u8
}
}

#[derive(Debug)]
Expand All @@ -51,6 +70,7 @@ impl FSETable {
pub(super) struct SymbolStates {
/// Sorted by baseline
pub(super) states: Vec<State>,
pub(super) probability: i32,
}

impl SymbolStates {
Expand Down Expand Up @@ -78,7 +98,7 @@ impl State {
}
}

fn build_table_from_data(data: &[u8]) -> FSETable {
pub fn build_table_from_data(data: &[u8]) -> FSETable {
let mut counts = [0; 256];
for x in data {
counts[*x as usize] += 1;
Expand Down Expand Up @@ -121,20 +141,27 @@ fn build_table_from_counts(counts: &[usize]) -> FSETable {
}

pub(super) fn build_table_from_probabilities(probs: &[i32], acc_log: u8) -> FSETable {
let mut states =
core::array::from_fn::<SymbolStates, 256, _>(|_| SymbolStates { states: Vec::new() });

let mut states = core::array::from_fn::<SymbolStates, 256, _>(|_| SymbolStates {
states: Vec::new(),
probability: 0,
});

// distribute -1 symbols
let mut negative_idx = (1 << acc_log) - 1;
for (symbol, _prob) in probs.iter().copied().enumerate().filter(|prob| prob.1 == -1) {
for (symbol, _prob) in probs
.iter()
.copied()
.enumerate()
.filter(|prob| prob.1 == -1)
{
dbg!(symbol, negative_idx);
states[symbol].states.push(State {
num_bits: acc_log,
baseline: 0,
last_index: (1 << acc_log) - 1,
index: negative_idx,
});
states[symbol].probability = -1;
negative_idx -= 1;
}

Expand All @@ -144,6 +171,7 @@ pub(super) fn build_table_from_probabilities(probs: &[i32], acc_log: u8) -> FSET
if prob <= 0 {
continue;
}
states[symbol].probability = prob;
let states = &mut states[symbol].states;
for _ in 0..prob {
states.push(State {
Expand All @@ -167,35 +195,35 @@ pub(super) fn build_table_from_probabilities(probs: &[i32], acc_log: u8) -> FSET
}
let prob = prob as u32;
let state = &mut states[symbol];
state.states.sort_by(|l,r| l.index.cmp(&r.index));
state.states.sort_by(|l, r| l.index.cmp(&r.index));

let prob_log = if prob.is_power_of_two() {
prob.ilog2()
} else {
prob.ilog2() + 1
prob.ilog2() + 1
};
let rounded_up = 1u32 << prob_log;
let double_states = rounded_up - prob;
let single_states = prob - double_states;
let num_bits = acc_log - prob_log as u8;
let mut baseline = (single_states as usize * (1 << (num_bits))) % (1 << acc_log);
let mut baseline = (single_states as usize * (1 << (num_bits))) % (1 << acc_log);
for (idx, state) in state.states.iter_mut().enumerate() {
if (idx as u32) < double_states {
let num_bits = num_bits + 1;
state.baseline = baseline;
state.num_bits = num_bits;
state.last_index= baseline + ((1 << num_bits) - 1);
state.last_index = baseline + ((1 << num_bits) - 1);

baseline += 1 << num_bits;
baseline %= 1 << acc_log;
baseline %= 1 << acc_log;
} else {
state.baseline = baseline;
state.num_bits = num_bits;
state.last_index= baseline + ((1 << num_bits) - 1);
state.last_index = baseline + ((1 << num_bits) - 1);
baseline += 1 << num_bits;
}
}
state.states.sort_by(|l,r| l.baseline.cmp(&r.baseline));
state.states.sort_by(|l, r| l.baseline.cmp(&r.baseline));
}

FSETable {
Expand Down
50 changes: 50 additions & 0 deletions src/fse/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@
mod fse_decoder;

use std::{dbg, eprintln, vec::Vec};

pub use fse_decoder::*;
use fse_encoder::FSEEncoder;

use crate::decoding::bit_reader_reverse::BitReaderReversed;
pub mod fse_encoder;

#[test]
Expand All @@ -24,10 +29,55 @@ fn tables_equal() {
dec_table.build_from_probabilities(6, probs).unwrap();
let enc_table = fse_encoder::build_table_from_probabilities(probs, 6);

check_tables(&dec_table, &enc_table);
}

fn check_tables(dec_table: &fse_decoder::FSETable, enc_table: &fse_encoder::FSETable) {
for (idx, dec_state) in dec_table.decode.iter().enumerate() {
let enc_states = &enc_table.states[dec_state.symbol as usize];
let enc_state = enc_states.states.iter().find(| state| state.index == idx).unwrap();
assert_eq!(enc_state.baseline, dec_state.base_line as usize);
assert_eq!(enc_state.num_bits, dec_state.num_bits);
}
}

#[test]
fn roundtrip() {
round_trip(&(0..64).collect::<Vec<_>>());
}

pub fn round_trip(data: &[u8]) {
let mut encoder: FSEEncoder = FSEEncoder::new(fse_encoder::build_table_from_data(data));
let mut dec_table = FSETable::new(255);
dec_table.build_from_probabilities(encoder.acc_log(), &encoder.probabilities()).unwrap();
let mut decoder = FSEDecoder::new(&dec_table);

check_tables(&dec_table, &encoder.table);

let encoded = encoder.encode(data);

let mut br = BitReaderReversed::new(&encoded);
let mut skipped_bits = 0;
loop {
let val = br.get_bits(1);
skipped_bits += 1;
if val == 1 || skipped_bits > 8 {
break;
}
}
if skipped_bits > 8 {
//if more than 7 bits are 0, this is not the correct end of the bitstream. Either a bug or corrupted data
panic!("Corrupted end marker");
}
decoder.init_state(&mut br).unwrap();
let mut decoded = alloc::vec::Vec::new();

for x in data {
let w = decoder.decode_symbol();
assert_eq!(w, *x);
decoded.push(w);
decoder.update_state(&mut br);
}

assert_eq!(&decoded, data);
}

0 comments on commit e02f5ba

Please sign in to comment.