Skip to content

Commit

Permalink
Protect ExecuteFetchAsDBA from multi-statements, excluding CREATE TAB…
Browse files Browse the repository at this point in the history
…LE and CREATE VIEW sequence. Fix allowZeroInDate

Signed-off-by: Shlomi Noach <[email protected]>
  • Loading branch information
shlomi-noach committed Jan 15, 2024
1 parent 18fe384 commit e50516f
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 47 deletions.
48 changes: 45 additions & 3 deletions go/test/endtoend/clustertest/vtctld_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,51 @@ func testTabletStatus(t *testing.T) {
}

func testExecuteAsDba(t *testing.T) {
result, err := clusterInstance.VtctlclientProcess.ExecuteCommandWithOutput("ExecuteFetchAsDba", clusterInstance.Keyspaces[0].Shards[0].Vttablets[0].Alias, `SELECT 1 AS a`)
require.NoError(t, err)
assert.Equal(t, result, oneTableOutput)
tcases := []struct {
query string
result string
expectErr bool
}{
{
query: "",
expectErr: true,
},
{
query: "SELECT 1 AS a",
result: oneTableOutput,
},
{
query: "SELECT 1 AS a; SELECT 1 AS a",
expectErr: true,
},
{
query: "create table t(id int)",
result: "",
},
{
query: "create table if not exists t(id int)",
result: "",
},
{
query: "create table if not exists t(id int); create table if not exists t(id int);",
result: "",
},
{
query: "create table if not exists t(id int); create table if not exists t(id int); SELECT 1 AS a",
expectErr: true,
},
}
for _, tcase := range tcases {
t.Run(tcase.query, func(t *testing.T) {
result, err := clusterInstance.VtctlclientProcess.ExecuteCommandWithOutput("ExecuteFetchAsDba", clusterInstance.Keyspaces[0].Shards[0].Vttablets[0].Alias, tcase.query)
if tcase.expectErr {
assert.Error(t, err)
} else {
require.NoError(t, err)
assert.Equal(t, tcase.result, result)
}
})
}
}

