Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

engine: fix race when reading fields in Concatenate #14324

Merged
merged 4 commits into from
Oct 23, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 69 additions & 53 deletions go/vt/vtgate/engine/concatenate.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ package engine

import (
"context"
"slices"
"sync"
"sync/atomic"

"golang.org/x/sync/errgroup"

"vitess.io/vitess/go/sqltypes"
querypb "vitess.io/vitess/go/vt/proto/query"
Expand Down Expand Up @@ -236,92 +238,106 @@ func (c *Concatenate) TryStreamExecute(ctx context.Context, vcursor VCursor, bin
}

func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, in func(*sqltypes.Result) error) error {
// Scoped context; any early exit triggers cancel() to clean up ongoing work.
ctx, cancel := context.WithCancel(inCtx)
defer cancel()
var outerErr error

var cbMu sync.Mutex
var wg, fieldMu sync.WaitGroup
var fieldRec atomic.Int64
fieldRec.Store(int64(len(c.Sources)))
fieldMu.Add(1)

rest := make([]*sqltypes.Result, len(c.Sources))
var fields []*querypb.Field
// Mutexes for dealing with concurrent access to shared state.
var (
muCallback sync.Mutex // Protects callback
muFields sync.Mutex // Protects field state
condFields = sync.NewCond(&muFields) // Condition var for field arrival
wg errgroup.Group // Wait group for all streaming goroutines
rest = make([]*sqltypes.Result, len(c.Sources)) // Collects first result from each source to derive fields
fields []*querypb.Field // Cached final field types
)

// Process each result chunk, considering type coercion.
callback := func(res *sqltypes.Result, srcIdx int) error {
cbMu.Lock()
defer cbMu.Unlock()
muCallback.Lock()
defer muCallback.Unlock()

// Check if type coercion needed for this source.
// We only need to check if fields are not in NoNeedToTypeCheck set.
needsCoercion := false
for idx, field := range rest[srcIdx].Fields {
_, ok := c.NoNeedToTypeCheck[idx]
if !ok && fields[idx].Type != field.Type {
_, skip := c.NoNeedToTypeCheck[idx]
if !skip && fields[idx].Type != field.Type {
needsCoercion = true
break
}
}

// Apply type coercion if needed.
if needsCoercion {
for _, row := range res.Rows {
err := c.coerceValuesTo(row, fields)
if err != nil {
if err := c.coerceValuesTo(row, fields); err != nil {
return err
}
}
}
return in(res)
}

once := sync.Once{}

// Start streaming query execution in parallel for all sources.
for i, source := range c.Sources {
wg.Add(1)
currIndex, currSource := i, source

go func() {
defer wg.Done()
wg.Go(func() error {
err := vcursor.StreamExecutePrimitive(ctx, currSource, bindVars, true, func(resultChunk *sqltypes.Result) error {
// if we have fields to compare, make sure all the fields are all the same
if fieldRec.Load() > 0 && resultChunk.Fields != nil {
rest[currIndex] = resultChunk
res := fieldRec.Add(-1)
if res == 0 {
// We have received fields from all sources. We can now calculate the output types
var err error
fields, err = c.getFields(rest)
if err != nil {
return err
// Process fields when they arrive; coordinate field agreement across sources.
if resultChunk.Fields != nil {
muFields.Lock()

// Capture the initial result chunk to determine field types later.
if rest[currIndex] == nil {
rest[currIndex] = resultChunk

// If this was the last source to report its fields, derive the final output fields.
if !slices.Contains(rest, nil) {
muFields.Unlock()

// We have received fields from all sources. We can now calculate the output types
var err error
fields, err = c.getFields(rest)
if err != nil {
return err
}
resultChunk.Fields = fields

defer condFields.Broadcast()
return callback(resultChunk, currIndex)
}
resultChunk.Fields = fields
defer once.Do(func() {
fieldMu.Done()
})

return callback(resultChunk, currIndex)
} else {
fieldMu.Wait()
}
// Wait for fields from all sources.
for slices.Contains(rest, nil) {
condFields.Wait()
}
muFields.Unlock()
}

// If we get here, all the fields have been received
select {
case <-ctx.Done():
// Context check to avoid extra work.
if ctx.Err() != nil {
return nil
default:
return callback(resultChunk, currIndex)
}
return callback(resultChunk, currIndex)
})

// Error handling and context cleanup for this source.
if err != nil {
outerErr = err
muFields.Lock()
if rest[currIndex] == nil {
// Signal that this source is done, even if by failure, to unblock field waiting.
rest[currIndex] = &sqltypes.Result{}
}
cancel()
once.Do(func() {
fieldMu.Done()
})
condFields.Broadcast()
muFields.Unlock()
}
}()

return err
})
}
wg.Wait()
return outerErr
// Wait for all sources to complete.
return wg.Wait()
}

func (c *Concatenate) sequentialStreamExec(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, callback func(*sqltypes.Result) error) error {
Expand Down
Loading