From 8108ce7633d378612dd5100789e2c1b117b85ecd Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Tue, 19 Mar 2024 10:03:57 +0000 Subject: [PATCH] Fix bitpacking (#13) --- src/bitpacking.zig | 20 +++++++++++++++++++ src/fastlanez.zig | 48 ++++++++++++++++++++-------------------------- 2 files changed, 41 insertions(+), 27 deletions(-) diff --git a/src/bitpacking.zig b/src/bitpacking.zig index b782acd..f38225e 100644 --- a/src/bitpacking.zig +++ b/src/bitpacking.zig @@ -54,3 +54,23 @@ test "bitpack" { BP.decode(3, &packed_ints, &output); try std.testing.expectEqual(.{2} ** 1024, output); } + +test "bitpack range" { + const std = @import("std"); + const fl = @import("./fastlanez.zig"); + const BP = BitPacking(fl.FastLanez(u8)); + + const W = 3; + + var ints: [1024]u8 = undefined; + for (0..1024) |i| { + ints[i] = @intCast(i % 7); + } + + var packed_ints: [128 * W]u8 = undefined; + BP.encode(W, &ints, &packed_ints); + + var output: [1024]u8 = undefined; + BP.decode(W, &packed_ints, &output); + try std.testing.expectEqual(ints, output); +} diff --git a/src/fastlanez.zig b/src/fastlanez.zig index 4a8758e..6620901 100644 --- a/src/fastlanez.zig +++ b/src/fastlanez.zig @@ -123,10 +123,10 @@ pub fn FastLanez(comptime Element: type) type { /// The position in the output that we're writing to. Will finish equal to Width. out_idx: comptime_int = 0, - shift_bits: comptime_int = 0, - mask_bits: comptime_int = Width, + bit_idx: comptime_int = 0, /// Invoke to store the next vector. + /// Called T times, and writes W times. bit_idx tracks how many bits have been written into the result. pub inline fn pack(comptime self: *Self, out: *PackedBytes(Width), word: MM1024, state: MM1024) MM1024 { var tmp: MM1024 = undefined; if (self.t == 0) { @@ -140,24 +140,23 @@ pub fn FastLanez(comptime Element: type) type { } self.t += 1; - // If we didn't take all W bits last time, then we load the remainder - if (self.mask_bits < Width) { - tmp = or_(tmp, and_rshift(word, self.mask_bits, bitmask(self.shift_bits))); - } - - // Update the number of mask bits - self.mask_bits = @min(T - self.shift_bits, Width); + const shift_bits = self.bit_idx % T; + const mask_bits = @min(T - shift_bits, Width - (self.bit_idx % Width)); - // Pull the masked bits into the tmp register - tmp = or_(tmp, and_lshift(word, self.shift_bits, bitmask(self.mask_bits))); - self.shift_bits += Width; + tmp = or_(tmp, and_lshift(word, shift_bits, bitmask(mask_bits))); + self.bit_idx += mask_bits; - if (self.shift_bits >= T) { - // If we have a full 1024 bits, then store it and reset the tmp register + if (self.bit_idx % T == 0) { + // If we have a full T bits, then store it and reset the tmp register store(out, self.out_idx, tmp); tmp = @splat(0); self.out_idx += 1; - self.shift_bits -= T; + + // Put the remainder of the bits in the next word + if (mask_bits < Width) { + tmp = or_(tmp, and_rshift(word, mask_bits, bitmask(Width - mask_bits))); + self.bit_idx += (Width - mask_bits); + } } return tmp; @@ -177,7 +176,7 @@ pub fn FastLanez(comptime Element: type) type { t: comptime_int = 0, input_idx: comptime_int = 0, - shift_bits: comptime_int = 0, + bit_idx: comptime_int = 0, pub inline fn unpack(comptime self: *Self, input: *const PackedBytes(Width), state: MM1024) struct { MM1024, MM1024 } { if (self.t > T) { @@ -193,24 +192,19 @@ pub fn FastLanez(comptime Element: type) type { tmp = state; } - const mask_bits = @min(T - self.shift_bits, Width); + const shift_bits = self.bit_idx % T; + const mask_bits = @min(T - shift_bits, Width - (self.bit_idx % Width)); - var next: MM1024 = undefined; - if (self.shift_bits == T) { - next = tmp; - } else { - next = and_rshift(tmp, self.shift_bits, bitmask(mask_bits)); - } + var next: MM1024 = and_rshift(tmp, shift_bits, bitmask(mask_bits)); - if (mask_bits != Width) { + if (mask_bits != Width and self.input_idx < Width) { tmp = load(input, self.input_idx); self.input_idx += 1; next = or_(next, and_lshift(tmp, mask_bits, bitmask(Width - mask_bits))); - - self.shift_bits = Width - mask_bits; + self.bit_idx += Width; } else { - self.shift_bits += Width; + self.bit_idx += mask_bits; } return .{ next, tmp };