From c4585d7217001170a96eb2d03a3144d8ce309a8e Mon Sep 17 00:00:00 2001 From: Manan Gupta Date: Tue, 20 Aug 2024 18:04:03 +0530 Subject: [PATCH] feat: fix engine and add tests Signed-off-by: Manan Gupta --- go/vt/vtgate/engine/fake_primitive_test.go | 10 +++ go/vt/vtgate/engine/fake_vcursor_test.go | 8 +- go/vt/vtgate/engine/timeout_handler.go | 4 +- go/vt/vtgate/engine/timeout_handler_test.go | 85 +++++++++++++++++++++ 4 files changed, 103 insertions(+), 4 deletions(-) create mode 100644 go/vt/vtgate/engine/timeout_handler_test.go diff --git a/go/vt/vtgate/engine/fake_primitive_test.go b/go/vt/vtgate/engine/fake_primitive_test.go index 6ab54fe9e7b..20585be7c88 100644 --- a/go/vt/vtgate/engine/fake_primitive_test.go +++ b/go/vt/vtgate/engine/fake_primitive_test.go @@ -22,6 +22,7 @@ import ( "reflect" "strings" "testing" + "time" "golang.org/x/sync/errgroup" @@ -41,6 +42,9 @@ type fakePrimitive struct { log []string + // sleepTime is the time for which the fake primitive sleeps before returning the results. + sleepTime time.Duration + allResultsInOneCall bool async bool @@ -71,6 +75,9 @@ func (f *fakePrimitive) GetTableName() string { func (f *fakePrimitive) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { f.log = append(f.log, fmt.Sprintf("Execute %v %v", printBindVars(bindVars), wantfields)) + if f.sleepTime != 0 { + time.Sleep(f.sleepTime) + } if f.results == nil { return nil, f.sendErr } @@ -85,6 +92,9 @@ func (f *fakePrimitive) TryExecute(ctx context.Context, vcursor VCursor, bindVar func (f *fakePrimitive) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { f.log = append(f.log, fmt.Sprintf("StreamExecute %v %v", printBindVars(bindVars), wantfields)) + if f.sleepTime != 0 { + time.Sleep(f.sleepTime) + } if f.results == nil { return f.sendErr } diff --git a/go/vt/vtgate/engine/fake_vcursor_test.go b/go/vt/vtgate/engine/fake_vcursor_test.go index 5458a384490..3d0bc9e35f1 100644 --- a/go/vt/vtgate/engine/fake_vcursor_test.go +++ b/go/vt/vtgate/engine/fake_vcursor_test.go @@ -54,7 +54,8 @@ var _ SessionActions = (*noopVCursor)(nil) // noopVCursor is used to build other vcursors. type noopVCursor struct { - inTx bool + inTx bool + queryTimeout int } // MySQLVersion implements VCursor. @@ -298,7 +299,10 @@ func (t *noopVCursor) SetQueryTimeout(maxExecutionTime int64) { } func (t *noopVCursor) GetQueryTimeout(queryTimeoutFromComments int) int { - return queryTimeoutFromComments + if queryTimeoutFromComments != 0 { + return queryTimeoutFromComments + } + return t.queryTimeout } func (t *noopVCursor) SetSkipQueryPlanCache(context.Context, bool) error { diff --git a/go/vt/vtgate/engine/timeout_handler.go b/go/vt/vtgate/engine/timeout_handler.go index 1fc919b8475..e62525c3b5e 100644 --- a/go/vt/vtgate/engine/timeout_handler.go +++ b/go/vt/vtgate/engine/timeout_handler.go @@ -54,7 +54,7 @@ func (t *TimeoutHandler) TryExecute(ctx context.Context, vcursor VCursor, bindVa ctx, cancel := addQueryTimeout(ctx, vcursor, t.Timeout) defer cancel() - var complete chan any + complete := make(chan any) go func() { res, err = t.Input.TryExecute(ctx, vcursor, bindVars, wantfields) close(complete) @@ -73,7 +73,7 @@ func (t *TimeoutHandler) TryStreamExecute(ctx context.Context, vcursor VCursor, ctx, cancel := addQueryTimeout(ctx, vcursor, t.Timeout) defer cancel() - var complete chan any + complete := make(chan any) go func() { err = t.Input.TryStreamExecute(ctx, vcursor, bindVars, wantfields, callback) close(complete) diff --git a/go/vt/vtgate/engine/timeout_handler_test.go b/go/vt/vtgate/engine/timeout_handler_test.go new file mode 100644 index 00000000000..9fcf562d372 --- /dev/null +++ b/go/vt/vtgate/engine/timeout_handler_test.go @@ -0,0 +1,85 @@ +package engine + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "vitess.io/vitess/go/sqltypes" +) + +// TestTimeoutHandler tests timeout handler primitive. +func TestTimeoutHandler(t *testing.T) { + tests := []struct { + name string + input *TimeoutHandler + vc VCursor + wantErr string + }{ + { + name: "No timeout", + input: NewTimeoutHandler(&fakePrimitive{ + results: nil, + sleepTime: 100 * time.Millisecond, + }, 0), + vc: &noopVCursor{}, + wantErr: "", + }, { + name: "Timeout without failure", + input: NewTimeoutHandler(&fakePrimitive{ + results: nil, + sleepTime: 100 * time.Millisecond, + }, 1000), + vc: &noopVCursor{}, + wantErr: "", + }, { + name: "Timeout in session", + input: NewTimeoutHandler(&fakePrimitive{ + results: nil, + sleepTime: 2 * time.Second, + }, 0), + vc: &noopVCursor{ + queryTimeout: 100, + }, + wantErr: "VT15001: Query execution was interrupted, maximum statement execution time exceeded", + }, { + name: "Timeout in comments", + input: NewTimeoutHandler(&fakePrimitive{ + results: nil, + sleepTime: 2 * time.Second, + }, 100), + vc: &noopVCursor{}, + wantErr: "VT15001: Query execution was interrupted, maximum statement execution time exceeded", + }, { + name: "Timeout in both", + input: NewTimeoutHandler(&fakePrimitive{ + results: nil, + sleepTime: 2 * time.Second, + }, 100), + vc: &noopVCursor{ + queryTimeout: 4000, + }, + wantErr: "VT15001: Query execution was interrupted, maximum statement execution time exceeded", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := tt.input.TryExecute(context.Background(), tt.vc, nil, false) + if tt.wantErr != "" { + require.EqualError(t, err, tt.wantErr) + } else { + require.NoError(t, err) + } + err = tt.input.TryStreamExecute(context.Background(), tt.vc, nil, false, func(result *sqltypes.Result) error { + return nil + }) + if tt.wantErr != "" { + require.EqualError(t, err, tt.wantErr) + } else { + require.NoError(t, err) + } + }) + } +}