Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

engine: fix race when reading fields in Concatenate #14324

Merged
merged 4 commits into from
Oct 23, 2023
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 49 additions & 47 deletions go/vt/vtgate/engine/concatenate.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ package engine

import (
"context"
"slices"
"sync"
"sync/atomic"

"golang.org/x/sync/errgroup"

"vitess.io/vitess/go/sqltypes"
querypb "vitess.io/vitess/go/vt/proto/query"
Expand Down Expand Up @@ -238,19 +240,19 @@ func (c *Concatenate) TryStreamExecute(ctx context.Context, vcursor VCursor, bin
func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, in func(*sqltypes.Result) error) error {
ctx, cancel := context.WithCancel(inCtx)
defer cancel()
var outerErr error

var cbMu sync.Mutex
var wg, fieldMu sync.WaitGroup
var fieldRec atomic.Int64
fieldRec.Store(int64(len(c.Sources)))
fieldMu.Add(1)
var (
muCallback sync.Mutex
muFields sync.Mutex
condFields = sync.NewCond(&muFields)
wg errgroup.Group
rest = make([]*sqltypes.Result, len(c.Sources))
fields []*querypb.Field
)

rest := make([]*sqltypes.Result, len(c.Sources))
var fields []*querypb.Field
callback := func(res *sqltypes.Result, srcIdx int) error {
cbMu.Lock()
defer cbMu.Unlock()
muCallback.Lock()
defer muCallback.Unlock()

needsCoercion := false
for idx, field := range rest[srcIdx].Fields {
Expand All @@ -271,57 +273,57 @@ func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor,
return in(res)
}

once := sync.Once{}

for i, source := range c.Sources {
wg.Add(1)
currIndex, currSource := i, source

go func() {
defer wg.Done()
wg.Go(func() error {
err := vcursor.StreamExecutePrimitive(ctx, currSource, bindVars, true, func(resultChunk *sqltypes.Result) error {
// if we have fields to compare, make sure all the fields are all the same
if fieldRec.Load() > 0 && resultChunk.Fields != nil {
rest[currIndex] = resultChunk
res := fieldRec.Add(-1)
if res == 0 {
// We have received fields from all sources. We can now calculate the output types
var err error
fields, err = c.getFields(rest)
if err != nil {
return err
if resultChunk.Fields != nil {
muFields.Lock()
if rest[currIndex] == nil {
rest[currIndex] = resultChunk
if !slices.Contains(rest, nil) {
muFields.Unlock()

// We have received fields from all sources. We can now calculate the output types
var err error
fields, err = c.getFields(rest)
if err != nil {
return err
}
resultChunk.Fields = fields

defer condFields.Broadcast()
return callback(resultChunk, currIndex)
}
resultChunk.Fields = fields
defer once.Do(func() {
fieldMu.Done()
})

return callback(resultChunk, currIndex)
} else {
fieldMu.Wait()
}

for slices.Contains(rest, nil) {
condFields.Wait()
}
muFields.Unlock()
}

// If we get here, all the fields have been received
select {
case <-ctx.Done():
if ctx.Err() != nil {
return nil
default:
return callback(resultChunk, currIndex)
}
return callback(resultChunk, currIndex)
})
if err != nil {
outerErr = err
muFields.Lock()
if rest[currIndex] == nil {
// In case we haven't received any fields yet, we need to set it
// empty, or otherwise we will keep waiting forever.
rest[currIndex] = &sqltypes.Result{}
}
cancel()
once.Do(func() {
fieldMu.Done()
})
condFields.Broadcast()
muFields.Unlock()
}
}()

return err
})
}
wg.Wait()
return outerErr
return wg.Wait()
}

func (c *Concatenate) sequentialStreamExec(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, callback func(*sqltypes.Result) error) error {
Expand Down
Loading