From f7921dbfca393582e1d1ab1e94192ded8226ffad Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Thu, 19 Dec 2024 08:24:40 +0100 Subject: [PATCH] feat: clean up limit code Signed-off-by: Andres Taylor --- go/vt/vtgate/engine/limit.go | 25 ++++++-- go/vt/vtgate/engine/limit_test.go | 95 +++++++++++++++++++++++++++++++ 2 files changed, 114 insertions(+), 6 deletions(-) diff --git a/go/vt/vtgate/engine/limit.go b/go/vt/vtgate/engine/limit.go index 824689d2859..a142fc8274c 100644 --- a/go/vt/vtgate/engine/limit.go +++ b/go/vt/vtgate/engine/limit.go @@ -89,6 +89,10 @@ func (l *Limit) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[st return result, nil } +func (l *Limit) mustRetrieveAll(vcursor VCursor) bool { + return l.RequireCompleteInput || vcursor.Session().InTransaction() +} + // TryStreamExecute satisfies the Primitive interface. func (l *Limit) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { count, offset, err := l.getCountAndOffset(ctx, vcursor, bindVars) @@ -107,22 +111,31 @@ func (l *Limit) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars mu.Lock() defer mu.Unlock() + inputSize := len(qr.Rows) + if inputSize == 0 { + if wantfields && len(qr.Fields) != 0 { + wantfields = false + } + return callback(qr) + } + // If this is the first callback and fields are requested, send the fields immediately. if wantfields && len(qr.Fields) != 0 { + wantfields = false + // otherwise, we need to send the fields first, and then the rows if err := callback(&sqltypes.Result{Fields: qr.Fields}); err != nil { return err } } - inputSize := len(qr.Rows) - if inputSize == 0 { - return callback(qr) - } // If we still need to skip `offset` rows before returning any to the client: if offset > 0 { if inputSize <= offset { // not enough to return anything yet, but we still want to pass on metadata such as last_insert_id offset -= inputSize + if !l.mustRetrieveAll(vcursor) { + return nil + } qr.Rows = nil return callback(qr) } @@ -134,7 +147,7 @@ func (l *Limit) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars // At this point, we've dealt with the offset. Now handle the count (limit). if count == 0 { // If count is zero, we've fetched everything we need. - if !l.RequireCompleteInput && !vcursor.Session().InTransaction() { + if !l.mustRetrieveAll(vcursor) { return io.EOF } @@ -159,7 +172,7 @@ func (l *Limit) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars // If we required complete input or are in a transaction, we must not exit early. // We'll return empty batches until the input is done. - if l.RequireCompleteInput || vcursor.Session().InTransaction() { + if l.mustRetrieveAll(vcursor) { return nil } diff --git a/go/vt/vtgate/engine/limit_test.go b/go/vt/vtgate/engine/limit_test.go index 8b91dadecb5..4f24141c8da 100644 --- a/go/vt/vtgate/engine/limit_test.go +++ b/go/vt/vtgate/engine/limit_test.go @@ -353,6 +353,101 @@ func TestLimitOffsetExecute(t *testing.T) { t.Errorf("l.Execute:\n got %v, want\n%v", result, wantResult) } } +func TestLimitStreamExecut2e(t *testing.T) { + fields := sqltypes.MakeTestFields( + "col1|col2", + "int64|varchar", + ) + inputResult := sqltypes.MakeTestResult( + fields, + "a|1", + "b|2", + "c|3", + ) + + tests := []struct { + name string + countExpr evalengine.Expr + bindVars map[string]*querypb.BindVariable + want []*sqltypes.Result + RequireCompleteInput bool + }{{ + name: "limit smaller than input (literal)", + countExpr: evalengine.NewLiteralInt(2), + want: sqltypes.MakeTestStreamingResults( + fields, + "a|1", + "b|2", + ), + }, { + name: "limit smaller than input (literal) - require complete input", + countExpr: evalengine.NewLiteralInt(2), + RequireCompleteInput: true, + want: sqltypes.MakeTestStreamingResults( + fields, + "a|1", + "b|2", + "---", // this extra result is required by RequireCompleteInput + ), + }, { + name: "limit smaller than input (bind var)", + countExpr: evalengine.NewBindVar("l", evalengine.NewType(sqltypes.Int64, collations.CollationBinaryID)), + bindVars: map[string]*querypb.BindVariable{"l": sqltypes.Int64BindVariable(2)}, + want: sqltypes.MakeTestStreamingResults( + fields, + "a|1", + "b|2", + ), + }, { + name: "limit equal to input", + countExpr: evalengine.NewLiteralInt(3), + want: sqltypes.MakeTestStreamingResults( + fields, + "a|1", + "b|2", + "---", + "c|3", + ), + }, { + name: "limit higher than input", + countExpr: evalengine.NewLiteralInt(4), + // same as limit=3 + want: sqltypes.MakeTestStreamingResults( + fields, + "a|1", + "b|2", + "---", + "c|3", + ), + }} + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fp := &fakePrimitive{ + results: []*sqltypes.Result{inputResult}, + } + + l := &Limit{ + Count: tt.countExpr, + RequireCompleteInput: tt.RequireCompleteInput, + Input: fp, + } + + var results []*sqltypes.Result + err := l.TryStreamExecute(context.Background(), &noopVCursor{}, tt.bindVars, true, func(qr *sqltypes.Result) error { + results = append(results, qr) + return nil + }) + require.NoError(t, err) + require.Len(t, results, len(tt.want)) + for i, result := range results { + if !result.Equal(tt.want[i]) { + t.Errorf("l.StreamExecute:\n%s, want\n%s", sqltypes.PrintResults(results), sqltypes.PrintResults(tt.want)) + } + } + }) + } +} func TestLimitStreamExecute(t *testing.T) { bindVars := make(map[string]*querypb.BindVariable)