diff --git a/go/vt/vtgate/engine/concatenate.go b/go/vt/vtgate/engine/concatenate.go index 13727124e78..eb93711eed2 100644 --- a/go/vt/vtgate/engine/concatenate.go +++ b/go/vt/vtgate/engine/concatenate.go @@ -102,12 +102,14 @@ func (c *Concatenate) TryExecute(ctx context.Context, vcursor VCursor, bindVars } var rows [][]sqltypes.Value - err = c.coerceAndVisitResults(res, fieldTypes, func(result *sqltypes.Result) error { + callback := func(result *sqltypes.Result) error { rows = append(rows, result.Rows...) return nil - }, evalengine.ParseSQLMode(vcursor.SQLMode())) - if err != nil { - return nil, err + } + for _, r := range res { + if err = c.coerceAndVisitResultsForOneSource([]*sqltypes.Result{r}, fields, fieldTypes, callback, evalengine.ParseSQLMode(vcursor.SQLMode())); err != nil { + return nil, err + } } return &sqltypes.Result{ @@ -245,12 +247,14 @@ func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor, // Mutexes for dealing with concurrent access to shared state. var ( - muCallback sync.Mutex // Protects callback - muFields sync.Mutex // Protects field state - condFields = sync.NewCond(&muFields) // Condition var for field arrival - wg errgroup.Group // Wait group for all streaming goroutines - rest = make([]*sqltypes.Result, len(c.Sources)) // Collects first result from each source to derive fields - fieldTypes []evalengine.Type // Cached final field types + muCallback sync.Mutex // Protects callback + muFields sync.Mutex // Protects field state + condFields = sync.NewCond(&muFields) // Condition var for field arrival + wg errgroup.Group // Wait group for all streaming goroutines + rest = make([]*sqltypes.Result, len(c.Sources)) // Collects first result from each source to derive fields + fieldTypes []evalengine.Type // Cached final field types + resultFields []*querypb.Field // Final fields that need to be set for the first result. + needsCoercion = make([]bool, len(c.Sources)) // Tracks if coercion is needed for each individual source ) // Process each result chunk, considering type coercion. @@ -258,19 +262,8 @@ func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor, muCallback.Lock() defer muCallback.Unlock() - // Check if type coercion needed for this source. - // We only need to check if fields are not in NoNeedToTypeCheck set. - needsCoercion := false - for idx, field := range rest[srcIdx].Fields { - _, skip := c.NoNeedToTypeCheck[idx] - if !skip && fieldTypes[idx].Type() != field.Type { - needsCoercion = true - break - } - } - // Apply type coercion if needed. - if needsCoercion { + if needsCoercion[srcIdx] { for _, row := range res.Rows { if err := c.coerceValuesTo(row, fieldTypes, sqlmode); err != nil { return err @@ -296,12 +289,29 @@ func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor, if !slices.Contains(rest, nil) { // We have received fields from all sources. We can now calculate the output types var err error - resultChunk.Fields, fieldTypes, err = c.getFieldTypes(vcursor, rest) + resultFields, fieldTypes, err = c.getFieldTypes(vcursor, rest) if err != nil { muFields.Unlock() return err } + // Check if we need coercion for each source. + for srcIdx, result := range rest { + srcNeedsCoercion := false + for idx, field := range result.Fields { + _, skip := c.NoNeedToTypeCheck[idx] + // We only need to check if fields are not in NoNeedToTypeCheck set. + if !skip && fieldTypes[idx].Type() != field.Type { + srcNeedsCoercion = true + break + } + } + needsCoercion[srcIdx] = srcNeedsCoercion + } + + // We only need to send the fields in the first result. + // We set this field after the coercion check to avoid calculating incorrect needs coercion value. + resultChunk.Fields = resultFields muFields.Unlock() defer condFields.Broadcast() return callback(resultChunk, currIndex) @@ -310,8 +320,11 @@ func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor, // Wait for fields from all sources. for slices.Contains(rest, nil) { + // This wait call lets go of the muFields lock and acquires it again later after waiting. condFields.Wait() } + // We only need to send fields in the first result. + resultChunk.Fields = nil muFields.Unlock() // Context check to avoid extra work. @@ -368,12 +381,12 @@ func (c *Concatenate) sequentialStreamExec(ctx context.Context, vcursor VCursor, firsts[i] = result[0] } - _, fieldTypes, err := c.getFieldTypes(vcursor, firsts) + fields, fieldTypes, err := c.getFieldTypes(vcursor, firsts) if err != nil { return err } for _, res := range results { - if err = c.coerceAndVisitResults(res, fieldTypes, callback, sqlmode); err != nil { + if err = c.coerceAndVisitResultsForOneSource(res, fields, fieldTypes, callback, sqlmode); err != nil { return err } } @@ -381,25 +394,33 @@ func (c *Concatenate) sequentialStreamExec(ctx context.Context, vcursor VCursor, return nil } -func (c *Concatenate) coerceAndVisitResults( +func (c *Concatenate) coerceAndVisitResultsForOneSource( res []*sqltypes.Result, + fields []*querypb.Field, fieldTypes []evalengine.Type, callback func(*sqltypes.Result) error, sqlmode evalengine.SQLMode, ) error { + if len(res) == 0 { + return nil + } + needsCoercion := false + for idx, field := range res[0].Fields { + if fieldTypes[idx].Type() != field.Type { + needsCoercion = true + break + } + } + if res[0].Fields != nil { + res[0].Fields = fields + } + for _, r := range res { if len(r.Rows) > 0 && len(fieldTypes) != len(r.Rows[0]) { return errWrongNumberOfColumnsInSelect } - needsCoercion := false - for idx, field := range r.Fields { - if fieldTypes[idx].Type() != field.Type { - needsCoercion = true - break - } - } if needsCoercion { for _, row := range r.Rows { err := c.coerceValuesTo(row, fieldTypes, sqlmode) diff --git a/go/vt/vtgate/engine/concatenate_test.go b/go/vt/vtgate/engine/concatenate_test.go index dd2b1300e9b..39b9ed961b3 100644 --- a/go/vt/vtgate/engine/concatenate_test.go +++ b/go/vt/vtgate/engine/concatenate_test.go @@ -124,27 +124,24 @@ func TestConcatenate_NoErrors(t *testing.T) { if !tx { txStr = "NotInTx" } - t.Run(fmt.Sprintf("%s-%s-Exec", txStr, tc.testName), func(t *testing.T) { - qr, err := concatenate.TryExecute(context.Background(), vcursor, nil, true) + checkResult := func(t *testing.T, qr *sqltypes.Result, err error) { if tc.expectedError == "" { require.NoError(t, err) utils.MustMatch(t, tc.expectedResult.Fields, qr.Fields, "fields") - utils.MustMatch(t, tc.expectedResult.Rows, qr.Rows) + require.NoError(t, sqltypes.RowsEquals(tc.expectedResult.Rows, qr.Rows)) } else { require.Error(t, err) require.Contains(t, err.Error(), tc.expectedError) } + } + t.Run(fmt.Sprintf("%s-%s-Exec", txStr, tc.testName), func(t *testing.T) { + qr, err := concatenate.TryExecute(context.Background(), vcursor, nil, true) + checkResult(t, qr, err) }) t.Run(fmt.Sprintf("%s-%s-StreamExec", txStr, tc.testName), func(t *testing.T) { qr, err := wrapStreamExecute(concatenate, vcursor, nil, true) - if tc.expectedError == "" { - require.NoError(t, err) - require.NoError(t, sqltypes.RowsEquals(tc.expectedResult.Rows, qr.Rows)) - } else { - require.Error(t, err) - require.Contains(t, err.Error(), tc.expectedError) - } + checkResult(t, qr, err) }) } } diff --git a/go/vt/vtgate/engine/fake_primitive_test.go b/go/vt/vtgate/engine/fake_primitive_test.go index e992c2a4623..532d3ebb970 100644 --- a/go/vt/vtgate/engine/fake_primitive_test.go +++ b/go/vt/vtgate/engine/fake_primitive_test.go @@ -24,6 +24,7 @@ import ( "testing" "golang.org/x/sync/errgroup" + "google.golang.org/protobuf/proto" "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" @@ -111,7 +112,7 @@ func (f *fakePrimitive) syncCall(wantfields bool, callback func(*sqltypes.Result } result := &sqltypes.Result{} for i := 0; i < len(r.Rows); i++ { - result.Rows = append(result.Rows, r.Rows[i]) + result.Rows = append(result.Rows, sqltypes.CopyRow(r.Rows[i])) // Send only two rows at a time. if i%2 == 1 { if err := callback(result); err != nil { @@ -188,6 +189,15 @@ func wrapStreamExecute(prim Primitive, vcursor VCursor, bindVars map[string]*que if result == nil { result = r } else { + if r.Fields != nil { + for i, field := range r.Fields { + aField := field + bField := result.Fields[i] + if !proto.Equal(aField, bField) { + return fmt.Errorf("fields differ: %s <> %s", aField.String(), bField.String()) + } + } + } result.Rows = append(result.Rows, r.Rows...) } return nil