diff --git a/go/vt/vtgate/engine/cached_size.go b/go/vt/vtgate/engine/cached_size.go index 5ff7a7c96ce..2606bad8052 100644 --- a/go/vt/vtgate/engine/cached_size.go +++ b/go/vt/vtgate/engine/cached_size.go @@ -77,6 +77,29 @@ func (cached *CheckCol) CachedSize(alloc bool) int64 { size += cached.CollationEnv.CachedSize(true) return size } +func (cached *Coerce) CachedSize(alloc bool) int64 { + if cached == nil { + return int64(0) + } + size := int64(0) + if alloc { + size += int64(48) + } + // field Source vitess.io/vitess/go/vt/vtgate/engine.Primitive + if cc, ok := cached.Source.(cachedObject); ok { + size += cc.CachedSize(true) + } + // field Types []*vitess.io/vitess/go/vt/vtgate/evalengine.Type + { + size += hack.RuntimeAllocSize(int64(cap(cached.Types)) * int64(8)) + for _, elem := range cached.Types { + if elem != nil { + size += hack.RuntimeAllocSize(int64(16)) + } + } + } + return size +} //go:nocheckptr func (cached *Concatenate) CachedSize(alloc bool) int64 { @@ -85,7 +108,7 @@ func (cached *Concatenate) CachedSize(alloc bool) int64 { } size := int64(0) if alloc { - size += int64(32) + size += int64(112) } // field Sources []vitess.io/vitess/go/vt/vtgate/engine.Primitive { @@ -107,6 +130,17 @@ func (cached *Concatenate) CachedSize(alloc bool) int64 { size += hack.RuntimeAllocSize(int64(numBuckets * 208)) } } + // field fields []*vitess.io/vitess/go/vt/proto/query.Field + { + size += hack.RuntimeAllocSize(int64(cap(cached.fields)) * int64(8)) + for _, elem := range cached.fields { + size += elem.CachedSize(true) + } + } + // field fieldTypes []vitess.io/vitess/go/vt/vtgate/evalengine.Type + { + size += hack.RuntimeAllocSize(int64(cap(cached.fieldTypes)) * int64(16)) + } return size } func (cached *DBDDL) CachedSize(alloc bool) int64 { @@ -1153,6 +1187,25 @@ func (cached *ShowExec) CachedSize(alloc bool) int64 { size += cached.ShowFilter.CachedSize(true) return size } +func (cached *SimpleConcatenate) CachedSize(alloc bool) int64 { + if cached == nil { + return int64(0) + } + size := int64(0) + if alloc { + size += int64(24) + } + // field Sources []vitess.io/vitess/go/vt/vtgate/engine.Primitive + { + size += hack.RuntimeAllocSize(int64(cap(cached.Sources)) * int64(16)) + for _, elem := range cached.Sources { + if cc, ok := elem.(cachedObject); ok { + size += cc.CachedSize(true) + } + } + } + return size +} func (cached *SimpleProjection) CachedSize(alloc bool) int64 { if cached == nil { return int64(0) diff --git a/go/vt/vtgate/engine/coerce.go b/go/vt/vtgate/engine/coerce.go new file mode 100644 index 00000000000..780d7e145f9 --- /dev/null +++ b/go/vt/vtgate/engine/coerce.go @@ -0,0 +1,159 @@ +/* +Copyright 2024 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package engine + +import ( + "context" + "fmt" + + "vitess.io/vitess/go/sqltypes" + querypb "vitess.io/vitess/go/vt/proto/query" + "vitess.io/vitess/go/vt/vtgate/evalengine" +) + +// Coerce is used to change types of incoming columns +type Coerce struct { + Source Primitive + Types []*evalengine.Type +} + +var _ Primitive = (*Coerce)(nil) + +func (c *Coerce) RouteType() string { + return c.Source.RouteType() +} + +func (c *Coerce) GetKeyspaceName() string { + return c.Source.GetKeyspaceName() +} + +func (c *Coerce) GetTableName() string { + return c.Source.GetTableName() +} + +func (c *Coerce) GetFields( + ctx context.Context, + vcursor VCursor, + bvars map[string]*querypb.BindVariable, +) (*sqltypes.Result, error) { + res, err := c.Source.GetFields(ctx, vcursor, bvars) + if err != nil { + return nil, err + } + c.setFields(res) + + return res, nil +} + +func (c *Coerce) setFields(res *sqltypes.Result) { + if len(res.Fields) == 0 { + return + } + for i, t := range c.Types { + if t == nil { + continue + } + + t.SetTypeAndFlags(res.Fields[i]) + } +} + +func (c *Coerce) NeedsTransaction() bool { + return c.Source.NeedsTransaction() +} + +func (c *Coerce) TryExecute( + ctx context.Context, + vcursor VCursor, + bindVars map[string]*querypb.BindVariable, + wantfields bool, +) (*sqltypes.Result, error) { + sqlmode := evalengine.ParseSQLMode(vcursor.SQLMode()) + + res, err := vcursor.ExecutePrimitive(ctx, c.Source, bindVars, wantfields) + if err != nil { + return nil, err + } + + for _, row := range res.Rows { + err := c.coerceValuesTo(row, sqlmode) + if err != nil { + return nil, err + } + } + + c.setFields(res) + return res, nil +} + +func (c *Coerce) coerceValuesTo(row sqltypes.Row, sqlmode evalengine.SQLMode) error { + for i, value := range row { + typ := c.Types[i] + if typ == nil { + // this column does not need to be coerced + continue + } + + newValue, err := evalengine.CoerceTo(value, *typ, sqlmode) + if err != nil { + return err + } + row[i] = newValue + } + return nil +} + +func (c *Coerce) TryStreamExecute( + ctx context.Context, + vcursor VCursor, + bindVars map[string]*querypb.BindVariable, + wantfields bool, + callback func(*sqltypes.Result) error, +) error { + sqlmode := evalengine.ParseSQLMode(vcursor.SQLMode()) + + return vcursor.StreamExecutePrimitive(ctx, c.Source, bindVars, wantfields, func(result *sqltypes.Result) error { + for _, row := range result.Rows { + err := c.coerceValuesTo(row, sqlmode) + if err != nil { + return err + } + } + c.setFields(result) + return callback(result) + }) +} + +func (c *Coerce) Inputs() ([]Primitive, []map[string]any) { + return []Primitive{c.Source}, nil +} + +func (c *Coerce) description() PrimitiveDescription { + var cols []string + for idx, typ := range c.Types { + if typ == nil { + continue + } + cols = append(cols, fmt.Sprintf("%d:%s", idx, typ.Type().String())) + } + return PrimitiveDescription{ + OperatorType: "Coerce", + Other: map[string]any{ + "Fields": cols, + }, + } +} diff --git a/go/vt/vtgate/engine/concatenate.go b/go/vt/vtgate/engine/concatenate.go index 13727124e78..ec353013106 100644 --- a/go/vt/vtgate/engine/concatenate.go +++ b/go/vt/vtgate/engine/concatenate.go @@ -18,8 +18,8 @@ package engine import ( "context" - "slices" "sync" + "sync/atomic" "golang.org/x/sync/errgroup" @@ -40,6 +40,13 @@ type Concatenate struct { // These column offsets do not need to be typed checked - they usually contain weight_string() // columns that are not going to be returned to the user NoNeedToTypeCheck map[int]any + + // the following fields are written to only once, and can then be shared between all users of this plan + typeLoading sync.Once + typesLoaded atomic.Bool + fields []*querypb.Field + fieldTypes []evalengine.Type + typeError error } // NewConcatenate creates a Concatenate primitive. The ignoreCols slice contains the offsets that @@ -96,13 +103,13 @@ func (c *Concatenate) TryExecute(ctx context.Context, vcursor VCursor, bindVars return nil, err } - fields, fieldTypes, err := c.getFieldTypes(vcursor, res) + err = c.loadTypes(vcursor, res) if err != nil { return nil, err } var rows [][]sqltypes.Value - err = c.coerceAndVisitResults(res, fieldTypes, func(result *sqltypes.Result) error { + err = c.coerceAndVisitResults(res, c.fieldTypes, func(result *sqltypes.Result) error { rows = append(rows, result.Rows...) return nil }, evalengine.ParseSQLMode(vcursor.SQLMode())) @@ -111,11 +118,18 @@ func (c *Concatenate) TryExecute(ctx context.Context, vcursor VCursor, bindVars } return &sqltypes.Result{ - Fields: fields, + Fields: c.fields, Rows: rows, }, nil } +func (c *Concatenate) loadTypes(vcursor VCursor, res []*sqltypes.Result) error { + c.typeLoading.Do(func() { + c.getFieldTypes(vcursor, res) + }) + return c.typeError +} + func (c *Concatenate) coerceValuesTo(row sqltypes.Row, fieldTypes []evalengine.Type, sqlmode evalengine.SQLMode) error { if len(row) != len(fieldTypes) { return errWrongNumberOfColumnsInSelect @@ -136,9 +150,9 @@ func (c *Concatenate) coerceValuesTo(row sqltypes.Row, fieldTypes []evalengine.T return nil } -func (c *Concatenate) getFieldTypes(vcursor VCursor, res []*sqltypes.Result) ([]*querypb.Field, []evalengine.Type, error) { +func (c *Concatenate) getFieldTypes(vcursor VCursor, res []*sqltypes.Result) { if len(res) == 0 { - return nil, nil, nil + return } typers := make([]evalengine.TypeAggregator, len(res[0].Fields)) @@ -149,11 +163,13 @@ func (c *Concatenate) getFieldTypes(vcursor VCursor, res []*sqltypes.Result) ([] continue } if len(r.Fields) != len(typers) { - return nil, nil, errWrongNumberOfColumnsInSelect + c.typeError = errWrongNumberOfColumnsInSelect + return } for idx, field := range r.Fields { if err := typers[idx].AddField(field, collations); err != nil { - return nil, nil, err + c.typeError = err + return } } } @@ -173,7 +189,9 @@ func (c *Concatenate) getFieldTypes(vcursor VCursor, res []*sqltypes.Result) ([] fields = append(fields, t.ToField(f.Name)) types = append(types, t) } - return fields, types, nil + c.fields = fields + c.fieldTypes = types + c.typesLoaded.Store(true) } func (c *Concatenate) execSources(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) ([]*sqltypes.Result, error) { @@ -229,7 +247,7 @@ func (c *Concatenate) sequentialExec(ctx context.Context, vcursor VCursor, bindV // TryStreamExecute performs a streaming exec. func (c *Concatenate) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, _ bool, callback func(*sqltypes.Result) error) error { sqlmode := evalengine.ParseSQLMode(vcursor.SQLMode()) - if vcursor.Session().InTransaction() { + if vcursor.Session().InTransaction() || !c.typesLoaded.Load() { // as we are in a transaction, we need to execute all queries inside a single connection, // which holds the single transaction we have return c.sequentialStreamExec(ctx, vcursor, bindVars, callback, sqlmode) @@ -238,82 +256,57 @@ func (c *Concatenate) TryStreamExecute(ctx context.Context, vcursor VCursor, bin return c.parallelStreamExec(ctx, vcursor, bindVars, callback, sqlmode) } -func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, in func(*sqltypes.Result) error, sqlmode evalengine.SQLMode) error { +// parallelStreamExec runs and returns the sub queries in parallel +// it assumes the field types have been loaded +func (c *Concatenate) parallelStreamExec( + inCtx context.Context, + vcursor VCursor, + bindVars map[string]*querypb.BindVariable, + in func(*sqltypes.Result) error, + sqlmode evalengine.SQLMode, +) error { // Scoped context; any early exit triggers cancel() to clean up ongoing work. ctx, cancel := context.WithCancel(inCtx) defer cancel() - // Mutexes for dealing with concurrent access to shared state. - var ( - muCallback sync.Mutex // Protects callback - muFields sync.Mutex // Protects field state - condFields = sync.NewCond(&muFields) // Condition var for field arrival - wg errgroup.Group // Wait group for all streaming goroutines - rest = make([]*sqltypes.Result, len(c.Sources)) // Collects first result from each source to derive fields - fieldTypes []evalengine.Type // Cached final field types - ) - // Process each result chunk, considering type coercion. callback := func(res *sqltypes.Result, srcIdx int) error { - muCallback.Lock() - defer muCallback.Unlock() - + if len(res.Rows) == 0 { + return in(res) + } // Check if type coercion needed for this source. // We only need to check if fields are not in NoNeedToTypeCheck set. needsCoercion := false - for idx, field := range rest[srcIdx].Fields { - _, skip := c.NoNeedToTypeCheck[idx] - if !skip && fieldTypes[idx].Type() != field.Type { - needsCoercion = true - break + if len(res.Fields) < len(c.fieldTypes) { + // if we didn't get enough fields, we'll always coerce + needsCoercion = true + } else { + for idx, field := range c.fieldTypes { + _, skip := c.NoNeedToTypeCheck[idx] + if !skip && field.Type() != res.Fields[idx].Type { + needsCoercion = true + break + } } } // Apply type coercion if needed. + // TODO: we should be able to do this only once as well, and remember if we need coercing here or not if needsCoercion { for _, row := range res.Rows { - if err := c.coerceValuesTo(row, fieldTypes, sqlmode); err != nil { + if err := c.coerceValuesTo(row, c.fieldTypes, sqlmode); err != nil { return err } } } return in(res) } - + var wg errgroup.Group // Start streaming query execution in parallel for all sources. for i, source := range c.Sources { currIndex, currSource := i, source wg.Go(func() error { err := vcursor.StreamExecutePrimitive(ctx, currSource, bindVars, true, func(resultChunk *sqltypes.Result) error { - muFields.Lock() - - // Process fields when they arrive; coordinate field agreement across sources. - if resultChunk.Fields != nil && rest[currIndex] == nil { - // Capture the initial result chunk to determine field types later. - rest[currIndex] = resultChunk - - // If this was the last source to report its fields, derive the final output fields. - if !slices.Contains(rest, nil) { - // We have received fields from all sources. We can now calculate the output types - var err error - resultChunk.Fields, fieldTypes, err = c.getFieldTypes(vcursor, rest) - if err != nil { - muFields.Unlock() - return err - } - - muFields.Unlock() - defer condFields.Broadcast() - return callback(resultChunk, currIndex) - } - } - - // Wait for fields from all sources. - for slices.Contains(rest, nil) { - condFields.Wait() - } - muFields.Unlock() - // Context check to avoid extra work. if ctx.Err() != nil { return nil @@ -323,14 +316,7 @@ func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor, // Error handling and context cleanup for this source. if err != nil { - muFields.Lock() - if rest[currIndex] == nil { - // Signal that this source is done, even if by failure, to unblock field waiting. - rest[currIndex] = &sqltypes.Result{} - } cancel() - condFields.Broadcast() - muFields.Unlock() } return err }) @@ -363,17 +349,19 @@ func (c *Concatenate) sequentialStreamExec(ctx context.Context, vcursor VCursor, } } - firsts := make([]*sqltypes.Result, len(c.Sources)) - for i, result := range results { - firsts[i] = result[0] + c.typeLoading.Do(func() { + firsts := make([]*sqltypes.Result, len(c.Sources)) + for i, result := range results { + firsts[i] = result[0] + } + c.getFieldTypes(vcursor, firsts) + }) + if c.typeError != nil { + return c.typeError } - _, fieldTypes, err := c.getFieldTypes(vcursor, firsts) - if err != nil { - return err - } for _, res := range results { - if err = c.coerceAndVisitResults(res, fieldTypes, callback, sqlmode); err != nil { + if err := c.coerceAndVisitResults(res, c.fieldTypes, callback, sqlmode); err != nil { return err } } diff --git a/go/vt/vtgate/engine/concatenate_test.go b/go/vt/vtgate/engine/concatenate_test.go index dd2b1300e9b..26d45c1c5ad 100644 --- a/go/vt/vtgate/engine/concatenate_test.go +++ b/go/vt/vtgate/engine/concatenate_test.go @@ -21,8 +21,11 @@ import ( "errors" "fmt" "strings" + "sync/atomic" "testing" + querypb "vitess.io/vitess/go/vt/proto/query" + "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/test/utils" @@ -211,3 +214,52 @@ func TestConcatenateTypes(t *testing.T) { }) } } + +func BenchmarkConcatenateTryExecute(b *testing.B) { + fakeSrc1, fakeSrc2, prim := createConcatenateForTest() + ctx := context.Background() + b.ResetTimer() + vcursor := &noopVCursor{} + count := 0 + for i := 0; i < b.N; i++ { + res, err := prim.TryExecute(ctx, vcursor, map[string]*querypb.BindVariable{}, true) + require.NoError(b, err) + count += len(res.Rows) + fakeSrc1.curResult = 0 + fakeSrc2.curResult = 0 + } +} + +func BenchmarkConcatenateTryStreamExecute(b *testing.B) { + fakeSrc1, fakeSrc2, prim := createConcatenateForTest() + ctx := context.Background() + b.ResetTimer() + vcursor := &noopVCursor{} + var count atomic.Int32 + for i := 0; i < b.N; i++ { + err := prim.TryStreamExecute(ctx, vcursor, map[string]*querypb.BindVariable{}, true, func(result *sqltypes.Result) error { + count.Add(int32(len(result.Rows))) + return nil + }) + require.NoError(b, err) + fakeSrc1.curResult = 0 + fakeSrc2.curResult = 0 + } +} + +func createConcatenateForTest() (*fakePrimitive, *fakePrimitive, *Concatenate) { + fake := r("id|col1|col2", "int64|varchar|varbinary", "1|a1|b1", "2|a2|b2") + var rows []string + for x := range 10 { + rows = append(rows, fmt.Sprintf("%d|a%d|b%d", x, x, x)) + } + result := sqltypes.MakeTestResult(fake.Fields, rows...) + fake.Rows = result.Rows + fakeSrc1 := &fakePrimitive{results: []*sqltypes.Result{fake, fake}} + fakeSrc2 := &fakePrimitive{results: []*sqltypes.Result{fake, fake}} + prim := NewConcatenate([]Primitive{ + fakeSrc1, + fakeSrc2, + }, nil) + return fakeSrc1, fakeSrc2, prim +} diff --git a/go/vt/vtgate/engine/simple_concatenate.go b/go/vt/vtgate/engine/simple_concatenate.go new file mode 100644 index 00000000000..4dd6db4367a --- /dev/null +++ b/go/vt/vtgate/engine/simple_concatenate.go @@ -0,0 +1,261 @@ +/* +Copyright 2024 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package engine + +import ( + "context" + "sync" + "sync/atomic" + + "golang.org/x/sync/errgroup" + + "vitess.io/vitess/go/sqltypes" + querypb "vitess.io/vitess/go/vt/proto/query" +) + +// SimpleConcatenate Primitive is used to concatenate results from multiple sources. +// It does no type checking or coercing - it just concatenates results together, assuming +// the inputs are already correctly typed, and it uses the first source for column names +var _ Primitive = (*SimpleConcatenate)(nil) + +type SimpleConcatenate struct { + Sources []Primitive +} + +// NewSimpleConcatenate creates a SimpleConcatenate primitive. +func NewSimpleConcatenate(Sources []Primitive) *SimpleConcatenate { + return &SimpleConcatenate{Sources: Sources} +} + +// RouteType returns a description of the query routing type used by the primitive +func (c *SimpleConcatenate) RouteType() string { + return "SimpleConcatenate" +} + +// GetKeyspaceName specifies the Keyspace that this primitive routes to +func (c *SimpleConcatenate) GetKeyspaceName() string { + res := c.Sources[0].GetKeyspaceName() + for i := 1; i < len(c.Sources); i++ { + res = formatTwoOptionsNicely(res, c.Sources[i].GetKeyspaceName()) + } + return res +} + +// GetTableName specifies the table that this primitive routes to. +func (c *SimpleConcatenate) GetTableName() string { + res := c.Sources[0].GetTableName() + for i := 1; i < len(c.Sources); i++ { + res = formatTwoOptionsNicely(res, c.Sources[i].GetTableName()) + } + return res +} + +// TryExecute performs a non-streaming exec. +func (c *SimpleConcatenate) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, _ bool) (*sqltypes.Result, error) { + if vcursor.Session().InTransaction() { + // as we are in a transaction, we need to execute all queries inside a single transaction + // therefore it needs a sequential execution. + return c.sequentialExec(ctx, vcursor, bindVars) + } + // not in transaction, so execute in parallel. + return c.parallelExec(ctx, vcursor, bindVars) +} + +// TryStreamExecute performs a streaming exec. +func (c *SimpleConcatenate) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, _ bool, callback func(*sqltypes.Result) error) error { + if vcursor.Session().InTransaction() { + // as we are in a transaction, we need to execute all queries inside a single connection, + // which holds the single transaction we have + return c.sequentialStreamExec(ctx, vcursor, bindVars, callback) + } + // not in transaction, so execute in parallel. + return c.parallelStreamExec(ctx, vcursor, bindVars, callback) +} + +func (c *SimpleConcatenate) parallelExec( + ctx context.Context, + vcursor VCursor, + bindVars map[string]*querypb.BindVariable, +) (*sqltypes.Result, error) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + var ( + wg errgroup.Group + rows sync.Mutex + ) + + result := &sqltypes.Result{} + for i, source := range c.Sources { + // the first source will be used to get the fields + // all sources will run in parallel + wg.Go(func() error { + vars := copyBindVars(bindVars) + chunk, err := vcursor.ExecutePrimitive(ctx, source, vars, true) + if err != nil { + cancel() + return err + } + + if i == 0 { + result.Fields = chunk.Fields + } + + if len(chunk.Rows) == 0 { + return nil + } + + rows.Lock() + defer rows.Unlock() + result.Rows = append(result.Rows, chunk.Rows...) + return nil + }) + } + err := wg.Wait() + if err != nil { + return nil, err + } + return result, nil +} + +func (c *SimpleConcatenate) sequentialExec( + ctx context.Context, + vcursor VCursor, + bindVars map[string]*querypb.BindVariable, +) (result *sqltypes.Result, err error) { + for _, src := range c.Sources { + vars := copyBindVars(bindVars) + chunk, err := vcursor.ExecutePrimitive(ctx, src, vars, true) + if err != nil { + return nil, err + } + if result == nil { + result = &sqltypes.Result{ + Fields: chunk.Fields, + SessionStateChanges: chunk.SessionStateChanges, + StatusFlags: chunk.StatusFlags, + Info: chunk.Info, + } + } + result.Rows = append(result.Rows, chunk.Rows...) + } + return +} + +// parallelStreamExec executes the sources in parallel and streams the results. +func (c *SimpleConcatenate) parallelStreamExec( + inCtx context.Context, + vcursor VCursor, + bindVars map[string]*querypb.BindVariable, + callback func(*sqltypes.Result) error, +) error { + // Scoped context; any early exit triggers cancel() to clean up ongoing work. + ctx, cancel := context.WithCancel(inCtx) + defer cancel() + + // Mutex for dealing with concurrent access to shared state. + var ( + muCallback sync.Mutex + wg errgroup.Group + fieldsWg sync.WaitGroup + fields []*querypb.Field + fieldsDone atomic.Bool + ) + + fieldsWg.Add(1) + // Start streaming query execution in parallel for all source + for i, source := range c.Sources { + wg.Go(func() error { + return vcursor.StreamExecutePrimitive(ctx, source, bindVars, true, func(chunk *sqltypes.Result) error { + // Context check to avoid extra work. + if ctx.Err() != nil { + return nil + } + if i == 0 { + // for results coming from the first source, we don't need to block + fieldsAlreadyLoaded := fieldsDone.Swap(true) + if !fieldsAlreadyLoaded { + fields = chunk.Fields + fieldsWg.Done() + } + } else { + fieldsWg.Wait() + chunk.Fields = fields + } + + muCallback.Lock() + defer muCallback.Unlock() + return callback(chunk) + }) + }) + } + // Wait for all sources to complete. + return wg.Wait() +} + +func (c *SimpleConcatenate) sequentialStreamExec( + ctx context.Context, + vcursor VCursor, + bindVars map[string]*querypb.BindVariable, + callback func(*sqltypes.Result) error, +) error { + var fields []*querypb.Field + for i, source := range c.Sources { + err := vcursor.StreamExecutePrimitive(ctx, source, bindVars, true, func(resultChunk *sqltypes.Result) error { + // check if context has expired. + if ctx.Err() != nil { + return ctx.Err() + } + if i == 0 && fields == nil { + fields = resultChunk.Fields + } else { + resultChunk.Fields = fields + } + + return callback(resultChunk) + }) + if err != nil { + return err + } + } + + return nil +} + +// GetFields fetches the field info. +func (c *SimpleConcatenate) GetFields(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) { + return c.Sources[0].GetFields(ctx, vcursor, bindVars) +} + +// NeedsTransaction returns whether a transaction is needed for this primitive +func (c *SimpleConcatenate) NeedsTransaction() bool { + for _, source := range c.Sources { + if source.NeedsTransaction() { + return true + } + } + return false +} + +// Inputs returns the input primitives for this +func (c *SimpleConcatenate) Inputs() ([]Primitive, []map[string]any) { + return c.Sources, nil +} + +func (c *SimpleConcatenate) description() PrimitiveDescription { + return PrimitiveDescription{OperatorType: c.RouteType()} +} diff --git a/go/vt/vtgate/engine/simple_concatenate_test.go b/go/vt/vtgate/engine/simple_concatenate_test.go new file mode 100644 index 00000000000..767214e5995 --- /dev/null +++ b/go/vt/vtgate/engine/simple_concatenate_test.go @@ -0,0 +1,88 @@ +/* +Copyright 2024 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package engine + +import ( + "context" + "fmt" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/require" + + "vitess.io/vitess/go/sqltypes" + querypb "vitess.io/vitess/go/vt/proto/query" +) + +func BenchmarkSimpleConcatenateTryExecute(b *testing.B) { + fakeSrc1, fakeSrc2, prim := createSimpleConcatenateForTest() + ctx := context.Background() + b.ResetTimer() + vcursor := &noopVCursor{} + count := 0 + for i := 0; i < b.N; i++ { + res, err := prim.TryExecute(ctx, vcursor, map[string]*querypb.BindVariable{}, true) + require.NoError(b, err) + count += len(res.Rows) + fakeSrc1.curResult = 0 + fakeSrc2.curResult = 0 + } +} + +func BenchmarkSimpleConcatenateTryStreamExecute(b *testing.B) { + fakeSrc1, fakeSrc2, prim := createSimpleConcatenateForTest() + ctx := context.Background() + b.ResetTimer() + vcursor := &noopVCursor{} + var count atomic.Int32 + for i := 0; i < b.N; i++ { + err := prim.TryStreamExecute(ctx, vcursor, map[string]*querypb.BindVariable{}, true, func(result *sqltypes.Result) error { + count.Add(int32(len(result.Rows))) + return nil + }) + require.NoError(b, err) + fakeSrc1.curResult = 0 + fakeSrc2.curResult = 0 + } +} + +func createSimpleConcatenateForTest() (*fakePrimitive, *fakePrimitive, *SimpleConcatenate) { + fake := r("id|col1|col2", "int64|varchar|varbinary", "1|a1|b1", "2|a2|b2") + var rows []string + for x := range 10 { + rows = append(rows, fmt.Sprintf("%d|a%d|b%d", x, x, x)) + } + result := sqltypes.MakeTestResult(fake.Fields, rows...) + fake.Rows = result.Rows + fakeSrc1 := &fakePrimitive{results: []*sqltypes.Result{fake, fake}} + fakeSrc2 := &fakePrimitive{results: []*sqltypes.Result{fake, fake}} + prim := NewSimpleConcatenate([]Primitive{ + fakeSrc1, + fakeSrc2, + }) + return fakeSrc1, fakeSrc2, prim +} + +func TestName(t *testing.T) { + var x atomic.Bool + for range 100 { + go func() { + old := x.Swap(true) + t.Log(old) + }() + } +} diff --git a/go/vt/vtgate/evalengine/compiler.go b/go/vt/vtgate/evalengine/compiler.go index c0b628b1aa8..0cb8d0f0fbd 100644 --- a/go/vt/vtgate/evalengine/compiler.go +++ b/go/vt/vtgate/evalengine/compiler.go @@ -92,7 +92,7 @@ func NewTypeFromField(f *querypb.Field) Type { } } -func (t *Type) ToField(name string) *querypb.Field { +func (t *Type) SetTypeAndFlags(field *querypb.Field) { // need to get the proper flags for the type; usually leaving flags // to 0 is OK, because Vitess' MySQL client will generate the right // ones for the column's type, but here we're also setting the NotNull @@ -101,15 +101,16 @@ func (t *Type) ToField(name string) *querypb.Field { if !t.nullable { flags |= int64(querypb.MySqlFlag_NOT_NULL_FLAG) } + field.Type = t.typ + field.Charset = uint32(t.collation) + field.ColumnLength = uint32(t.size) + field.Decimals = uint32(t.scale) + field.Flags = uint32(flags) +} - f := &querypb.Field{ - Name: name, - Type: t.typ, - Charset: uint32(t.collation), - ColumnLength: uint32(t.size), - Decimals: uint32(t.scale), - Flags: uint32(flags), - } +func (t *Type) ToField(name string) *querypb.Field { + f := &querypb.Field{Name: name} + t.SetTypeAndFlags(f) return f } diff --git a/go/vt/vtgate/planbuilder/concatenate.go b/go/vt/vtgate/planbuilder/concatenate.go index 81cbe3d5b65..5c512ee4c1b 100644 --- a/go/vt/vtgate/planbuilder/concatenate.go +++ b/go/vt/vtgate/planbuilder/concatenate.go @@ -22,6 +22,7 @@ import ( type concatenate struct { sources []logicalPlan + coerced bool // These column offsets do not need to be typed checked - they usually contain weight_string() // columns that are not going to be returned to the user @@ -37,5 +38,12 @@ func (c *concatenate) Primitive() engine.Primitive { sources = append(sources, source.Primitive()) } + if c.coerced { + // types are already handled, let's use the fast concatenate + return &engine.SimpleConcatenate{ + Sources: sources, + } + } + return engine.NewConcatenate(sources, c.noNeedToTypeCheck) } diff --git a/go/vt/vtgate/planbuilder/filter.go b/go/vt/vtgate/planbuilder/filter.go index c3686380446..45b7f2affba 100644 --- a/go/vt/vtgate/planbuilder/filter.go +++ b/go/vt/vtgate/planbuilder/filter.go @@ -18,6 +18,7 @@ package planbuilder import ( "vitess.io/vitess/go/vt/vtgate/engine" + "vitess.io/vitess/go/vt/vtgate/evalengine" ) type ( @@ -35,3 +36,16 @@ func (l *filter) Primitive() engine.Primitive { l.efilter.Input = l.input.Primitive() return l.efilter } + +type coercePlan struct { + input logicalPlan + columns []*evalengine.Type +} + +func (c *coercePlan) Primitive() engine.Primitive { + src := c.input.Primitive() + return &engine.Coerce{ + Source: src, + Types: c.columns, + } +} diff --git a/go/vt/vtgate/planbuilder/operator_transformers.go b/go/vt/vtgate/planbuilder/operator_transformers.go index f0783a5ecfb..c1a4d4d82ee 100644 --- a/go/vt/vtgate/planbuilder/operator_transformers.go +++ b/go/vt/vtgate/planbuilder/operator_transformers.go @@ -811,7 +811,51 @@ func getAllTableNames(op *operators.Route) ([]string, error) { } func transformUnionPlan(ctx *plancontext.PlanningContext, op *operators.Union) (logicalPlan, error) { - sources, err := slice.MapWithError(op.Sources, func(src operators.Operator) (logicalPlan, error) { + sources, coerced, err := coercedInputs(ctx, op) + if err != nil { + return nil, err + } + + return &concatenate{ + sources: sources, + coerced: coerced, + }, nil +} + +func typeForExpr(ctx *plancontext.PlanningContext, e sqlparser.Expr) (evalengine.Type, bool) { + if typ, found := ctx.SemTable.TypeForExpr(e); found { + return typ, true + } + + cfg := &evalengine.Config{ + ResolveColumn: func(name *sqlparser.ColName) (int, error) { + return 0, nil // we are not going to use these for anything other than getting the type + }, + ResolveType: func(expr sqlparser.Expr) (evalengine.Type, bool) { + return ctx.SemTable.TypeForExpr(e) + }, + Collation: ctx.SemTable.Collation, + Environment: ctx.VSchema.Environment(), + } + evalExpr, err := evalengine.Translate(e, cfg) + if err != nil { + return evalengine.Type{}, false + } + env := evalengine.ExpressionEnv{ + BindVars: nil, + Row: nil, + Fields: nil, + } + typ, err := env.TypeOf(evalExpr) + if err != nil { + return evalengine.Type{}, false + } + ctx.SemTable.ExprTypes[e] = typ + return typ, true +} + +func coercedInputs(ctx *plancontext.PlanningContext, op *operators.Union) ([]logicalPlan, bool, error) { + orgSources, err := slice.MapWithError(op.Sources, func(src operators.Operator) (logicalPlan, error) { plan, err := transformToLogicalPlan(ctx, src) if err != nil { return nil, err @@ -819,17 +863,57 @@ func transformUnionPlan(ctx *plancontext.PlanningContext, op *operators.Union) ( return plan, nil }) if err != nil { - return nil, err + return nil, false, err + } + collationEnv := ctx.VSchema.Environment().CollationEnv() + typers := make([]evalengine.TypeAggregator, len(op.Sources[0].GetColumns(ctx))) + for _, src := range op.Sources { + cols := src.GetColumns(ctx) + for idx, col := range cols { + typ, found := typeForExpr(ctx, col.Expr) + if !found { + return orgSources, false, nil + } + err := typers[idx].Add(typ, collationEnv) + if err != nil { + // let's ignore this and just return the + return orgSources, false, nil + } + } } - if len(sources) == 1 { - return sources[0], nil + newSources := make([]logicalPlan, 0, len(orgSources)) + for srcIdx, src := range op.Sources { + cols := src.GetColumns(ctx) + coerceTypes := make([]*evalengine.Type, len(cols)) + coerce := false + for colIdx, col := range cols { + typ, found := ctx.SemTable.TypeForExpr(col.Expr) + if !found { + return orgSources, false, nil + } + resultType := typers[colIdx].Type() + if resultType.Type() == sqltypes.Unknown { + // if the resulting type is a null type, the type aggregator probably messed up the + // type calculus, and we can't trust these types + return orgSources, false, nil + } + if resultType != typ { + coerceTypes[colIdx] = &resultType + coerce = true + } + } + if coerce { + newSources = append(newSources, &coercePlan{ + input: orgSources[srcIdx], + columns: coerceTypes, + }) + } else { + newSources = append(newSources, orgSources[srcIdx]) + } } - return &concatenate{ - sources: sources, - noNeedToTypeCheck: nil, - }, nil + return newSources, true, nil } func transformLimit(ctx *plancontext.PlanningContext, op *operators.Limit) (logicalPlan, error) { diff --git a/go/vt/vtgate/planbuilder/testdata/cte_cases.json b/go/vt/vtgate/planbuilder/testdata/cte_cases.json index 0d7d9020ac2..80d3ef6e57a 100644 --- a/go/vt/vtgate/planbuilder/testdata/cte_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/cte_cases.json @@ -1901,7 +1901,7 @@ ], "Inputs": [ { - "OperatorType": "Concatenate", + "OperatorType": "SimpleConcatenate", "Inputs": [ { "OperatorType": "Projection", diff --git a/go/vt/vtgate/planbuilder/testdata/info_schema57_cases.json b/go/vt/vtgate/planbuilder/testdata/info_schema57_cases.json index 1727e372490..0a11d83dc02 100644 --- a/go/vt/vtgate/planbuilder/testdata/info_schema57_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/info_schema57_cases.json @@ -107,7 +107,7 @@ ], "Inputs": [ { - "OperatorType": "Concatenate", + "OperatorType": "SimpleConcatenate", "Inputs": [ { "OperatorType": "Route", @@ -170,7 +170,7 @@ ], "Inputs": [ { - "OperatorType": "Concatenate", + "OperatorType": "SimpleConcatenate", "Inputs": [ { "OperatorType": "Route", @@ -788,7 +788,7 @@ "Aggregates": "sum(0) AS sum(found)", "Inputs": [ { - "OperatorType": "Concatenate", + "OperatorType": "SimpleConcatenate", "Inputs": [ { "OperatorType": "Route", @@ -827,7 +827,7 @@ "QueryType": "SELECT", "Original": "select found from (select 1 as found from information_schema.`tables` where table_schema = 'music' union all (select 1 as found from information_schema.views where table_schema = 'music' limit 1)) as t", "Instructions": { - "OperatorType": "Concatenate", + "OperatorType": "SimpleConcatenate", "Inputs": [ { "OperatorType": "Route", @@ -864,7 +864,7 @@ "QueryType": "SELECT", "Original": "select 1 as found from information_schema.`tables` where table_schema = 'music' and table_schema = 'Music' union all (select 1 as found from information_schema.views where table_schema = 'music' and table_schema = 'user' limit 1)", "Instructions": { - "OperatorType": "Concatenate", + "OperatorType": "SimpleConcatenate", "Inputs": [ { "OperatorType": "Route", @@ -901,7 +901,7 @@ "QueryType": "SELECT", "Original": "select 1 as found from information_schema.`tables` where table_schema = 'music' and table_schema = 'Music' union all (select 1 as found from information_schema.views where table_schema = 'music' and table_schema = 'user' limit 1)", "Instructions": { - "OperatorType": "Concatenate", + "OperatorType": "SimpleConcatenate", "Inputs": [ { "OperatorType": "Route", @@ -946,7 +946,7 @@ "Inputs": [ { "InputName": "SubQuery", - "OperatorType": "Concatenate", + "OperatorType": "SimpleConcatenate", "Inputs": [ { "OperatorType": "Route", @@ -1004,7 +1004,7 @@ ], "Inputs": [ { - "OperatorType": "Concatenate", + "OperatorType": "SimpleConcatenate", "Inputs": [ { "OperatorType": "Route", @@ -1057,7 +1057,7 @@ ], "Inputs": [ { - "OperatorType": "Concatenate", + "OperatorType": "SimpleConcatenate", "Inputs": [ { "OperatorType": "Route", diff --git a/go/vt/vtgate/planbuilder/testdata/info_schema80_cases.json b/go/vt/vtgate/planbuilder/testdata/info_schema80_cases.json index 3e15c19abc5..d6b7dd69b70 100644 --- a/go/vt/vtgate/planbuilder/testdata/info_schema80_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/info_schema80_cases.json @@ -107,7 +107,7 @@ ], "Inputs": [ { - "OperatorType": "Concatenate", + "OperatorType": "SimpleConcatenate", "Inputs": [ { "OperatorType": "Route", @@ -853,7 +853,7 @@ "Aggregates": "sum(0) AS sum(found)", "Inputs": [ { - "OperatorType": "Concatenate", + "OperatorType": "SimpleConcatenate", "Inputs": [ { "OperatorType": "Route", @@ -892,7 +892,7 @@ "QueryType": "SELECT", "Original": "select found from (select 1 as found from information_schema.`tables` where table_schema = 'music' union all (select 1 as found from information_schema.views where table_schema = 'music' limit 1)) as t", "Instructions": { - "OperatorType": "Concatenate", + "OperatorType": "SimpleConcatenate", "Inputs": [ { "OperatorType": "Route", @@ -929,7 +929,7 @@ "QueryType": "SELECT", "Original": "select 1 as found from information_schema.`tables` where table_schema = 'music' and table_schema = 'Music' union all (select 1 as found from information_schema.views where table_schema = 'music' and table_schema = 'user' limit 1)", "Instructions": { - "OperatorType": "Concatenate", + "OperatorType": "SimpleConcatenate", "Inputs": [ { "OperatorType": "Route", @@ -966,7 +966,7 @@ "QueryType": "SELECT", "Original": "select 1 as found from information_schema.`tables` where table_schema = 'music' and table_schema = 'Music' union all (select 1 as found from information_schema.views where table_schema = 'music' and table_schema = 'user' limit 1)", "Instructions": { - "OperatorType": "Concatenate", + "OperatorType": "SimpleConcatenate", "Inputs": [ { "OperatorType": "Route", @@ -1011,7 +1011,7 @@ "Inputs": [ { "InputName": "SubQuery", - "OperatorType": "Concatenate", + "OperatorType": "SimpleConcatenate", "Inputs": [ { "OperatorType": "Route", @@ -1069,7 +1069,7 @@ ], "Inputs": [ { - "OperatorType": "Concatenate", + "OperatorType": "SimpleConcatenate", "Inputs": [ { "OperatorType": "Route", @@ -1122,7 +1122,7 @@ ], "Inputs": [ { - "OperatorType": "Concatenate", + "OperatorType": "SimpleConcatenate", "Inputs": [ { "OperatorType": "Route", diff --git a/go/vt/vtgate/planbuilder/testdata/union_cases.json b/go/vt/vtgate/planbuilder/testdata/union_cases.json index cdbd368478f..87585d6e305 100644 --- a/go/vt/vtgate/planbuilder/testdata/union_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/union_cases.json @@ -443,7 +443,7 @@ ], "Inputs": [ { - "OperatorType": "Concatenate", + "OperatorType": "SimpleConcatenate", "Inputs": [ { "OperatorType": "Route", @@ -696,7 +696,7 @@ ], "Inputs": [ { - "OperatorType": "Concatenate", + "OperatorType": "SimpleConcatenate", "Inputs": [ { "OperatorType": "Aggregate", @@ -1187,7 +1187,7 @@ ], "Inputs": [ { - "OperatorType": "Concatenate", + "OperatorType": "SimpleConcatenate", "Inputs": [ { "OperatorType": "Route", @@ -1278,7 +1278,7 @@ ], "Inputs": [ { - "OperatorType": "Concatenate", + "OperatorType": "SimpleConcatenate", "Inputs": [ { "OperatorType": "Route",