Skip to content

Commit

Permalink
refactor: Use generics in function of bidirectional stream (#2803)
Browse files Browse the repository at this point in the history
* refactor: Use generics in function of bidirectional stream

* Pass dereferenced value of to `r` to `l.loaderFunc`

* fix: marshal/unmarshal error

* fix: Add boolean value as the return value of the callback

* fix: delete redudancy variable

---------

Co-authored-by: Kiichiro YUKAWA <[email protected]>
Co-authored-by: Yusuke Kato <[email protected]>
  • Loading branch information
3 people authored Feb 6, 2025
1 parent b4b19cf commit be30998
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 41 deletions.
17 changes: 10 additions & 7 deletions internal/net/grpc/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ type (
// It receives messages from the stream, calls the function with the received message, and sends the returned message to the stream.
// It limits the number of concurrent calls to the function with the concurrency integer.
// It records errors and returns them as a single error.
func BidirectionalStream[Q any, R any](
func BidirectionalStream[Q, R any](
ctx context.Context,
stream ServerStream,
concurrency int,
Expand Down Expand Up @@ -167,8 +167,8 @@ func BidirectionalStream[Q any, R any](
}

// BidirectionalStreamClient is gRPC client stream.
func BidirectionalStreamClient(
stream ClientStream, dataProvider, newData func() any, f func(any, error),
func BidirectionalStreamClient[S, R any](
stream ClientStream, sendDataProvider func() *S, callBack func(*R, error) bool,
) (err error) {
if stream == nil {
return errors.ErrGRPCClientStreamNotFound
Expand All @@ -183,13 +183,16 @@ func BidirectionalStreamClient(
case <-ctx.Done():
return ctx.Err()
default:
res := newData()
res := new(R)
err = stream.RecvMsg(res)
if err == io.EOF || errors.Is(err, io.EOF) {
cancel()
return nil
}
f(res, err)
if !callBack(res, err) {
cancel()
return nil
}
}
}
}))
Expand All @@ -208,7 +211,7 @@ func BidirectionalStreamClient(
case <-ctx.Done():
return eg.Wait()
default:
data := dataProvider()
data := sendDataProvider()
if data == nil {
err = stream.CloseSend()
cancel()
Expand All @@ -218,7 +221,7 @@ func BidirectionalStreamClient(
return eg.Wait()
}

err = stream.SendMsg(data)
err = stream.SendMsg(*data)
if err != nil {
return err
}
Expand Down
22 changes: 12 additions & 10 deletions pkg/tools/cli/loadtest/service/insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import (

func insertRequestProvider(
dataset assets.Dataset, batchSize int,
) (f func() any, size int, err error) {
) (f func() *any, size int, err error) {
switch {
case batchSize == 1:
f, size = objectVectorProvider(dataset)
Expand All @@ -42,47 +42,49 @@ func insertRequestProvider(
return f, size, nil
}

func objectVectorProvider(dataset assets.Dataset) (func() any, int) {
func objectVectorProvider(dataset assets.Dataset) (func() *any, int) {
idx := int32(-1)
size := dataset.TrainSize()
return func() (ret any) {
return func() (ret *any) {
if i := int(atomic.AddInt32(&idx, 1)); i < size {
v, err := dataset.Train(i)
if err != nil {
return nil
}
ret = &payload.Insert_Request{
obj := any(&payload.Insert_Request{
Vector: &payload.Object_Vector{
Id: fuid.String(),
Vector: v.([]float32),
},
}
})
ret = &obj
}
return ret
}, size
}

func objectVectorsProvider(dataset assets.Dataset, n int) (func() any, int) {
func objectVectorsProvider(dataset assets.Dataset, n int) (func() *any, int) {
provider, s := objectVectorProvider(dataset)
size := s / n
if s%n != 0 {
size = size + 1
}
return func() (ret any) {
return func() (ret *any) {
r := make([]*payload.Insert_Request, 0, n)
for i := 0; i < n; i++ {
d := provider()
if d == nil {
break
}
r = append(r, d.(*payload.Insert_Request))
r = append(r, (*d).(*payload.Insert_Request))
}
if len(r) == 0 {
return nil
}
return &payload.Insert_MultiRequest{
obj := any(&payload.Insert_MultiRequest{
Requests: r,
}
})
return &obj
}, size
}

Expand Down
40 changes: 20 additions & 20 deletions pkg/tools/cli/loadtest/service/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ type loader struct {
dataset string
progressDuration time.Duration
loaderFunc loadFunc
dataProvider func() any
sendDataProvider func() *any
dataSize int
operation config.Operation
}
Expand Down Expand Up @@ -97,9 +97,9 @@ func (l *loader) Prepare(context.Context) (err error) {

switch l.operation {
case config.Insert, config.StreamInsert:
l.dataProvider, l.dataSize, err = insertRequestProvider(dataset, l.batchSize)
l.sendDataProvider, l.dataSize, err = insertRequestProvider(dataset, l.batchSize)
case config.Search, config.StreamSearch:
l.dataProvider, l.dataSize, err = searchRequestProvider(dataset)
l.sendDataProvider, l.dataSize, err = searchRequestProvider(dataset)
}
if err != nil {
return err
Expand Down Expand Up @@ -135,7 +135,7 @@ func (l *loader) Do(ctx context.Context) <-chan error {
log.Infof("progress %d requests, %f[vps], error: %d", pgCnt, vps(int(pgCnt)*l.batchSize, start, time.Now()), errCnt)
}

f := func(i any, err error) {
f := func(i *any, err error) {
atomic.AddInt32(&pgCnt, 1)
if err != nil {
atomic.AddInt32(&errCnt, 1)
Expand Down Expand Up @@ -184,23 +184,12 @@ func (l *loader) Do(ctx context.Context) <-chan error {
}

func (l *loader) do(
ctx context.Context, f func(any, error), notify func(context.Context, error),
ctx context.Context, f func(*any, error), notify func(context.Context, error),
) (err error) {
eg, egctx := errgroup.New(ctx)

switch l.operation {
case config.StreamInsert, config.StreamSearch:
var newData func() any
switch l.operation {
case config.StreamInsert:
newData = func() any {
return new(payload.Empty)
}
case config.StreamSearch:
newData = func() any {
return new(payload.Search_Response)
}
}
eg.Go(safety.RecoverFunc(func() (err error) {
defer func() {
if err != nil {
Expand All @@ -213,7 +202,18 @@ func (l *loader) do(
if err != nil {
return nil, err
}
return nil, grpc.BidirectionalStreamClient(st.(grpc.ClientStream), l.dataProvider, newData, f)

if l.operation == config.StreamInsert {
return nil, grpc.BidirectionalStreamClient(st.(grpc.ClientStream), l.sendDataProvider, func(i *payload.Empty, err error) bool {
f(nil, err)
return true
})
} else {
return nil, grpc.BidirectionalStreamClient(st.(grpc.ClientStream), l.sendDataProvider, func(i *payload.Search_Response, err error) bool {
f(nil, err)
return true
})
}
})
return err
}))
Expand All @@ -222,7 +222,7 @@ func (l *loader) do(
eg.SetLimit(l.concurrency)

for {
r := l.dataProvider()
r := l.sendDataProvider()
if r == nil {
break
}
Expand All @@ -233,8 +233,8 @@ func (l *loader) do(
err = nil
}()
_, err = l.client.Do(egctx, l.addr, func(ctx context.Context, conn *grpc.ClientConn, copts ...grpc.CallOption) (any, error) {
res, err := l.loaderFunc(egctx, conn, r)
f(res, err)
res, err := l.loaderFunc(egctx, conn, *r)
f(&res, err)
return res, err
})

Expand Down
9 changes: 5 additions & 4 deletions pkg/tools/cli/loadtest/service/search.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,24 @@ import (
"github.com/vdaas/vald/pkg/tools/cli/loadtest/assets"
)

func searchRequestProvider(dataset assets.Dataset) (func() any, int, error) {
func searchRequestProvider(dataset assets.Dataset) (func() *any, int, error) {
size := dataset.QuerySize()
idx := int32(-1)
return func() (ret any) {
return func() (ret *any) {
if i := int(atomic.AddInt32(&idx, 1)); i < size {
v, err := dataset.Query(i)
if err != nil {
return nil
}
ret = &payload.Search_Request{
obj := any(&payload.Search_Request{
Vector: v.([]float32),
Config: &payload.Search_Config{
Num: 10,
Radius: -1,
Epsilon: 0.1,
},
}
})
ret = &obj
}
return ret
}, size, nil
Expand Down

0 comments on commit be30998

Please sign in to comment.