diff --git a/go/vt/sqlparser/analyzer_test.go b/go/vt/sqlparser/analyzer_test.go index 0a2de52ef19..4d22a02eb4f 100644 --- a/go/vt/sqlparser/analyzer_test.go +++ b/go/vt/sqlparser/analyzer_test.go @@ -23,9 +23,9 @@ import ( ) func TestPreview(t *testing.T) { - testcases := []struct { - sql string - want StatementType + tt := []struct { + sql string + expected StatementType }{ {"select ...", StmtSelect}, {" select ...", StmtSelect}, @@ -73,7 +73,6 @@ func TestPreview(t *testing.T) { {"truncate", StmtDDL}, {"flush", StmtFlush}, {"unknown", StmtUnknown}, - {"/* leading comment */ select ...", StmtSelect}, {"/* leading comment */ (select ...", StmtSelect}, {"/* leading comment */ /* leading comment 2 */ select ...", StmtSelect}, @@ -81,20 +80,30 @@ func TestPreview(t *testing.T) { {"/*!50708 MySQL-version comment */", StmtComment}, {"-- leading single line comment \n select ...", StmtSelect}, {"-- leading single line comment \n -- leading single line comment 2\n select ...", StmtSelect}, - {"/* leading comment no end select ...", StmtUnknown}, {"-- leading single line comment no end select ...", StmtUnknown}, {"/*!40000 ALTER TABLE `t1` DISABLE KEYS */", StmtComment}, + {"release", StmtRelease}, + {"rollback", StmtRollback}, + {"rollback ....", StmtSRollback}, + {"kill", StmtKill}, + {"savepoint", StmtSavepoint}, + {"lock", StmtLockTables}, + {"unlock", StmtUnlockTables}, + {"stream", StmtStream}, + {"vstream", StmtVStream}, + {"revert", StmtRevert}, } - for _, tcase := range testcases { - if got := Preview(tcase.sql); got != tcase.want { - t.Errorf("Preview(%s): %v, want %v", tcase.sql, got, tcase.want) - } + for _, tc := range tt { + t.Run(tc.sql, func(t *testing.T) { + out := Preview(tc.sql) + assert.Equal(t, tc.expected, out) + }) } } func TestIsDML(t *testing.T) { - testcases := []struct { + tt := []struct { sql string want bool }{ @@ -109,10 +118,33 @@ func TestIsDML(t *testing.T) { {"", false}, {" ", false}, } - for _, tcase := range testcases { - if got := IsDML(tcase.sql); got != tcase.want { - t.Errorf("IsDML(%s): %v, want %v", tcase.sql, got, tcase.want) - } + for _, tc := range tt { + t.Run(tc.sql, func(t *testing.T) { + got := IsDML(tc.sql) + assert.Equal(t, tc.want, got) + }) + } +} + +func TestIsDMLStatement(t *testing.T) { + dmlStatements := []Statement{ + &Insert{}, + &Update{}, + &Delete{}, + } + + nonDmlStatements := []Statement{ + &Select{}, + &Set{}, + &Show{}, + } + + for _, stmt := range dmlStatements { + assert.True(t, IsDMLStatement(stmt), "Expected true, got false for %v", stmt) + } + + for _, stmt := range nonDmlStatements { + assert.False(t, IsDMLStatement(stmt), "Expected false, got true for %v", stmt) } } @@ -338,3 +370,119 @@ func TestIsNull(t *testing.T) { } } } + +func TestStatementTypeString(t *testing.T) { + testcases := []struct { + stmtType StatementType + expected string + }{ + {StmtSelect, "SELECT"}, + {StmtStream, "STREAM"}, + {StmtVStream, "VSTREAM"}, + {StmtRevert, "REVERT"}, + {StmtInsert, "INSERT"}, + {StmtReplace, "REPLACE"}, + {StmtUpdate, "UPDATE"}, + {StmtDelete, "DELETE"}, + {StmtDDL, "DDL"}, + {StmtBegin, "BEGIN"}, + {StmtCommit, "COMMIT"}, + {StmtRollback, "ROLLBACK"}, + {StmtSet, "SET"}, + {StmtShow, "SHOW"}, + {StmtUse, "USE"}, + {StmtOther, "OTHER"}, + {StmtAnalyze, "ANALYZE"}, + {StmtPriv, "PRIV"}, + {StmtExplain, "EXPLAIN"}, + {StmtSavepoint, "SAVEPOINT"}, + {StmtSRollback, "SAVEPOINT_ROLLBACK"}, + {StmtRelease, "RELEASE"}, + {StmtLockTables, "LOCK_TABLES"}, + {StmtUnlockTables, "UNLOCK_TABLES"}, + {StmtFlush, "FLUSH"}, + {StmtCallProc, "CALL_PROC"}, + {StmtCommentOnly, "COMMENT_ONLY"}, + {StmtPrepare, "PREPARE"}, + {StmtExecute, "EXECUTE"}, + {StmtDeallocate, "DEALLOCATE PREPARE"}, + {StmtKill, "KILL"}, + {StmtUnknown, "UNKNOWN"}, + } + + for _, tc := range testcases { + t.Run(tc.expected, func(t *testing.T) { + result := tc.stmtType.String() + assert.Equal(t, tc.expected, result) + }) + } +} + +func TestASTToStatementType(t *testing.T) { + testcases := []struct { + name string + ast Statement + want StatementType + }{ + {"Select", &Select{}, StmtSelect}, + {"Insert", &Insert{}, StmtInsert}, + {"Update", &Update{}, StmtUpdate}, + {"Delete", &Delete{}, StmtDelete}, + {"Set", &Set{}, StmtSet}, + {"Show", &Show{}, StmtShow}, + {"AlterVschema", &AlterVschema{}, StmtDDL}, + {"RevertMigration", &RevertMigration{}, StmtRevert}, + {"ShowMigrationLogs", &ShowMigrationLogs{}, StmtShowMigrationLogs}, + {"Use", &Use{}, StmtUse}, + {"OtherAdmin", &OtherAdmin{}, StmtOther}, + {"Analyze", &Analyze{}, StmtAnalyze}, + {"VExplainStmt", &VExplainStmt{}, StmtExplain}, + {"Begin", &Begin{}, StmtBegin}, + {"Commit", &Commit{}, StmtCommit}, + {"Rollback", &Rollback{}, StmtRollback}, + {"Savepoint", &Savepoint{}, StmtSavepoint}, + {"SRollback", &SRollback{}, StmtSRollback}, + {"Release", &Release{}, StmtRelease}, + {"LockTables", &LockTables{}, StmtLockTables}, + {"UnlockTables", &UnlockTables{}, StmtUnlockTables}, + {"Flush", &Flush{}, StmtFlush}, + {"CallProc", &CallProc{}, StmtCallProc}, + {"Stream", &Stream{}, StmtStream}, + {"VStream", &VStream{}, StmtVStream}, + {"CommentOnly", &CommentOnly{}, StmtCommentOnly}, + {"PrepareStmt", &PrepareStmt{}, StmtPrepare}, + {"ExecuteStmt", &ExecuteStmt{}, StmtExecute}, + {"DeallocateStmt", &DeallocateStmt{}, StmtDeallocate}, + {"Kill", &Kill{}, StmtKill}, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + result := ASTToStatementType(tc.ast) + assert.Equal(t, tc.want, result) + }) + } +} + +func TestMustRewriteAST(t *testing.T) { + testCases := []struct { + name string + stmt Statement + hasSelectLimit bool + }{ + {"Set", &Set{}, true}, + {"Show with ShowBasic", &Show{Internal: &ShowBasic{}}, true}, + {"Show", &Show{}, false}, + {"Select", &Select{}, true}, + {"Insert", &Insert{}, false}, + {"Update", &Update{}, false}, + {"Delete", &Delete{}, false}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := MustRewriteAST(tc.stmt, tc.hasSelectLimit) + assert.Equal(t, tc.hasSelectLimit, result, "Test case: %s", tc.stmt) + }) + } +}