Skip to content

Commit

Permalink
add check program
Browse files Browse the repository at this point in the history
Signed-off-by: Kosuke Morimoto <[email protected]>
  • Loading branch information
kmrmt committed Sep 12, 2023
1 parent 83ab327 commit 09e9263
Show file tree
Hide file tree
Showing 2 changed files with 227 additions and 34 deletions.
177 changes: 177 additions & 0 deletions cmd/tools/cli/benchmark/core/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
//
// Copyright (C) 2019-2023 vdaas.org vald team <[email protected]>
//
// Licensed under the Apache License, Version 2.0 (the "License");
// You may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//

// Package ngt provides implementation of Go API for https://github.com/yahoojapan/NGT
package main

import (
"context"
"fmt"
"os"
"os/signal"
"runtime"
"strings"
"sync"
"syscall"

"github.com/vdaas/vald/internal/core/algorithm/ngt"
"github.com/vdaas/vald/internal/log"
"gonum.org/v1/hdf5"
)

func main() {
vectors, _, _ := load("sift-128-euclidean.hdf5")
n, _ := ngt.New(
ngt.WithDimension(len(vectors[0])),
ngt.WithDefaultPoolSize(8),
ngt.WithObjectType(ngt.Float),
ngt.WithDistanceType(ngt.L2),
)
pid := os.Getpid()

log.Infof("# of vectors: %v", len(vectors))
output := func(header string) {
status := fmt.Sprintf("/proc/%d/status", pid)
buf, err := os.ReadFile(status)
if err != nil {
log.Fatal(err)
}
var vmpeak, vmrss, vmhwm string
for _, line := range strings.Split(string(buf), "\n") {
switch {
case strings.HasPrefix(line, "VmPeak"):
vmpeak = strings.Fields(line)[1]
case strings.HasPrefix(line, "VmHWM"):
vmhwm = strings.Fields(line)[1]
case strings.HasPrefix(line, "VmRSS"):
vmrss = strings.Fields(line)[1]
}
}

var m runtime.MemStats
runtime.ReadMemStats(&m)
log.Infof("%v\t%v\t%v\t%v\t%v\t%v\t%v\t%v\t%v", header, vmpeak, vmhwm, vmrss, m.Alloc/1024, m.TotalAlloc/1024, m.HeapAlloc/1024, m.HeapSys/1024, m.HeapInuse/1024)
}
log.Info(" operation\tVmPeak\tVmHWM\tVmRSS\tAlloc\tTotalAlloc\tHeapAlloc\tHeapSys\tHeapInuse")
output(" start")
defer output(" end")
ctx, cancel := context.WithCancel(context.Background())
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
for {
select {
case <-ctx.Done():
wg.Done()
return
default:
ids := make([]uint, len(vectors))
for i, vector := range vectors {
id, err := n.Insert(vector)
if err != nil {
log.Fatal(err)
}
ids[i] = id
}
output(" insert")

if err := n.CreateIndex(8); err != nil {
log.Fatal(err)
}
output("create index")

for _, id := range ids {
if err := n.Remove(id); err != nil {
log.Fatal(err)
}
}
output(" remove")
}
}
}()

ch := make(chan os.Signal)
signal.Notify(ch, os.Interrupt, syscall.SIGTERM)

<-ch
cancel()

wg.Wait()
}

