diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..5d5906c --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,11 @@ +# To get started with Dependabot version updates, you'll need to specify which +# package ecosystems to update and where the package manifests are located. +# Please see the documentation for all configuration options: +# https://docs.github.com/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file + +version: 2 +updates: + - package-ecosystem: "gomod" # See documentation for possible values + directory: "/" # Location of package manifests + schedule: + interval: "daily" diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml new file mode 100644 index 0000000..77d4898 --- /dev/null +++ b/.github/workflows/go.yml @@ -0,0 +1,30 @@ +# This workflow will build a golang project +# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-go + +name: Go + +on: + push: + branches: ["main"] + pull_request: + branches: ["main"] + +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: "1.23" + + - name: Build + run: go build -v ./... + + - name: Test + run: go test -v ./... + + - name: Coverage + run: go test -v -cover ./... diff --git a/.github/workflows/gocover.yaml b/.github/workflows/gocover.yaml new file mode 100644 index 0000000..37e9b5c --- /dev/null +++ b/.github/workflows/gocover.yaml @@ -0,0 +1,53 @@ +name: Go coverage badge # The name of the workflow that will appear on Github + +on: + push: + branches: [main] + # Allows you to run this workflow manually from the Actions tab + workflow_dispatch: + +jobs: + build: + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + go: [1.23] + steps: + - uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: ${{ matrix.go }} + + - name: Build + run: go install + + - name: Test + run: | + go test -v -cover ./... -coverprofile coverage.out -coverpkg ./... + go tool cover -func coverage.out -o coverage.out # Replaces coverage.out with the analysis of coverage.out + + - name: Go Coverage Badge + uses: tj-actions/coverage-badge-go@v2 + if: ${{ runner.os == 'Linux' && matrix.go == '1.23' }} # Runs this on only one of the ci builds. + with: + green: 80 + filename: coverage.out + link: https://github.com/keilerkonzept/bitknn/actions/workflows/gocover.yaml + + - uses: stefanzweifel/git-auto-commit-action@v5 + id: auto-commit-action + with: + commit_message: Apply Code Coverage Badge + skip_fetch: true + skip_checkout: true + file_pattern: ./README.md + + - name: Push Changes + if: steps.auto-commit-action.outputs.changes_detected == 'true' + uses: ad-m/github-push-action@master + with: + github_token: ${{ github.token }} + branch: ${{ github.ref }} diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..f47cb20 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +*.out diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..9f00194 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 KEILERKONZEPT UG (haftungsbeschränkt) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..fa0ea3e --- /dev/null +++ b/README.md @@ -0,0 +1,79 @@ +# bitknn + +[![Go Reference](https://pkg.go.dev/badge/github.com/keilerkonzept/bitknn.svg)](https://pkg.go.dev/github.com/keilerkonzept/bitknn) +[![Go Report Card](https://goreportcard.com/badge/github.com/keilerkonzept/bitknn)](https://goreportcard.com/report/github.com/keilerkonzept/bitknn) + + +```go +import "github.com/keilerkonzept/bitknn" +``` + +`bitknn` is a fast k-nearest neighbors (k-NN) library for `uint64`s, using Hamming distance to measure similarity. + +If you need to classify **binary feature vectors that fit into `uint64`s**, this library might be useful. It is fast mainly because we can use cheap bitwise ops (XOR + POPCNT) to calculate distances between `uint64` values. For smaller datasets, the performance of the [neighbor heap](heap.go) is also relevant, and so this part has been tuned here also. + +You can optionally weigh class votes by distance, or specify different vote values per data point. + + +**Contents** +- [Usage](#usage) +- [Options](#options) +- [License](#license) + +## Usage + +```go +package main + +import ( + "fmt" + "github.com/keilerkonzept/bitknn" +) + +func main() { + // feature vectors packed into uint64s + data := []uint64{0b101010, 0b111000, 0b000111} + // class labels + labels := []int{0, 1, 1} + + model := bitknn.Fit(data, labels, 2, bitknn.WithLinearDecay()) + + // one vote counter per class + votes := make([]float64, 2) + model.Predict1(0b101011, votes) + + fmt.Println("Votes:", votes) +} +``` + +## Options + +- `WithLinearDecay()`: Apply linear distance weighting (`1 / (1 + dist)`). +- `WithQuadraticDecay()`: Apply quadratic distance weighting (`1 / (1 + dist^2)`). +- `WithDistanceWeightFunc(f func(dist int) float64)`: Use a custom distance weighting function. +- `WithValues(values []float64)`: Assign specific vote values for each data point. + +## Benchmarks + +``` +goos: darwin +goarch: arm64 +pkg: github.com/keilerkonzept/bitknn +cpu: Apple M1 Pro +``` + +| op | N | k | iters | ns/op | B/op | allocs/op | +|------------|---------|-----|---------|--------------|------|-----------| +| `Predict1` | 100 | 3 | 8308794 | 121.4 ns/op | 0 | 0 | +| `Predict1` | 100 | 10 | 4707778 | 269.7 ns/op | 0 | 0 | +| `Predict1` | 100 | 100 | 2255380 | 549.2 ns/op | 0 | 0 | +| `Predict1` | 1000 | 3 | 1693364 | 659.3 ns/op | 0 | 0 | +| `Predict1` | 1000 | 10 | 1220426 | 1005 ns/op | 0 | 0 | +| `Predict1` | 1000 | 100 | 345151 | 3560 ns/op | 0 | 0 | +| `Predict1` | 1000000 | 3 | 2076 | 566647 ns/op | 0 | 0 | +| `Predict1` | 1000000 | 10 | 2112 | 568787 ns/op | 0 | 0 | +| `Predict1` | 1000000 | 100 | 2066 | 587827 ns/op | 0 | 0 | + +## License + +MIT License diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..82c48b3 --- /dev/null +++ b/go.mod @@ -0,0 +1,5 @@ +module github.com/keilerkonzept/bitknn + +go 1.23.0 + +require github.com/google/go-cmp v0.6.0 diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..5a8d551 --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= diff --git a/heap.go b/heap.go new file mode 100644 index 0000000..f4d5b70 --- /dev/null +++ b/heap.go @@ -0,0 +1,78 @@ +package bitknn + +import "unsafe" + +// neighborHeap is a max-heap that stores distances and their corresponding indices. +// The heap is used to keep track of nearest neighbors. +type neighborHeap struct { + distances []int + lastDistance *int + indices []int + lastIndex *int + len int +} + +const unsafeSizeofInt = unsafe.Sizeof(int(0)) + +func makeNeighborHeap(distances, indices []int) neighborHeap { + return neighborHeap{ + distances: distances, + lastDistance: (*int)(unsafe.Add(unsafe.Pointer(unsafe.SliceData(distances)), unsafeSizeofInt*uintptr(len(distances)-1))), + indices: indices, + lastIndex: (*int)(unsafe.Add(unsafe.Pointer(unsafe.SliceData(indices)), unsafeSizeofInt*uintptr(len(indices)-1))), + } +} + +func (me *neighborHeap) swap(i, j int) { + me.distances[i], me.distances[j] = me.distances[j], me.distances[i] + me.indices[i], me.indices[j] = me.indices[j], me.indices[i] +} + +func (me *neighborHeap) less(i, j int) bool { + return me.distances[i] > me.distances[j] +} + +func (me *neighborHeap) pushpop(value int, index int) { + n := me.len + *me.lastDistance = value + *me.lastIndex = index + me.up(n) + me.swap(0, n) + + // me.down(0, n) + i := 0 + for { + l := 2*i + 1 // Left child + if l >= n || l < 0 { // If no left child, break + break + } + j := l + if r := l + 1; r < n && me.less(r, l) { // If right child exists and is smaller, select right child + j = r + } + if !me.less(j, i) { // If parent is smaller than selected child, break + break + } + me.swap(i, j) // Swap parent with child + i = j // Continue pushing down + } +} + +func (me *neighborHeap) push(value int, index int) { + n := me.len + me.distances[n] = value + me.indices[n] = index + me.len = n + 1 + me.up(n) +} + +func (me *neighborHeap) up(i int) { + for { + p := (i - 1) / 2 // Parent index + if p == i || !me.less(i, p) { // If parent is larger or i is root, stop + break + } + me.swap(p, i) // Swap child with parent + i = p // Continue moving up + } +} diff --git a/heap_test.go b/heap_test.go new file mode 100644 index 0000000..c6a1e85 --- /dev/null +++ b/heap_test.go @@ -0,0 +1,97 @@ +package bitknn + +import ( + "testing" +) + +func TestMakeNeighborHeap(t *testing.T) { + distances := []int{10, 20, 30} + indices := []int{1, 2, 3} + heap := makeNeighborHeap(distances, indices) + + // Check if lastDistance and lastIndex are pointing to the correct elements + if *heap.lastDistance != 30 { + t.Errorf("Expected lastDistance to be 30, got %d", *heap.lastDistance) + } + if *heap.lastIndex != 3 { + t.Errorf("Expected lastIndex to be 3, got %d", *heap.lastIndex) + } +} + +func TestNeighborHeapSwap(t *testing.T) { + heap := neighborHeap{ + distances: []int{10, 20, 30}, + indices: []int{1, 2, 3}, + } + + heap.swap(0, 2) + + if heap.distances[0] != 30 || heap.distances[2] != 10 { + t.Errorf("Swap failed on distances, got %v", heap.distances) + } + if heap.indices[0] != 3 || heap.indices[2] != 1 { + t.Errorf("Swap failed on indices, got %v", heap.indices) + } +} + +func TestNeighborHeapLess(t *testing.T) { + heap := neighborHeap{ + distances: []int{10, 20, 30}, + indices: []int{1, 2, 3}, + } + + if !heap.less(2, 0) { + t.Errorf("Expected less(2, 0) to be true, got false") + } + + if heap.less(0, 2) { + t.Errorf("Expected less(0, 2) to be false, got true") + } +} + +func TestNeighborHeapPushPop(t *testing.T) { + distances := []int{30, 20, 10, 0} + indices := []int{1, 2, 3, 0} + heap := makeNeighborHeap(distances, indices) + heap.len = 3 + + heap.pushpop(25, 4) + + // Check if heap is reordered correctly + expectedDistances := []int{25, 20, 10, + 30, + } + expectedIndices := []int{4, 2, 3, + 1, + } + for i := range expectedDistances { + if heap.distances[i] != expectedDistances[i] { + t.Errorf("Expected distance at %d to be %d, got %d", i, expectedDistances[i], heap.distances[i]) + } + if heap.indices[i] != expectedIndices[i] { + t.Errorf("Expected index at %d to be %d, got %d", i, expectedIndices[i], heap.indices[i]) + } + } +} + +func TestNeighborHeapPush(t *testing.T) { + heap := makeNeighborHeap( + make([]int, 4), + make([]int, 4), + ) + + heap.push(10, 3) + heap.push(15, 5) + heap.push(25, 6) + heap.pushpop(9, 3) + heap.pushpop(7, 2) + heap.pushpop(8, 1) + heap.pushpop(6, 0) + + if heap.distances[0] != 8 { + t.Errorf("Expected root distance to be 25, got %d", heap.distances[0]) + } + if heap.indices[0] != 1 { + t.Errorf("Expected root index to be 6, got %d", heap.indices[0]) + } +} diff --git a/knn.go b/knn.go new file mode 100644 index 0000000..5659cbe --- /dev/null +++ b/knn.go @@ -0,0 +1,168 @@ +package bitknn + +// The k-NN model state. +type Model struct { + // Number of nearest neighbors to consider. + k int + // Input data points. + data []uint64 + // Class labels for each data point. + labels []int + // Optional vote values for each data point. + values []float64 + + distanceWeightFuncAny bool + distanceWeightFunc func(int) float64 + distanceWeightLinear bool + distanceWeightQuadratic bool + + distances []int // len = k+1 + indices []int // len = k+1 + votes []float64 // len = k+1 +} + +// The type of options for the k-NN model. +type Option func(*Model) + +func linearDecay(dist int) float64 { return 1 / float64(1+dist) } +func quadraticDecay(dist int) float64 { return 1 / float64(1+dist*dist) } + +// Set linear decay as the distance weight function. +func WithLinearDecay() Option { + return func(m *Model) { + m.distanceWeightFuncAny = true + m.distanceWeightFunc = nil + m.distanceWeightLinear = true + m.distanceWeightQuadratic = false + } +} + +// Set quadratic decay as the distance weight function. +func WithQuadraticDecay() Option { + return func(m *Model) { + m.distanceWeightFuncAny = true + m.distanceWeightFunc = nil + m.distanceWeightLinear = false + m.distanceWeightQuadratic = true + + } +} + +// Set a custom distance weight function. +func WithDistanceWeightFunc(f func(distance int) float64) Option { + return func(m *Model) { + m.distanceWeightFuncAny = f != nil + m.distanceWeightFunc = f + m.distanceWeightLinear = false + m.distanceWeightQuadratic = false + } +} + +// Set vote values for each data point. +func WithValues(values []float64) Option { + return func(m *Model) { + m.values = values + } +} + +// Build a k-NN model from the given data and options. +func Fit(data []uint64, labels []int, k int, opts ...Option) *Model { + m := &Model{ + k: k, + data: data, + labels: labels, + + distances: make([]int, k+1), + indices: make([]int, k+1), + votes: make([]float64, k+1), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// Predicts the label of a single input point, using the scratch space pre-allocated with the [Model] for the neighbor heap. +func (me *Model) Predict1(x uint64, votes []float64) { + me.Predict1Into(x, me.distances, me.indices, votes) +} + +// Predicts the label of a single input point, using the given slices for the neighbor heap. +func (me *Model) Predict1Into(x uint64, distances []int, indices []int, votes []float64) { + k := Nearest(me.data, me.k, x, distances, indices) + + clear(votes) + if me.values == nil { + me.countWeightedVotes(k, indices, distances, votes) + return + } + + me.sumWeightedValues(k, indices, distances, votes) +} + +func (me *Model) sumWeightedValues(k int, indices []int, distances []int, votes []float64) { + w := me.distanceWeightFunc + switch { + case !me.distanceWeightFuncAny: + for i := range k { + index := indices[i] + label := me.labels[index] + v := votes[label] + me.values[index] + votes[label] = v + } + case me.distanceWeightLinear: + for i := range k { + index := indices[i] + label := me.labels[index] + v := votes[label] + me.values[index]*linearDecay(distances[i]) + votes[label] = v + } + case me.distanceWeightQuadratic: + for i := range k { + index := indices[i] + label := me.labels[index] + v := votes[label] + me.values[index]*quadraticDecay(distances[i]) + votes[label] = v + } + case w != nil: + for i := range k { + index := indices[i] + label := me.labels[index] + v := votes[label] + me.values[index]*w(distances[i]) + votes[label] = v + } + } +} + +func (me *Model) countWeightedVotes(k int, indices []int, distances []int, votes []float64) { + w := me.distanceWeightFunc + switch { + case !me.distanceWeightFuncAny: + for i := range k { + index := indices[i] + label := me.labels[index] + votes[label]++ + } + case me.distanceWeightLinear: + for i := range k { + index := indices[i] + label := me.labels[index] + v := votes[label] + linearDecay(distances[i]) + votes[label] = v + } + case me.distanceWeightQuadratic: + for i := range k { + index := indices[i] + label := me.labels[index] + v := votes[label] + quadraticDecay(distances[i]) + votes[label] = v + } + case w != nil: + for i := range k { + index := indices[i] + label := me.labels[index] + v := votes[label] + w(distances[i]) + votes[label] = v + } + } +} diff --git a/knn_bench_test.go b/knn_bench_test.go new file mode 100644 index 0000000..514ccd8 --- /dev/null +++ b/knn_bench_test.go @@ -0,0 +1,28 @@ +package bitknn_test + +import ( + "fmt" + "math/rand/v2" + "testing" + + "github.com/keilerkonzept/bitknn" +) + +func BenchmarkPredict1(b *testing.B) { + votes := make([]float64, 256) + for _, dataSize := range []int{100, 1000, 1_000_000} { + for _, k := range []int{3, 10, 100} { + b.Run(fmt.Sprintf("N=%d_k=%d", dataSize, k), func(b *testing.B) { + data := randomData(dataSize) + labels := randomLabels(dataSize) + model := bitknn.Fit(data, labels, k) + query := rand.Uint64() + + b.ResetTimer() + for n := 0; n < b.N; n++ { + model.Predict1(query, votes) + } + }) + } + } +} diff --git a/knn_test.go b/knn_test.go new file mode 100644 index 0000000..594f874 --- /dev/null +++ b/knn_test.go @@ -0,0 +1,158 @@ +package bitknn_test + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/keilerkonzept/bitknn" +) + +func TestPredict1(t *testing.T) { + data := []uint64{0b0000, 0b1111, 0b0011, 0b0101} + labels := []int{0, 1, 1, 0} + k := 2 + + model := bitknn.Fit(data, labels, k) + + x := uint64(0b0010) + votes := make([]float64, k) + model.Predict1(x, votes) + + expectedVotes := []float64{1, 1} + if diff := cmp.Diff(expectedVotes, votes); diff != "" { + t.Error(diff) + } +} + +func TestPredict1WithValues(t *testing.T) { + data := []uint64{0b0000, 0b1111, 0b0011, 0b0101} + labels := []int{0, 1, 1, 0} + values := []float64{1.0, 2.0, 3.0, 4.0} + k := 2 + + model := bitknn.Fit(data, labels, k, bitknn.WithValues(values)) + + x := uint64(0b0010) + votes := make([]float64, k) + model.Predict1(x, votes) + + expectedVotes := []float64{1, 3} + if diff := cmp.Diff(expectedVotes, votes); diff != "" { + t.Error(diff) + } +} + +func TestPredict1WithLinearDecay(t *testing.T) { + data := []uint64{0b0000, 0b1111, 0b0011, 0b0101} + labels := []int{0, 1, 1, 0} + k := 3 + + model := bitknn.Fit(data, labels, k, bitknn.WithLinearDecay()) + + x := uint64(0b0001) + votes := make([]float64, 2) + model.Predict1(x, votes) + + expectedVotes := []float64{1, 0.5} + if diff := cmp.Diff(expectedVotes, votes); diff != "" { + t.Error(diff) + } +} + +func TestPredict1WithValuesAndLinearDecay(t *testing.T) { + data := []uint64{0b0000, 0b1111, 0b0011, 0b0101} + labels := []int{0, 1, 1, 0} + values := []float64{1.0, 2.0, 3.0, 3.0} + k := 3 + + model := bitknn.Fit(data, labels, k, bitknn.WithValues(values), bitknn.WithLinearDecay()) + + x := uint64(0b0000) + votes := make([]float64, 2) + model.Predict1(x, votes) + + expectedVotes := []float64{2, 1} + if diff := cmp.Diff(expectedVotes, votes); diff != "" { + t.Error(diff) + } +} + +func TestPredict1WithValuesAndQuadraticDecay(t *testing.T) { + data := []uint64{0b0000, 0b1111, 0b0011, 0b0101} + labels := []int{0, 1, 1, 0} + values := []float64{1.0, 2.0, 4.0, 5.0} + k := 3 + + model := bitknn.Fit(data, labels, k, bitknn.WithValues(values), bitknn.WithQuadraticDecay()) + x := uint64(0b0000) + votes := make([]float64, 2) + model.Predict1(x, votes) + + expectedVotes := []float64{2, 0.8} + if diff := cmp.Diff(expectedVotes, votes); diff != "" { + t.Error(diff) + } +} + +func TestPredict1WithQuadraticDecay(t *testing.T) { + data := []uint64{0b0000, 0b1111, 0b0011, 0b0101} + labels := []int{0, 1, 1, 0} + k := 3 + + model := bitknn.Fit(data, labels, k, bitknn.WithQuadraticDecay()) + x := uint64(0b0000) + votes := make([]float64, 2) + model.Predict1(x, votes) + + expectedVotes := []float64{1.2, 0.2} + if diff := cmp.Diff(expectedVotes, votes); diff != "" { + t.Error(diff) + } +} + +func TestPredict1WithValuesAndCustomDecay(t *testing.T) { + data := []uint64{0b0000, 0b1111, 0b0011, 0b0101} + labels := []int{0, 1, 1, 0} + values := []float64{1.0, 2.0, 3.0, 3.0} + k := 3 + + model := bitknn.Fit(data, labels, k, bitknn.WithValues(values), bitknn.WithDistanceWeightFunc(func(d int) float64 { + if d <= 2 { + return 1.0 + } else { + return 0.0 + } + })) + + x := uint64(0b0000) + votes := make([]float64, 2) + model.Predict1(x, votes) + + expectedVotes := []float64{4, 3} + if diff := cmp.Diff(expectedVotes, votes); diff != "" { + t.Error(diff) + } +} + +func TestPredict1WithCustomDecay(t *testing.T) { + data := []uint64{0b0000, 0b1111, 0b0011, 0b0101} + labels := []int{0, 1, 2, 0} + k := 3 + + model := bitknn.Fit(data, labels, k, bitknn.WithDistanceWeightFunc(func(d int) float64 { + if d <= 2 { + return 1.0 + } else { + return 0.0 + } + })) + + x := uint64(0b0000) + votes := make([]float64, 3) + model.Predict1(x, votes) + + expectedVotes := []float64{2, 0, 1} + if diff := cmp.Diff(expectedVotes, votes); diff != "" { + t.Error(diff) + } +} diff --git a/nearest.go b/nearest.go new file mode 100644 index 0000000..b691555 --- /dev/null +++ b/nearest.go @@ -0,0 +1,31 @@ +package bitknn + +import ( + "math/bits" +) + +// Nearest finds the nearest neighbors of the given point `x` by Hamming distance in `data`. +// The neighbor's distances and indices (in `data`) are written to the slices `distances` and `indices`. +// The two slices should be pre-allocated to length `k+1`. +// pre: +// +// cap(distances) = cap(indices) = k+1 >= 1 +func Nearest(data []uint64, k int, x uint64, distances, indices []int) int { + heap := makeNeighborHeap(distances, indices) + + var maxDist int + for i := range data { + dist := bits.OnesCount64(x ^ data[i]) + if i < k { + heap.push(dist, i) + maxDist = distances[0] + continue + } + if dist >= maxDist { + continue + } + heap.pushpop(dist, i) + maxDist = distances[0] + } + return min(len(data), k) +} diff --git a/nearest_test.go b/nearest_test.go new file mode 100644 index 0000000..b569132 --- /dev/null +++ b/nearest_test.go @@ -0,0 +1,48 @@ +package bitknn_test + +import ( + "math/rand/v2" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/keilerkonzept/bitknn" +) + +func TestNearest(t *testing.T) { + data := []uint64{0b0000, 0b1111, 0b0011, 0b0101} + k := 2 + x := uint64(0b0001) + distances := make([]int, k+1) + indices := make([]int, k+1) + + count := bitknn.Nearest(data, k, x, distances, indices) + distances = distances[:count] + indices = indices[:count] + + if count != k { + t.Errorf("Expected count %d, got %d", k, count) + } + + expectedDistances := []int{1, 1} + if diff := cmp.Diff(expectedDistances, distances); diff != "" { + t.Error(diff) + } + expectedIndices := []int{2, 0} + if diff := cmp.Diff(expectedIndices, indices); diff != "" { + t.Error(diff) + } +} + +func BenchmarkNearest(b *testing.B) { + dataSize := 10_000 + k := 5 + query := rand.Uint64() + data := randomData(dataSize) + distances := make([]int, k+1) + indices := make([]int, k+1) + + b.ResetTimer() + for n := 0; n < b.N; n++ { + bitknn.Nearest(data, k, query, distances, indices) + } +} diff --git a/utils_test.go b/utils_test.go new file mode 100644 index 0000000..c53b569 --- /dev/null +++ b/utils_test.go @@ -0,0 +1,29 @@ +package bitknn_test + +import "math/rand/v2" + +var randSource = rand.New(rand.NewPCG(0xB0, 0xA4)) + +func randomData(size int) []uint64 { + data := make([]uint64, size) + for i := range data { + data[i] = randSource.Uint64() + } + return data +} + +func randomLabels(size int) []int { + labels := make([]int, size) + for i := range labels { + labels[i] = int(randSource.Uint32N(256)) + } + return labels +} + +func randomValues(size int) []float64 { + labels := make([]float64, size) + for i := range labels { + labels[i] = randSource.Float64() + } + return labels +}