diff --git a/zlib-rs/src/inflate.rs b/zlib-rs/src/inflate.rs index fcd30dd..df08e7a 100644 --- a/zlib-rs/src/inflate.rs +++ b/zlib-rs/src/inflate.rs @@ -517,6 +517,335 @@ const INFLATE_FAST_MIN_HAVE: usize = 15; const INFLATE_FAST_MIN_LEFT: usize = 260; impl State<'_> { + // This logic is split into its own function for two reasons + // + // - We get to load state to the stack; doing this in all cases is expensive, but doing it just + // for Len and related states is very helpful. + // - The `-Cllvm-args=-enable-dfa-jump-thread` llvm arg is able to optimize this function, but + // not the entirity of `dispatch`. We get a massive boost from that pass. + // + // It unfortunately does duplicate the code for some of the states; deduplicating it by having + // more of the states call this function is slower. + fn len_and_friends(&mut self) -> Option { + let avail_in = self.bit_reader.bytes_remaining(); + let avail_out = self.writer.remaining(); + + if avail_in >= INFLATE_FAST_MIN_HAVE && avail_out >= INFLATE_FAST_MIN_LEFT { + inflate_fast_help(self, 0); + return None; + } + + let mut mode; + let mut writer; + let mut bit_reader; + + macro_rules! load { + () => { + mode = self.mode; + writer = core::mem::replace(&mut self.writer, Writer::new(&mut [])); + bit_reader = self.bit_reader; + }; + } + + macro_rules! restore { + () => { + self.mode = mode; + self.writer = writer; + self.bit_reader = bit_reader; + }; + } + + load!(); + + let len_table = match self.len_table.codes { + Codes::Fixed => &self::inffixed_tbl::LENFIX[..], + Codes::Codes => &self.codes_codes, + Codes::Len => &self.len_codes, + Codes::Dist => &self.dist_codes, + }; + + let dist_table = match self.dist_table.codes { + Codes::Fixed => &self::inffixed_tbl::DISTFIX[..], + Codes::Codes => &self.codes_codes, + Codes::Len => &self.len_codes, + Codes::Dist => &self.dist_codes, + }; + + 'top: loop { + match mode { + Mode::Len => { + let avail_in = bit_reader.bytes_remaining(); + let avail_out = writer.remaining(); + + // INFLATE_FAST_MIN_LEFT is important. It makes sure there is at least 32 bytes of free + // space available. This means for many SIMD operations we don't need to process a + // remainder; we just copy blindly, and a later operation will overwrite the extra copied + // bytes + if avail_in >= INFLATE_FAST_MIN_HAVE && avail_out >= INFLATE_FAST_MIN_LEFT { + restore!(); + inflate_fast_help(self, 0); + return None; + } + + self.back = 0; + + // get a literal, length, or end-of-block code + let mut here; + loop { + let bits = bit_reader.bits(self.len_table.bits); + here = len_table[bits as usize]; + + if here.bits <= bit_reader.bits_in_buffer() { + break; + } + + if let Err(return_code) = bit_reader.pull_byte() { + restore!(); + return Some(return_code); + }; + } + + if here.op != 0 && here.op & 0xf0 == 0 { + let last = here; + loop { + let bits = bit_reader.bits((last.bits + last.op) as usize) as u16; + here = len_table[(last.val + (bits >> last.bits)) as usize]; + if last.bits + here.bits <= bit_reader.bits_in_buffer() { + break; + } + + if let Err(return_code) = bit_reader.pull_byte() { + restore!(); + return Some(return_code); + }; + } + + bit_reader.drop_bits(last.bits); + self.back += last.bits as usize; + } + + bit_reader.drop_bits(here.bits); + self.back += here.bits as usize; + self.length = here.val as usize; + + if here.op == 0 { + mode = Mode::Lit; + continue 'top; + } else if here.op & 32 != 0 { + // end of block + + // eprintln!("inflate: end of block"); + + self.back = usize::MAX; + mode = Mode::Type; + + restore!(); + return None; + } else if here.op & 64 != 0 { + mode = Mode::Bad; + { + restore!(); + let this = &mut *self; + let msg: &'static str = "invalid literal/length code\0"; + #[cfg(all(feature = "std", test))] + dbg!(msg); + this.error_message = Some(msg); + return Some(this.inflate_leave(ReturnCode::DataError)); + } + } else { + // length code + self.extra = (here.op & MAX_BITS) as usize; + mode = Mode::LenExt; + continue 'top; + } + } + Mode::Lit => { + if writer.is_full() { + restore!(); + #[cfg(all(test, feature = "std"))] + eprintln!("Ok: writer is full ({} bytes)", self.writer.capacity()); + return Some(self.inflate_leave(ReturnCode::Ok)); + } + + writer.push(self.length as u8); + + mode = Mode::Len; + + continue 'top; + } + Mode::LenExt => { + let extra = self.extra; + + // get extra bits, if any + if extra != 0 { + match bit_reader.need_bits(extra) { + Err(return_code) => { + restore!(); + return Some(self.inflate_leave(return_code)); + } + Ok(v) => v, + }; + self.length += bit_reader.bits(extra) as usize; + bit_reader.drop_bits(extra as u8); + self.back += extra; + } + + // eprintln!("inflate: length {}", state.length); + + self.was = self.length; + mode = Mode::Dist; + + continue 'top; + } + Mode::Dist => { + // get distance code + let mut here; + loop { + let bits = bit_reader.bits(self.dist_table.bits) as usize; + here = dist_table[bits]; + if here.bits <= bit_reader.bits_in_buffer() { + break; + } + + if let Err(return_code) = bit_reader.pull_byte() { + restore!(); + return Some(return_code); + }; + } + + if here.op & 0xf0 == 0 { + let last = here; + + loop { + let bits = bit_reader.bits((last.bits + last.op) as usize); + here = dist_table[last.val as usize + ((bits as usize) >> last.bits)]; + + if last.bits + here.bits <= bit_reader.bits_in_buffer() { + break; + } + + if let Err(return_code) = bit_reader.pull_byte() { + restore!(); + return Some(return_code); + }; + } + + bit_reader.drop_bits(last.bits); + self.back += last.bits as usize; + } + + bit_reader.drop_bits(here.bits); + + if here.op & 64 != 0 { + restore!(); + self.mode = Mode::Bad; + return Some(self.bad("invalid distance code\0")); + } + + self.offset = here.val as usize; + + self.extra = (here.op & MAX_BITS) as usize; + mode = Mode::DistExt; + + continue 'top; + } + Mode::DistExt => { + let extra = self.extra; + + if extra > 0 { + match bit_reader.need_bits(extra) { + Err(return_code) => { + restore!(); + return Some(self.inflate_leave(return_code)); + } + Ok(v) => v, + }; + self.offset += bit_reader.bits(extra) as usize; + bit_reader.drop_bits(extra as u8); + self.back += extra; + } + + if INFLATE_STRICT && self.offset > self.dmax { + restore!(); + self.mode = Mode::Bad; + return Some(self.bad("invalid distance code too far back\0")); + } + + // eprintln!("inflate: distance {}", state.offset); + + mode = Mode::Match; + + continue 'top; + } + Mode::Match => { + if writer.is_full() { + restore!(); + #[cfg(all(feature = "std", test))] + eprintln!( + "BufError: writer is full ({} bytes)", + self.writer.capacity() + ); + return Some(self.inflate_leave(ReturnCode::Ok)); + } + + let left = writer.remaining(); + let copy = writer.len(); + + let copy = if self.offset > copy { + // copy from window to output + + let mut copy = self.offset - copy; + + if copy > self.window.have() { + if self.flags.contains(Flags::SANE) { + restore!(); + self.mode = Mode::Bad; + return Some(self.bad("invalid distance too far back\0")); + } + + // TODO INFLATE_ALLOW_INVALID_DISTANCE_TOOFAR_ARRR + panic!("INFLATE_ALLOW_INVALID_DISTANCE_TOOFAR_ARRR") + } + + let wnext = self.window.next(); + let wsize = self.window.size(); + + let from = if copy > wnext { + copy -= wnext; + wsize - copy + } else { + wnext - copy + }; + + copy = Ord::min(copy, self.length); + copy = Ord::min(copy, left); + + writer.extend_from_window(&self.window, from..from + copy); + + copy + } else { + let copy = Ord::min(self.length, left); + writer.copy_match(self.offset, copy); + + copy + }; + + self.length -= copy; + + if self.length == 0 { + mode = Mode::Len; + continue 'top; + } else { + // otherwise it seems to recurse? + // self.match_() + continue 'top; + } + } + _ => unsafe { core::hint::unreachable_unchecked() }, + } + } + } + fn dispatch(&mut self) -> ReturnCode { 'label: loop { match self.mode { @@ -1054,84 +1383,15 @@ impl State<'_> { continue 'label; } - Mode::Len => { - let avail_in = self.bit_reader.bytes_remaining(); - let avail_out = self.writer.remaining(); - - // INFLATE_FAST_MIN_LEFT is important. It makes sure there is at least 32 bytes of free - // space available. This means for many SIMD operations we don't need to process a - // remainder; we just copy blindly, and a later operation will overwrite the extra copied - // bytes - if avail_in >= INFLATE_FAST_MIN_HAVE && avail_out >= INFLATE_FAST_MIN_LEFT { - inflate_fast_help(self, 0); - continue 'label; - } - - self.back = 0; - - // get a literal, length, or end-of-block code - let mut here; - loop { - let bits = self.bit_reader.bits(self.len_table.bits); - here = self.len_table_get(bits as usize); - - if here.bits <= self.bit_reader.bits_in_buffer() { - break; - } - - pull_byte!(self); - } - - if here.op != 0 && here.op & 0xf0 == 0 { - let last = here; - loop { - let bits = self.bit_reader.bits((last.bits + last.op) as usize) as u16; - here = self.len_table_get((last.val + (bits >> last.bits)) as usize); - if last.bits + here.bits <= self.bit_reader.bits_in_buffer() { - break; - } - - pull_byte!(self); - } - - self.bit_reader.drop_bits(last.bits); - self.back += last.bits as usize; - } - - self.bit_reader.drop_bits(here.bits); - self.back += here.bits as usize; - self.length = here.val as usize; - - if here.op == 0 { - self.mode = Mode::Lit; - - continue 'label; - } else if here.op & 32 != 0 { - // end of block - - // eprintln!("inflate: end of block"); - - self.back = usize::MAX; - self.mode = Mode::Type; - - continue 'label; - } else if here.op & 64 != 0 { - self.mode = Mode::Bad; - - break 'label self.bad("invalid literal/length code\0"); - } else { - // length code - self.extra = (here.op & MAX_BITS) as usize; - self.mode = Mode::LenExt; - - continue 'label; - } - } Mode::Len_ => { self.mode = Mode::Len; continue 'label; } + Mode::Len => match self.len_and_friends() { + Some(return_code) => break 'label return_code, + None => continue 'label, + }, Mode::LenExt => { let extra = self.extra;