func testExecuteAsApp(t *testing.T) {
Expand Down
67 changes: 34 additions & 33 deletions go/vt/vttablet/tabletmanager/rpc_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ package tabletmanager

import (
"context"
"errors"
"io"

"vitess.io/vitess/go/constants/sidecar"
"vitess.io/vitess/go/sqlescape"
Expand All @@ -33,26 +31,32 @@ import (
"vitess.io/vitess/go/vt/proto/vtrpc"
)

// queriesHaveAllowZeroInDateDirective reutrns 'true' when at least one of the queries
// analyzeExecuteFetchAsDbaMultiQuery reutrns 'true' when at least one of the queries
// in the given SQL has a `/*vt+ allowZeroInDate=true */` directive.
func queriesHaveAllowZeroInDateDirective(sql string, parser *sqlparser.Parser) bool {
tokenizer := parser.NewStringTokenizer(sql)
for {
stmt, err := sqlparser.ParseNext(tokenizer)
if err != nil {
if errors.Is(err, io.EOF) {
break
}
return false
func analyzeExecuteFetchAsDbaMultiQuery(sql string, parser *sqlparser.Parser) (statements []sqlparser.Statement, allCreateTableViewQueries bool, allowZeroInDate bool, err error) {
statements, err = parser.SplitStatements(sql)
if err != nil {
return nil, false, false, err
}
if len(statements) == 0 {
return nil, false, false, vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "no statements found in query: %s", sql)
}
allCreateTableViewQueries = true
for _, stmt := range statements {
switch stmt.(type) {
case *sqlparser.CreateTable, *sqlparser.CreateView:
default:
allCreateTableViewQueries = false
}

if cmnt, ok := stmt.(sqlparser.Commented); ok {
directives := cmnt.GetParsedComments().Directives()
if directives.IsSet("allowZeroInDate") {
return true
allowZeroInDate = true
}
}
}
return false
return statements, allCreateTableViewQueries, allowZeroInDate, nil
}

// ExecuteFetchAsDba will execute the given query, possibly disabling binlogs and reload schema.
Expand Down Expand Up @@ -81,25 +85,17 @@ func (tm *TabletManager) ExecuteFetchAsDba(ctx context.Context, req *tabletmanag
_, _ = conn.ExecuteFetch("USE "+sqlescape.EscapeID(req.DbName), 1, false)
}

allowZeroInDate := false
tokenizer := tm.SQLParser.NewStringTokenizer(string(req.Query))
for {
stmt, err := sqlparser.ParseNext(tokenizer)
if err != nil {
if errors.Is(err, io.EOF) {
break
}
return nil, vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "could not parse statement in ExecuteFetchAsDba: %v: %v", string(req.Query), err)
}
if cmnt, ok := stmt.(sqlparser.Commented); ok {
directives := cmnt.GetParsedComments().Directives()
if directives.IsSet("allowZeroInDate") {
// --allow-zero-in-date Applies to DDLs. As a backport solution to
// https://github.com/vitessio/vitess/issues/14952, it is enough that
// one of the DDLs has the `allowZeroInDate` directive, that we allow
// zero in date for all queries.
allowZeroInDate = true
}
statements, allCreateTableViewQueries, allowZeroInDate, err := analyzeExecuteFetchAsDbaMultiQuery(string(req.Query), tm.SQLParser)
if err != nil {
return nil, err
}
if len(statements) > 1 {
// Up to v19, we allow multi-statement SQL in ExecuteFetchAsDba, but only for the specific case
// where all statements are CREATE TABLE or CREATE VIEW. This is to support `ApplySchema --batch-size`.
// In v20, we will not support multi statements whatsoever.
// v20 will throw an error by virtua of using ExecuteFetch instead of ExecuteFetchMulti.
if !allCreateTableViewQueries {
return nil, vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "multi statement queries are not supported in ExecuteFetchAsDba unless all are CREATE TABLE or CREATE VIEW")
}
}
if allowZeroInDate {
Expand All @@ -108,10 +104,15 @@ func (tm *TabletManager) ExecuteFetchAsDba(ctx context.Context, req *tabletmanag
}
}
// Replace any provided sidecar database qualifiers with the correct one.
// TODO(shlomi): we use ReplaceTableQualifiersMultiQuery for backwards compatibility. In v20 we will not accept
// multi statement queries in ExecuteFetchAsDBA. This will be rewritten as ReplaceTableQualifiers()
uq, err := tm.SQLParser.ReplaceTableQualifiersMultiQuery(string(req.Query), sidecar.DefaultName, sidecar.GetName())
if err != nil {
return nil, err
}
// TODO(shlomi): we use ExecuteFetchMulti for backwards compatibility. In v20 we will not accept
// multi statement queries in ExecuteFetchAsDBA. This will be rewritten as:
// (in v20): result, err := ExecuteFetch(uq, int(req.MaxRows), true /*wantFields*/)
result, more, err := conn.ExecuteFetchMulti(uq, int(req.MaxRows), true /*wantFields*/)
for more {
_, more, _, err = conn.ReadQueryResult(0, false)
Expand Down
59 changes: 48 additions & 11 deletions go/vt/vttablet/tabletmanager/rpc_query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,29 +35,66 @@ import (
tabletmanagerdatapb "vitess.io/vitess/go/vt/proto/tabletmanagerdata"
)

func TestQueriesHaveAllowZeroInDateDirective(t *testing.T) {
func TestAnalyzeExecuteFetchAsDbaMultiQuery(t *testing.T) {
tcases := []struct {
query string
expected bool
query string
stmts int
allowZeroInDate bool
allCreate bool
expectErr bool
}{
{
query: "create table t(id int)",
expected: false,
query: "",
expectErr: true,
},
{
query: "create /*vt+ allowZeroInDate=true */ table t (id int)",
expected: true,
query: "select * from t1 ; select * from t2",
stmts: 2,
},
{
query: "create table a (id int) ; create /*vt+ allowZeroInDate=true */ table b (id int)",
expected: true,
query: "create table t(id int)",
stmts: 1,
allCreate: true,
},
{
query: "create table t(id int); create view v as select 1 from dual",
stmts: 2,
allCreate: true,
},
{
query: "create table t(id int); create view v as select 1 from dual; drop table t3",
stmts: 3,
allCreate: false,
},
{
query: "create /*vt+ allowZeroInDate=true */ table t (id int)",
stmts: 1,
allCreate: true,
allowZeroInDate: true,
},
{
query: "create table a (id int) ; create /*vt+ allowZeroInDate=true */ table b (id int)",
stmts: 2,
allCreate: true,
allowZeroInDate: true,
},
{
query: "create table a (id int) ; --comment ; what",
expectErr: true,
},
}
for _, tcase := range tcases {
t.Run(tcase.query, func(t *testing.T) {
parser := sqlparser.NewTestParser()
got := queriesHaveAllowZeroInDateDirective(tcase.query, parser)
assert.Equal(t, tcase.expected, got)
statements, allCreate, allowZeroInDate, err := analyzeExecuteFetchAsDbaMultiQuery(tcase.query, parser)
if tcase.expectErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tcase.stmts, len(statements))
assert.Equal(t, tcase.allCreate, allCreate)
assert.Equal(t, tcase.allowZeroInDate, allowZeroInDate)
}
})
}
}
Expand Down

0 comments on commit e50516f

Please sign in to comment.