Skip to content

Commit

Permalink
engine: fix race when reading fields in Concatenate
Browse files Browse the repository at this point in the history
Signed-off-by: Vicent Marti <[email protected]>
  • Loading branch information
vmg committed Oct 20, 2023
1 parent 688db19 commit f23981a
Showing 1 changed file with 42 additions and 47 deletions.
89 changes: 42 additions & 47 deletions go/vt/vtgate/engine/concatenate.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ package engine

import (
"context"
"slices"
"sync"
"sync/atomic"

"golang.org/x/sync/errgroup"

"vitess.io/vitess/go/sqltypes"
querypb "vitess.io/vitess/go/vt/proto/query"
Expand Down Expand Up @@ -238,19 +240,19 @@ func (c *Concatenate) TryStreamExecute(ctx context.Context, vcursor VCursor, bin
func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, in func(*sqltypes.Result) error) error {
ctx, cancel := context.WithCancel(inCtx)
defer cancel()
var outerErr error

var cbMu sync.Mutex
var wg, fieldMu sync.WaitGroup
var fieldRec atomic.Int64
fieldRec.Store(int64(len(c.Sources)))
fieldMu.Add(1)
var (
muCallback sync.Mutex
muFields sync.Mutex
condFields = sync.NewCond(&muFields)
wg errgroup.Group
rest = make([]*sqltypes.Result, len(c.Sources))
fields []*querypb.Field
)

rest := make([]*sqltypes.Result, len(c.Sources))
var fields []*querypb.Field
callback := func(res *sqltypes.Result, srcIdx int) error {
cbMu.Lock()
defer cbMu.Unlock()
muCallback.Lock()
defer muCallback.Unlock()

needsCoercion := false
for idx, field := range rest[srcIdx].Fields {
Expand All @@ -271,57 +273,50 @@ func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor,
return in(res)
}

once := sync.Once{}

for i, source := range c.Sources {
wg.Add(1)
currIndex, currSource := i, source

go func() {
defer wg.Done()
wg.Go(func() error {
err := vcursor.StreamExecutePrimitive(ctx, currSource, bindVars, true, func(resultChunk *sqltypes.Result) error {
// if we have fields to compare, make sure all the fields are all the same
if fieldRec.Load() > 0 && resultChunk.Fields != nil {
rest[currIndex] = resultChunk
res := fieldRec.Add(-1)
if res == 0 {
// We have received fields from all sources. We can now calculate the output types
var err error
fields, err = c.getFields(rest)
if err != nil {
return err
if resultChunk.Fields != nil {
muFields.Lock()
if rest[currIndex] == nil {
rest[currIndex] = resultChunk
if !slices.Contains(rest, nil) {
muFields.Unlock()

// We have received fields from all sources. We can now calculate the output types
var err error
fields, err = c.getFields(rest)
if err != nil {
return err
}
resultChunk.Fields = fields

defer condFields.Broadcast()
return callback(resultChunk, currIndex)
}
resultChunk.Fields = fields
defer once.Do(func() {
fieldMu.Done()
})

return callback(resultChunk, currIndex)
} else {
fieldMu.Wait()
}

for slices.Contains(rest, nil) {
condFields.Wait()
}
muFields.Unlock()
}

// If we get here, all the fields have been received
select {
case <-ctx.Done():
if ctx.Err() != nil {
return nil
default:
return callback(resultChunk, currIndex)
}
return callback(resultChunk, currIndex)
})
if err != nil {
outerErr = err
cancel()
once.Do(func() {
fieldMu.Done()
})
condFields.Broadcast()
}
}()

return err
})
}
wg.Wait()
return outerErr
return wg.Wait()
}

func (c *Concatenate) sequentialStreamExec(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, callback func(*sqltypes.Result) error) error {
Expand Down

0 comments on commit f23981a

Please sign in to comment.