diff --git a/go/vt/sqlparser/comments.go b/go/vt/sqlparser/comments.go index 780f1e67594..dff6f60e531 100644 --- a/go/vt/sqlparser/comments.go +++ b/go/vt/sqlparser/comments.go @@ -554,27 +554,6 @@ func AllowScatterDirective(stmt Statement) bool { return checkDirective(stmt, DirectiveAllowScatter) } -// ForeignKeyChecksState returns the state of foreign_key_checks variable if it is part of a SET_VAR optimizer hint in the comments. -func ForeignKeyChecksState(stmt Statement) *bool { - cmt, ok := stmt.(Commented) - if ok { - fkChecksVal := cmt.GetParsedComments().GetMySQLSetVarValue(sysvars.ForeignKeyChecks) - // If the value of the `foreign_key_checks` optimizer hint is something that doesn't make sense, - // then MySQL just ignores it and treats it like the case, where it is unspecified. We are choosing - // to have the same behaviour here. If the value doesn't match any of the acceptable values, we return nil, - // that signifies that no value was specified. - switch strings.ToLower(fkChecksVal) { - case "on", "1": - fkState := true - return &fkState - case "off", "0": - fkState := false - return &fkState - } - } - return nil -} - func checkDirective(stmt Statement, key string) bool { cmt, ok := stmt.(Commented) if ok { @@ -583,42 +562,43 @@ func checkDirective(stmt Statement, key string) bool { return false } -// GetPriorityFromStatement gets the priority from the provided Statement, using DirectivePriority -func GetPriorityFromStatement(statement Statement) (string, error) { - commentedStatement, ok := statement.(Commented) - // This would mean that the statement lacks comments, so we can't obtain the workload from it. Hence default to - // empty priority +type QueryHints struct { + IgnoreMaxMemoryRows bool + Consolidator querypb.ExecuteOptions_Consolidator + Workload string + ForeignKeyChecks *bool + Priority string + Timeout *int +} + +func BuildQueryHints(stmt Statement) (qh QueryHints, err error) { + qh = QueryHints{} + + comment, ok := stmt.(Commented) if !ok { - return "", nil + return qh, nil } - directives := commentedStatement.GetParsedComments().Directives() - priority, ok := directives.GetString(DirectivePriority, "") - if !ok || priority == "" { - return "", nil - } + directives := comment.GetParsedComments().Directives() - intPriority, err := strconv.Atoi(priority) - if err != nil || intPriority < 0 || intPriority > MaxPriorityValue { - return "", ErrInvalidPriority + qh.Priority, err = getPriority(directives) + if err != nil { + return qh, err } + qh.IgnoreMaxMemoryRows = directives.IsSet(DirectiveIgnoreMaxMemoryRows) + qh.Consolidator = getConsolidator(stmt, directives) + qh.Workload = getWorkload(directives) + qh.ForeignKeyChecks = getForeignKeyChecksState(comment) + qh.Timeout = getQueryTimeout(directives) - return priority, nil + return qh, nil } -// Consolidator returns the consolidator option. -func Consolidator(stmt Statement) querypb.ExecuteOptions_Consolidator { - var comments *ParsedComments - switch stmt := stmt.(type) { - case *Select: - comments = stmt.Comments - default: - return querypb.ExecuteOptions_CONSOLIDATOR_UNSPECIFIED - } - if comments == nil { +// getConsolidator returns the consolidator option. +func getConsolidator(stmt Statement, directives *CommentDirectives) querypb.ExecuteOptions_Consolidator { + if _, isSelect := stmt.(SelectStatement); !isSelect { return querypb.ExecuteOptions_CONSOLIDATOR_UNSPECIFIED } - directives := comments.Directives() strv, isSet := directives.GetString(DirectiveConsolidator, "") if !isSet { return querypb.ExecuteOptions_CONSOLIDATOR_UNSPECIFIED @@ -629,18 +609,56 @@ func Consolidator(stmt Statement) querypb.ExecuteOptions_Consolidator { return querypb.ExecuteOptions_CONSOLIDATOR_UNSPECIFIED } -// GetWorkloadNameFromStatement gets the workload name from the provided Statement, using workloadLabel as the name of +// getWorkload gets the workload name from the provided Statement, using workloadLabel as the name of // the query directive that specifies it. -func GetWorkloadNameFromStatement(statement Statement) string { - commentedStatement, ok := statement.(Commented) - // This would mean that the statement lacks comments, so we can't obtain the workload from it. Hence default to - // empty workload name - if !ok { - return "" +func getWorkload(directives *CommentDirectives) string { + workloadName, _ := directives.GetString(DirectiveWorkloadName, "") + return workloadName +} + +// getForeignKeyChecksState returns the state of foreign_key_checks variable if it is part of a SET_VAR optimizer hint in the comments. +func getForeignKeyChecksState(cmt Commented) *bool { + fkChecksVal := cmt.GetParsedComments().GetMySQLSetVarValue(sysvars.ForeignKeyChecks) + // If the value of the `foreign_key_checks` optimizer hint is something that doesn't make sense, + // then MySQL just ignores it and treats it like the case, where it is unspecified. We are choosing + // to have the same behaviour here. If the value doesn't match any of the acceptable values, we return nil, + // that signifies that no value was specified. + switch strings.ToLower(fkChecksVal) { + case "on", "1": + fkState := true + return &fkState + case "off", "0": + fkState := false + return &fkState } + return nil +} - directives := commentedStatement.GetParsedComments().Directives() - workloadName, _ := directives.GetString(DirectiveWorkloadName, "") +// getPriority gets the priority from the provided Statement, using DirectivePriority +func getPriority(directives *CommentDirectives) (string, error) { + priority, ok := directives.GetString(DirectivePriority, "") + if !ok || priority == "" { + return "", nil + } - return workloadName + intPriority, err := strconv.Atoi(priority) + if err != nil || intPriority < 0 || intPriority > MaxPriorityValue { + return "", ErrInvalidPriority + } + + return priority, nil +} + +// getQueryTimeout gets the query timeout from the provided Statement, using DirectiveQueryTimeout +func getQueryTimeout(directives *CommentDirectives) *int { + timeoutString, ok := directives.GetString(DirectiveQueryTimeout, "") + if !ok || timeoutString == "" { + return nil + } + + timeout, err := strconv.Atoi(timeoutString) + if err != nil || timeout < 0 { + return nil + } + return &timeout } diff --git a/go/vt/sqlparser/comments_test.go b/go/vt/sqlparser/comments_test.go index 42d02e35652..7f36645901c 100644 --- a/go/vt/sqlparser/comments_test.go +++ b/go/vt/sqlparser/comments_test.go @@ -474,8 +474,10 @@ func TestConsolidator(t *testing.T) { for _, test := range testCases { t.Run(test.query, func(t *testing.T) { stmt, _ := parser.Parse(test.query) - got := Consolidator(stmt) - assert.Equalf(t, test.expected, got, fmt.Sprintf("Consolidator(stmt) returned %v but expected %v", got, test.expected)) + qh, err := BuildQueryHints(stmt) + require.NoError(t, err) + assert.Equalf(t, test.expected, qh.Consolidator, + "Consolidator(stmt) returned %v but expected %v", qh.Consolidator, test.expected) }) } } @@ -534,12 +536,12 @@ func TestGetPriorityFromStatement(t *testing.T) { t.Parallel() stmt, err := parser.Parse(testCase.query) assert.NoError(t, err) - actualPriority, actualError := GetPriorityFromStatement(stmt) + qh, err := BuildQueryHints(stmt) if testCase.expectedError != nil { - assert.ErrorIs(t, actualError, testCase.expectedError) + assert.ErrorIs(t, err, testCase.expectedError) } else { assert.NoError(t, err) - assert.Equal(t, testCase.expectedPriority, actualPriority) + assert.Equal(t, testCase.expectedPriority, qh.Priority) } }) } @@ -661,3 +663,38 @@ func TestSetMySQLSetVarValue(t *testing.T) { }) } } + +// TestQueryTimeout tests the extraction of Query_Timeout_MS from the comments. +func TestQueryTimeout(t *testing.T) { + testCases := []struct { + query string + expTimeout int + noTimeout bool + }{{ + query: "select * from a_table", + noTimeout: true, + }, { + query: "select /*vt+ QUERY_TIMEOUT_MS=21 */ * from another_table", + expTimeout: 21, + }, { + query: "select /*vt+ QUERY_TIMEOUT_MS=0 */ * from another_table", + expTimeout: 0, + }, { + query: "select /*vt+ PRIORITY=-42 */ * from another_table", + noTimeout: true, + }} + + parser := NewTestParser() + for _, tc := range testCases { + t.Run(tc.query, func(t *testing.T) { + stmt, err := parser.Parse(tc.query) + assert.NoError(t, err) + qh, _ := BuildQueryHints(stmt) + if tc.noTimeout { + assert.Nil(t, qh.Timeout) + } else { + assert.Equal(t, tc.expTimeout, *qh.Timeout) + } + }) + } +} diff --git a/go/vt/vtgate/engine/fake_vcursor_test.go b/go/vt/vtgate/engine/fake_vcursor_test.go index 498c26db877..4c316201a4c 100644 --- a/go/vt/vtgate/engine/fake_vcursor_test.go +++ b/go/vt/vtgate/engine/fake_vcursor_test.go @@ -61,6 +61,10 @@ type noopVCursor struct { inTx bool } +func (t *noopVCursor) SetExecQueryTimeout(timeout *int) { + panic("implement me") +} + // MySQLVersion implements VCursor. func (t *noopVCursor) Commit(ctx context.Context) error { return nil diff --git a/go/vt/vtgate/engine/primitive.go b/go/vt/vtgate/engine/primitive.go index 30894b99ab8..38e7dcdac68 100644 --- a/go/vt/vtgate/engine/primitive.go +++ b/go/vt/vtgate/engine/primitive.go @@ -173,6 +173,7 @@ type ( SetConsolidator(querypb.ExecuteOptions_Consolidator) SetWorkloadName(string) SetPriority(string) + SetExecQueryTimeout(timeout *int) SetFoundRows(uint64) SetDDLStrategy(string) diff --git a/go/vt/vtgate/engine/route.go b/go/vt/vtgate/engine/route.go index f28dda01a52..64de598c762 100644 --- a/go/vt/vtgate/engine/route.go +++ b/go/vt/vtgate/engine/route.go @@ -142,7 +142,7 @@ func (route *Route) TryExecute(ctx context.Context, vcursor VCursor, bindVars ma // addQueryTimeout adds a query timeout to the context it receives and returns the modified context along with the cancel function. func addQueryTimeout(ctx context.Context, vcursor VCursor, queryTimeout int) (context.Context, context.CancelFunc) { timeout := vcursor.Session().GetQueryTimeout(queryTimeout) - if timeout != 0 { + if timeout > 0 { return context.WithTimeout(ctx, time.Duration(timeout)*time.Millisecond) } return ctx, func() {} diff --git a/go/vt/vtgate/executor.go b/go/vt/vtgate/executor.go index 08e1bf09ab7..5dc388f1a3d 100644 --- a/go/vt/vtgate/executor.go +++ b/go/vt/vtgate/executor.go @@ -1105,15 +1105,16 @@ func (e *Executor) getPlan( return nil, vterrors.VT13001("vschema not initialized") } - vcursor.SetIgnoreMaxMemoryRows(sqlparser.IgnoreMaxMaxMemoryRowsDirective(stmt)) - vcursor.SetConsolidator(sqlparser.Consolidator(stmt)) - vcursor.SetWorkloadName(sqlparser.GetWorkloadNameFromStatement(stmt)) - vcursor.UpdateForeignKeyChecksState(sqlparser.ForeignKeyChecksState(stmt)) - priority, err := sqlparser.GetPriorityFromStatement(stmt) + qh, err := sqlparser.BuildQueryHints(stmt) if err != nil { return nil, err } - vcursor.SetPriority(priority) + vcursor.SetIgnoreMaxMemoryRows(qh.IgnoreMaxMemoryRows) + vcursor.SetConsolidator(qh.Consolidator) + vcursor.SetWorkloadName(qh.Workload) + vcursor.UpdateForeignKeyChecksState(qh.ForeignKeyChecks) + vcursor.SetPriority(qh.Priority) + vcursor.SetExecQueryTimeout(qh.Timeout) setVarComment, err := prepareSetVarComment(vcursor, stmt) if err != nil { diff --git a/go/vt/vtgate/executor_test.go b/go/vt/vtgate/executor_test.go index b8e2b996780..5cbc9c6d711 100644 --- a/go/vt/vtgate/executor_test.go +++ b/go/vt/vtgate/executor_test.go @@ -1787,7 +1787,8 @@ func TestGetPlanPriority(t *testing.T) { stmt, err := sqlparser.NewTestParser().Parse(testCase.sql) assert.NoError(t, err) - crticalityFromStatement, _ := sqlparser.GetPriorityFromStatement(stmt) + qh, _ := sqlparser.BuildQueryHints(stmt) + crticalityFromStatement := qh.Priority _, err = r.getPlan(context.Background(), vCursor, testCase.sql, stmt, makeComments("/* some comment */"), map[string]*querypb.BindVariable{}, nil, true, logStats) if testCase.expectedError != nil { diff --git a/go/vt/vtgate/planbuilder/builder.go b/go/vt/vtgate/planbuilder/builder.go index 5d1d4ecd622..27b994b1730 100644 --- a/go/vt/vtgate/planbuilder/builder.go +++ b/go/vt/vtgate/planbuilder/builder.go @@ -73,11 +73,14 @@ func TestBuilder(query string, vschema plancontext.VSchema, keyspace string) (*e // Store the foreign key mode like we do for vcursor. vw, isVw := vschema.(*vschemawrapper.VSchemaWrapper) if isVw { - fkState := sqlparser.ForeignKeyChecksState(stmt) - if fkState != nil { + qh, err := sqlparser.BuildQueryHints(stmt) + if err != nil { + return nil, err + } + if qh.ForeignKeyChecks != nil { // Restore the old volue of ForeignKeyChecksState to not interfere with the next test cases. oldVal := vw.ForeignKeyChecksState - vw.ForeignKeyChecksState = fkState + vw.ForeignKeyChecksState = qh.ForeignKeyChecks defer func() { vw.ForeignKeyChecksState = oldVal }() diff --git a/go/vt/vtgate/vcursor_impl.go b/go/vt/vtgate/vcursor_impl.go index a71ad29184a..79d9fe341ba 100644 --- a/go/vt/vtgate/vcursor_impl.go +++ b/go/vt/vtgate/vcursor_impl.go @@ -929,6 +929,19 @@ func (vc *vcursorImpl) SetPriority(priority string) { } } +func (vc *vcursorImpl) SetExecQueryTimeout(timeout *int) { + if timeout == nil { + if vc.safeSession.GetOptions() == nil { + return + } + vc.safeSession.GetOrCreateOptions().Timeout = nil + return + } + vc.safeSession.GetOrCreateOptions().Timeout = &querypb.ExecuteOptions_AuthoritativeTimeout{ + AuthoritativeTimeout: int64(*timeout), + } +} + // SetConsolidator implements the SessionActions interface func (vc *vcursorImpl) SetConsolidator(consolidator querypb.ExecuteOptions_Consolidator) { // Avoid creating session Options when they do not yet exist and the