// load function loads training and test vector from hdf file. The size of ids is same to the number of training data.
// Each id, which is an element of ids, will be set a random number.
func load(path string) (train, test [][]float32, err error) {
var f *hdf5.File
f, err = hdf5.OpenFile(path, hdf5.F_ACC_RDONLY)
if err != nil {
return nil, nil, err
}
defer f.Close()

// readFn function reads vectors of the hierarchy with the given the name.
readFn := func(name string) ([][]float32, error) {
// Opens and returns a named Dataset.
// The returned dataset must be closed by the user when it is no longer needed.
d, err := f.OpenDataset(name)
if err != nil {
return nil, err
}
defer d.Close()

// Space returns an identifier for a copy of the dataspace for a dataset.
sp := d.Space()
defer sp.Close()

// SimpleExtentDims returns dataspace dimension size and maximum size.
dims, _, _ := sp.SimpleExtentDims()
row, dim := int(dims[0]), int(dims[1])

// Gets the stored vector. All are represented as one-dimensional arrays.
// The type of the slice depends on your dataset.
// For fashion-mnist-784-euclidean.hdf5, the datatype is float32.
vec := make([]float32, sp.SimpleExtentNPoints())
if err := d.Read(&vec); err != nil {
return nil, err
}

// Converts a one-dimensional array to a two-dimensional array.
// Use the `dim` variable as a separator.
vecs := make([][]float32, row)
for i := 0; i < row; i++ {
vecs[i] = make([]float32, dim)
for j := 0; j < dim; j++ {
vecs[i][j] = float32(vec[i*dim+j])
}
}

return vecs, nil
}

// Gets vector of `train` hierarchy.
train, err = readFn("train")
if err != nil {
return nil, nil, err
}

// Gets vector of `test` hierarchy.
test, err = readFn("test")
if err != nil {
return nil, nil, err
}

return
}
84 changes: 50 additions & 34 deletions internal/core/algorithm/ngt/ngt_bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
package ngt

import (
"fmt"
"os"
"runtime"
"strings"
"testing"

"gonum.org/v1/hdf5"
Expand All @@ -26,7 +30,7 @@ import (
var (
vectors [][]float32
n NGT
ids []uint
pid int
)

func init() {
Expand All @@ -37,48 +41,60 @@ func init() {
WithObjectType(Float),
WithDistanceType(L2),
)
pid = os.Getpid()
}

func RunNGT1(b *testing.B) error {
b.Helper()

ids = make([]uint, len(vectors))
b.ResetTimer()
for i, vector := range vectors {
id, err := n.Insert(vector)
func BenchmarkNGT(b *testing.B) {
b.Logf("# of vectors: %v", len(vectors))
output := func(header string) {
status := fmt.Sprintf("/proc/%d/status", pid)
buf, err := os.ReadFile(status)
if err != nil {
return err
b.Fatal(err)
}
var vmpeak, vmrss, vmhwm string
for _, line := range strings.Split(string(buf), "\n") {
switch {
case strings.HasPrefix(line, "VmPeak"):
vmpeak = strings.Fields(line)[1]
case strings.HasPrefix(line, "VmHWM"):
vmhwm = strings.Fields(line)[1]
case strings.HasPrefix(line, "VmRSS"):
vmrss = strings.Fields(line)[1]
}
}
ids[i] = id
}

if err := n.CreateIndex(8); err != nil {
return err
var m runtime.MemStats
runtime.ReadMemStats(&m)
b.Logf("%v\t%v\t%v\t%v\t%v\t%v\t%v\t%v\t%v", header, vmpeak, vmhwm, vmrss, m.Alloc/1024, m.TotalAlloc/1024, m.HeapAlloc/1024, m.HeapSys/1024, m.HeapInuse/1024)
}
return nil
}

func RunNGT2(b *testing.B) error {
b.Helper()

b.Logf(" operation\tVmPeak\tVmHWM\tVmRSS\tAlloc\tTotalAlloc\tHeapAlloc\tHeapSys\tHeapInuse")
b.ResetTimer()
for _, id := range ids {
if err := n.Remove(id); err != nil {
return err
}
}
return nil
}
output(" start")
defer output(" end")
for N := 0; N < b.N; N++ {
for i := 0; i < 3; i++ {
ids := make([]uint, len(vectors))
for idx, vector := range vectors {
id, err := n.Insert(vector)
if err != nil {
b.Fatal(err)
}
ids[idx] = id
}
output(" insert")

func BenchmarkNGT(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
if err := RunNGT1(b); err != nil {
b.Fatal(err)
}
if err := n.CreateIndex(8); err != nil {
b.Fatal(err)
}
output("create index")

if err := RunNGT2(b); err != nil {
b.Fatal(err)
for _, id := range ids {
if err := n.Remove(id); err != nil {
b.Fatal(err)
}
}
output(" remove")
}
}
}
Expand Down

0 comments on commit 09e9263

Please sign in to comment.