Skip to content

Commit

Permalink
doc,pack,test
Browse files Browse the repository at this point in the history
  • Loading branch information
sgreben committed Oct 10, 2024
1 parent b87084a commit 669d832
Show file tree
Hide file tree
Showing 9 changed files with 98 additions and 12 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ The [`lsh.Fit`/`lsh.FitWide`](https://pkg.go.dev/github.com/keilerkonzept/bitkn

If your vectors are longer than 64 bits, you can still use `bitknn` if you [pack](https://pkg.go.dev/github.com/keilerkonzept/bitknn/pack) them into `[]uint64`. The [`pack` package](https://pkg.go.dev/github.com/keilerkonzept/bitknn/pack) defines helper functions to pack `string`s and `[]byte`s into `[]uint64`s.

> It's faster to use a `[][]uint64` allocated using a flat backing slice, laid out in one contiguous memory block. If you already have a non-contiguous `[][]uint64`, you can use [`pack.ReallocateFlat`](https://pkg.go.dev/github.com/keilerkonzept/bitknn/pack#ReallocateFlat) to re-allocate the dataset using a flat 1d backing slice.
The exact k-NN model in `bitknn` and the approximate-NN model in `lsh` each have a `Wide` variant that accepts slice-valued data points:

```go
Expand Down
2 changes: 1 addition & 1 deletion lsh/model_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ func Test_Model_NoHash_IsExact(t *testing.T) {
var h0 lsh.ConstantHash
id := func(a uint64) uint64 { return a }
rapid.Check(t, func(t *rapid.T) {
k := rapid.IntRange(3, 1001).Draw(t, "k")
k := rapid.IntRange(1, 1001).Draw(t, "k")
data := rapid.SliceOfNDistinct(rapid.Uint64(), 3, 1000, id).Draw(t, "data")
labels := rapid.SliceOfN(rapid.IntRange(0, 3), len(data), len(data)).Draw(t, "labels")
values := rapid.SliceOfN(rapid.Float64(), len(data), len(data)).Draw(t, "values")
Expand Down
2 changes: 1 addition & 1 deletion lsh/model_wide_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
func Test_WideModel_64bit_Equal_To_Narrow(t *testing.T) {
id := func(a uint64) uint64 { return a }
rapid.Check(t, func(t *rapid.T) {
k := rapid.IntRange(3, 1001).Draw(t, "k")
k := rapid.IntRange(1, 1001).Draw(t, "k")
data := rapid.SliceOfNDistinct(rapid.Uint64(), 3, 1000, id).Draw(t, "data")
dataWide := make([][]uint64, len(data))
for i := range data {
Expand Down
29 changes: 21 additions & 8 deletions pack/bytes.go
Original file line number Diff line number Diff line change
@@ -1,24 +1,23 @@
// Package pack provides helpers to pack bytes and strings into []uint64 slices.
package pack

// Bytes packs a byte slice into a uint64 slice.
// If the length of the byte slice is not a multiple of 8, it will pad the remaining bytes with zeroes.
func Bytes(data []byte) []uint64 {
// BytesInto packs a byte slice into the given pre-allocated uint64 slice.
// The output slice should have length >=[BytesPackedLength](data).
func BytesInto(data []byte, out []uint64) {
n := len(data)
dims := (n + 7) / 8 // round up division

out := make([]uint64, dims)

i := 0
j := 0
for ; i+8 <= n; i += 8 {
out[i/8] = uint64(data[i]) |
out[j] = uint64(data[i]) |
uint64(data[i+1])<<8 |
uint64(data[i+2])<<16 |
uint64(data[i+3])<<24 |
uint64(data[i+4])<<32 |
uint64(data[i+5])<<40 |
uint64(data[i+6])<<48 |
uint64(data[i+7])<<56
j++
}

if i < n {
Expand Down Expand Up @@ -46,9 +45,23 @@ func Bytes(data []byte) []uint64 {
case 1:
packed |= uint64(data[i])
}
out[i/8] = packed
out[j] = packed
}
}

// BytesPackedLength return the packed length of the given byte slice.
func BytesPackedLength(data []byte) int {
return (len(data) + 7) / 8
}

// Bytes packs a byte slice into a uint64 slice.
// If the length of the byte slice is not a multiple of 8, it will pad the remaining bytes with zeroes.
func Bytes(data []byte) []uint64 {
n := len(data)
dims := (n + 7) / 8 // round up division

out := make([]uint64, dims)
BytesInto(data, out)
return out
}

Expand Down
2 changes: 1 addition & 1 deletion pack/bytes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ func TestPackBytes(t *testing.T) {

// Property 1: Length of packed []uint64 should be (len(bytes) + 7) / 8
packed := pack.Bytes(bytesInput)
expectedLength := (len(bytesInput) + 7) / 8
expectedLength := pack.BytesPackedLength(bytesInput)
if len(packed) != expectedLength {
t.Fatalf("Expected packed length: %d, got: %d", expectedLength, len(packed))
}
Expand Down
16 changes: 16 additions & 0 deletions pack/compact.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package pack

// ReallocateFlat re-allocates the given 2d slice with a flat backing slice.
func ReallocateFlat[T any](d [][]T) {
n := 0
for _, d := range d {
n += len(d)
}
flat := make([]T, n)
j := 0
for i, row := range d {
copy(flat[j:], row)
d[i] = flat[j : j+len(row)]
j += len(row)
}
}
27 changes: 27 additions & 0 deletions pack/compact_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package pack_test

import (
"reflect"
"slices"
"testing"

"github.com/keilerkonzept/bitknn/pack"
"pgregory.net/rapid"
)

func TestPackReallocateFlat(t *testing.T) {
rapid.Check(t, func(t *rapid.T) {
dims := rapid.IntRange(3, 100).Draw(t, "dims")
n := rapid.IntRange(0, 1000).Draw(t, "n")
data := rapid.SliceOfN(rapid.SliceOfN(rapid.Uint64(), dims, dims), n, n).Draw(t, "data")

dataCopy := make([][]uint64, len(data))
for i := range dataCopy {
dataCopy[i] = slices.Clone(data[i])
}
pack.ReallocateFlat(data)
if !reflect.DeepEqual(data, dataCopy) {
t.Fatalf("Original: %v, Packed: %v", dataCopy, data)
}
})
}
12 changes: 12 additions & 0 deletions pack/string.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,18 @@ func String(data string) []uint64 {
return Bytes(b)
}

// StringPackedLength return the packed length of the given byte slice.
func StringPackedLength(data string) int {
return (len(data) + 7) / 8
}

// String packs a string into the given pre-allocated uint64 slice.
// The output slice should have length >=[StringPackedLength](data).
func StringInto(data string, out []uint64) {
b := unsafe.Slice(unsafe.StringData(data), len(data))
BytesInto(b, out)
}

// StringInv unpacks a []uint64 slice as packed by [String],
func StringInv(data []uint64, originalLengthBytes int) string {
b := BytesInv(data, originalLengthBytes)
Expand Down
18 changes: 17 additions & 1 deletion pack/string_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ func TestPackString(t *testing.T) {
data := rapid.String().Draw(t, "data")

// Property 1: Length of packed []uint64 should be (len(data) + 7) / 8
expectedLength := pack.StringPackedLength(data)
packed := pack.String(data)
expectedLength := (len(data) + 7) / 8
if len(packed) != expectedLength {
t.Fatalf("Expected packed length: %d, got: %d", expectedLength, len(packed))
}
Expand All @@ -25,3 +25,19 @@ func TestPackString(t *testing.T) {
}
})
}

func TestPackStringInto(t *testing.T) {
rapid.Check(t, func(t *rapid.T) {
data := rapid.String().Draw(t, "data")

n := pack.StringPackedLength(data)
packed := make([]uint64, n)
pack.StringInto(data, packed)

// Property 2: Roundtrip
unpacked := pack.StringInv(packed, len(data))
if data != unpacked {
t.Fatalf("Original string: %v, Unpacked string: %v", data, unpacked)
}
})
}

0 comments on commit 669d832

Please sign in to comment.