Skip to content
This repository has been archived by the owner on Jun 17, 2024. It is now read-only.

Fix bitpacking #13

Merged
merged 1 commit into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions src/bitpacking.zig
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,23 @@ test "bitpack" {
BP.decode(3, &packed_ints, &output);
try std.testing.expectEqual(.{2} ** 1024, output);
}

test "bitpack range" {
const std = @import("std");
const fl = @import("./fastlanez.zig");
const BP = BitPacking(fl.FastLanez(u8));

const W = 3;

var ints: [1024]u8 = undefined;
for (0..1024) |i| {
ints[i] = @intCast(i % 7);
}

var packed_ints: [128 * W]u8 = undefined;
BP.encode(W, &ints, &packed_ints);

var output: [1024]u8 = undefined;
BP.decode(W, &packed_ints, &output);
try std.testing.expectEqual(ints, output);
}
48 changes: 21 additions & 27 deletions src/fastlanez.zig
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,10 @@ pub fn FastLanez(comptime Element: type) type {
/// The position in the output that we're writing to. Will finish equal to Width.
out_idx: comptime_int = 0,

shift_bits: comptime_int = 0,
mask_bits: comptime_int = Width,
bit_idx: comptime_int = 0,

/// Invoke to store the next vector.
/// Called T times, and writes W times. bit_idx tracks how many bits have been written into the result.
pub inline fn pack(comptime self: *Self, out: *PackedBytes(Width), word: MM1024, state: MM1024) MM1024 {
var tmp: MM1024 = undefined;
if (self.t == 0) {
Expand All @@ -140,24 +140,23 @@ pub fn FastLanez(comptime Element: type) type {
}
self.t += 1;

// If we didn't take all W bits last time, then we load the remainder
if (self.mask_bits < Width) {
tmp = or_(tmp, and_rshift(word, self.mask_bits, bitmask(self.shift_bits)));
}

// Update the number of mask bits
self.mask_bits = @min(T - self.shift_bits, Width);
const shift_bits = self.bit_idx % T;
const mask_bits = @min(T - shift_bits, Width - (self.bit_idx % Width));

// Pull the masked bits into the tmp register
tmp = or_(tmp, and_lshift(word, self.shift_bits, bitmask(self.mask_bits)));
self.shift_bits += Width;
tmp = or_(tmp, and_lshift(word, shift_bits, bitmask(mask_bits)));
self.bit_idx += mask_bits;

if (self.shift_bits >= T) {
// If we have a full 1024 bits, then store it and reset the tmp register
if (self.bit_idx % T == 0) {
// If we have a full T bits, then store it and reset the tmp register
store(out, self.out_idx, tmp);
tmp = @splat(0);
self.out_idx += 1;
self.shift_bits -= T;

// Put the remainder of the bits in the next word
if (mask_bits < Width) {
tmp = or_(tmp, and_rshift(word, mask_bits, bitmask(Width - mask_bits)));
self.bit_idx += (Width - mask_bits);
}
}

return tmp;
Expand All @@ -177,7 +176,7 @@ pub fn FastLanez(comptime Element: type) type {
t: comptime_int = 0,

input_idx: comptime_int = 0,
shift_bits: comptime_int = 0,
bit_idx: comptime_int = 0,

pub inline fn unpack(comptime self: *Self, input: *const PackedBytes(Width), state: MM1024) struct { MM1024, MM1024 } {
if (self.t > T) {
Expand All @@ -193,24 +192,19 @@ pub fn FastLanez(comptime Element: type) type {
tmp = state;
}

const mask_bits = @min(T - self.shift_bits, Width);
const shift_bits = self.bit_idx % T;
const mask_bits = @min(T - shift_bits, Width - (self.bit_idx % Width));

var next: MM1024 = undefined;
if (self.shift_bits == T) {
next = tmp;
} else {
next = and_rshift(tmp, self.shift_bits, bitmask(mask_bits));
}
var next: MM1024 = and_rshift(tmp, shift_bits, bitmask(mask_bits));

if (mask_bits != Width) {
if (mask_bits != Width and self.input_idx < Width) {
tmp = load(input, self.input_idx);
self.input_idx += 1;

next = or_(next, and_lshift(tmp, mask_bits, bitmask(Width - mask_bits)));

self.shift_bits = Width - mask_bits;
self.bit_idx += Width;
} else {
self.shift_bits += Width;
self.bit_idx += mask_bits;
}

return .{ next, tmp };
Expand Down