From 9dc7651435cf57e106b1bca93765e4d0a1dd11c6 Mon Sep 17 00:00:00 2001 From: Harshit Gangal Date: Thu, 19 Dec 2024 17:24:29 +0530 Subject: [PATCH] fix: merge sort for receiving fields in multiple packets Signed-off-by: Harshit Gangal --- go/vt/vtgate/engine/limit.go | 30 ++++++++----------- go/vt/vtgate/engine/merge_sort.go | 14 ++++----- go/vt/vtgate/executor_select_test.go | 43 ++++++++++++++++++++++++++++ 3 files changed, 62 insertions(+), 25 deletions(-) diff --git a/go/vt/vtgate/engine/limit.go b/go/vt/vtgate/engine/limit.go index a142fc8274c..56cf0ab87eb 100644 --- a/go/vt/vtgate/engine/limit.go +++ b/go/vt/vtgate/engine/limit.go @@ -112,30 +112,17 @@ func (l *Limit) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars defer mu.Unlock() inputSize := len(qr.Rows) - if inputSize == 0 { - if wantfields && len(qr.Fields) != 0 { - wantfields = false - } - return callback(qr) - } - - // If this is the first callback and fields are requested, send the fields immediately. - if wantfields && len(qr.Fields) != 0 { - wantfields = false - // otherwise, we need to send the fields first, and then the rows - if err := callback(&sqltypes.Result{Fields: qr.Fields}); err != nil { - return err - } - } - // If we still need to skip `offset` rows before returning any to the client: if offset > 0 { if inputSize <= offset { // not enough to return anything yet, but we still want to pass on metadata such as last_insert_id offset -= inputSize - if !l.mustRetrieveAll(vcursor) { + if !wantfields && !l.mustRetrieveAll(vcursor) { return nil } + if len(qr.Fields) > 0 { + wantfields = false + } qr.Rows = nil return callback(qr) } @@ -147,9 +134,12 @@ func (l *Limit) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars // At this point, we've dealt with the offset. Now handle the count (limit). if count == 0 { // If count is zero, we've fetched everything we need. - if !l.mustRetrieveAll(vcursor) { + if !wantfields && !l.mustRetrieveAll(vcursor) { return io.EOF } + if len(qr.Fields) > 0 { + wantfields = false + } // If we require the complete input, or we are in a transaction, we cannot return io.EOF early. // Instead, we return empty results as needed until input ends. @@ -157,6 +147,10 @@ func (l *Limit) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars return callback(qr) } + if len(qr.Fields) > 0 { + wantfields = false + } + // reduce count till 0. resultSize := len(qr.Rows) if count > resultSize { diff --git a/go/vt/vtgate/engine/merge_sort.go b/go/vt/vtgate/engine/merge_sort.go index fac57c37ccb..8678ff49241 100644 --- a/go/vt/vtgate/engine/merge_sort.go +++ b/go/vt/vtgate/engine/merge_sort.go @@ -21,13 +21,11 @@ import ( "io" "vitess.io/vitess/go/mysql/sqlerror" - "vitess.io/vitess/go/vt/vtgate/evalengine" - "vitess.io/vitess/go/sqltypes" - querypb "vitess.io/vitess/go/vt/proto/query" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/vtgate/evalengine" ) // StreamExecutor is a subset of Primitive that MergeSort @@ -216,9 +214,10 @@ func (ms *MergeSort) description() PrimitiveDescription { // routine that pulls the rows out of each streamHandle can abort the stream // by calling canceling the context. type streamHandle struct { - fields chan []*querypb.Field - row chan []sqltypes.Value - err error + fields chan []*querypb.Field + fieldSeen bool + row chan []sqltypes.Value + err error } // runOnestream starts a streaming query on one shard, and returns a streamHandle for it. @@ -233,7 +232,8 @@ func runOneStream(ctx context.Context, vcursor VCursor, input StreamExecutor, bi defer close(handle.row) handle.err = input.StreamExecute(ctx, vcursor, bindVars, wantfields, func(qr *sqltypes.Result) error { - if len(qr.Fields) != 0 { + if !handle.fieldSeen && len(qr.Fields) != 0 { + handle.fieldSeen = true select { case handle.fields <- qr.Fields: case <-ctx.Done(): diff --git a/go/vt/vtgate/executor_select_test.go b/go/vt/vtgate/executor_select_test.go index 86aafaefba4..411f19bb30d 100644 --- a/go/vt/vtgate/executor_select_test.go +++ b/go/vt/vtgate/executor_select_test.go @@ -3302,6 +3302,49 @@ func TestSelectFromInformationSchema(t *testing.T) { sbc1.StringQueries()) } +func TestStreamOrderByWithMultipleResults(t *testing.T) { + ctx := utils.LeakCheckContext(t) + + // Special setup: Don't use createExecutorEnv. + cell := "aa" + hc := discovery.NewFakeHealthCheck(nil) + u := createSandbox(KsTestUnsharded) + s := createSandbox(KsTestSharded) + s.VSchema = executorVSchema + u.VSchema = unshardedVSchema + serv := newSandboxForCells(ctx, []string{cell}) + resolver := newTestResolver(ctx, hc, serv, cell) + shards := []string{"-20", "20-40", "40-60", "60-80", "80-a0", "a0-c0", "c0-e0", "e0-"} + count := 1 + for _, shard := range shards { + sbc := hc.AddTestTablet(cell, shard, 1, "TestExecutor", shard, topodatapb.TabletType_PRIMARY, true, 1, nil) + sbc.SetResults([]*sqltypes.Result{ + sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col|weight_string(id)", "int32|int32|varchar"), fmt.Sprintf("%d|%d|NULL", count, count)), + sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col|weight_string(id)", "int32|int32|varchar"), fmt.Sprintf("%d|%d|NULL", count+10, count)), + }) + count++ + } + queryLogger := streamlog.New[*logstats.LogStats]("VTGate", queryLogBufferSize) + plans := DefaultPlanCache() + executor := NewExecutor(ctx, vtenv.NewTestEnv(), serv, cell, resolver, true, false, testBufferSize, plans, nil, false, querypb.ExecuteOptions_Gen4, 0) + executor.SetQueryLogger(queryLogger) + defer executor.Close() + // some sleep for all goroutines to start + time.Sleep(100 * time.Millisecond) + before := runtime.NumGoroutine() + + query := "select id, col from user order by id" + gotResult, err := executorStream(ctx, executor, query) + require.NoError(t, err) + + wantResult := sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col", "int32|int32"), + "1|1", "2|2", "3|3", "4|4", "5|5", "6|6", "7|7", "8|8", "11|1", "12|2", "13|3", "14|4", "15|5", "16|6", "17|7", "18|8") + assert.Equal(t, fmt.Sprintf("%v", wantResult.Rows), fmt.Sprintf("%v", gotResult.Rows)) + // some sleep to close all goroutines. + time.Sleep(100 * time.Millisecond) + assert.GreaterOrEqual(t, before, runtime.NumGoroutine(), "left open goroutines lingering") +} + func TestStreamOrderByLimitWithMultipleResults(t *testing.T) { ctx := utils.LeakCheckContext(t)