Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tests: add tests for go/vt/sqlparser #15057

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 162 additions & 14 deletions go/vt/sqlparser/analyzer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ import (
)

func TestPreview(t *testing.T) {
testcases := []struct {
sql string
want StatementType
tt := []struct {
sql string
expected StatementType
}{
{"select ...", StmtSelect},
{" select ...", StmtSelect},
Expand Down Expand Up @@ -73,28 +73,37 @@ func TestPreview(t *testing.T) {
{"truncate", StmtDDL},
{"flush", StmtFlush},
{"unknown", StmtUnknown},

{"/* leading comment */ select ...", StmtSelect},
{"/* leading comment */ (select ...", StmtSelect},
{"/* leading comment */ /* leading comment 2 */ select ...", StmtSelect},
{"/*! MySQL-specific comment */", StmtComment},
{"/*!50708 MySQL-version comment */", StmtComment},
{"-- leading single line comment \n select ...", StmtSelect},
{"-- leading single line comment \n -- leading single line comment 2\n select ...", StmtSelect},

{"/* leading comment no end select ...", StmtUnknown},
{"-- leading single line comment no end select ...", StmtUnknown},
{"/*!40000 ALTER TABLE `t1` DISABLE KEYS */", StmtComment},
{"release", StmtRelease},
{"rollback", StmtRollback},
{"rollback ....", StmtSRollback},
{"kill", StmtKill},
{"savepoint", StmtSavepoint},
{"lock", StmtLockTables},
{"unlock", StmtUnlockTables},
{"stream", StmtStream},
{"vstream", StmtVStream},
{"revert", StmtRevert},
}
for _, tcase := range testcases {
if got := Preview(tcase.sql); got != tcase.want {
t.Errorf("Preview(%s): %v, want %v", tcase.sql, got, tcase.want)
}
for _, tc := range tt {
t.Run(tc.sql, func(t *testing.T) {
out := Preview(tc.sql)
assert.Equal(t, tc.expected, out)
})
}
}

func TestIsDML(t *testing.T) {
testcases := []struct {
tt := []struct {
sql string
want bool
}{
Expand All @@ -109,10 +118,33 @@ func TestIsDML(t *testing.T) {
{"", false},
{" ", false},
}
for _, tcase := range testcases {
if got := IsDML(tcase.sql); got != tcase.want {
t.Errorf("IsDML(%s): %v, want %v", tcase.sql, got, tcase.want)
}
for _, tc := range tt {
t.Run(tc.sql, func(t *testing.T) {
got := IsDML(tc.sql)
assert.Equal(t, tc.want, got)
})
}
}

func TestIsDMLStatement(t *testing.T) {
dmlStatements := []Statement{
&Insert{},
&Update{},
&Delete{},
}

nonDmlStatements := []Statement{
&Select{},
&Set{},
&Show{},
}

for _, stmt := range dmlStatements {
assert.True(t, IsDMLStatement(stmt), "Expected true, got false for %v", stmt)
}

for _, stmt := range nonDmlStatements {
assert.False(t, IsDMLStatement(stmt), "Expected false, got true for %v", stmt)
}
}

Expand Down Expand Up @@ -338,3 +370,119 @@ func TestIsNull(t *testing.T) {
}
}
}

func TestStatementTypeString(t *testing.T) {
testcases := []struct {
stmtType StatementType
expected string
}{
{StmtSelect, "SELECT"},
{StmtStream, "STREAM"},
{StmtVStream, "VSTREAM"},
{StmtRevert, "REVERT"},
{StmtInsert, "INSERT"},
{StmtReplace, "REPLACE"},
{StmtUpdate, "UPDATE"},
{StmtDelete, "DELETE"},
{StmtDDL, "DDL"},
{StmtBegin, "BEGIN"},
{StmtCommit, "COMMIT"},
{StmtRollback, "ROLLBACK"},
{StmtSet, "SET"},
{StmtShow, "SHOW"},
{StmtUse, "USE"},
{StmtOther, "OTHER"},
{StmtAnalyze, "ANALYZE"},
{StmtPriv, "PRIV"},
{StmtExplain, "EXPLAIN"},
{StmtSavepoint, "SAVEPOINT"},
{StmtSRollback, "SAVEPOINT_ROLLBACK"},
{StmtRelease, "RELEASE"},
{StmtLockTables, "LOCK_TABLES"},
{StmtUnlockTables, "UNLOCK_TABLES"},
{StmtFlush, "FLUSH"},
{StmtCallProc, "CALL_PROC"},
{StmtCommentOnly, "COMMENT_ONLY"},
{StmtPrepare, "PREPARE"},
{StmtExecute, "EXECUTE"},
{StmtDeallocate, "DEALLOCATE PREPARE"},
{StmtKill, "KILL"},
{StmtUnknown, "UNKNOWN"},
}

for _, tc := range testcases {
t.Run(tc.expected, func(t *testing.T) {
result := tc.stmtType.String()
assert.Equal(t, tc.expected, result)
})
}
}

func TestASTToStatementType(t *testing.T) {
testcases := []struct {
name string
ast Statement
want StatementType
}{
{"Select", &Select{}, StmtSelect},
{"Insert", &Insert{}, StmtInsert},
{"Update", &Update{}, StmtUpdate},
{"Delete", &Delete{}, StmtDelete},
{"Set", &Set{}, StmtSet},
{"Show", &Show{}, StmtShow},
{"AlterVschema", &AlterVschema{}, StmtDDL},
{"RevertMigration", &RevertMigration{}, StmtRevert},
{"ShowMigrationLogs", &ShowMigrationLogs{}, StmtShowMigrationLogs},
{"Use", &Use{}, StmtUse},
{"OtherAdmin", &OtherAdmin{}, StmtOther},
{"Analyze", &Analyze{}, StmtAnalyze},
{"VExplainStmt", &VExplainStmt{}, StmtExplain},
{"Begin", &Begin{}, StmtBegin},
{"Commit", &Commit{}, StmtCommit},
{"Rollback", &Rollback{}, StmtRollback},
{"Savepoint", &Savepoint{}, StmtSavepoint},
{"SRollback", &SRollback{}, StmtSRollback},
{"Release", &Release{}, StmtRelease},
{"LockTables", &LockTables{}, StmtLockTables},
{"UnlockTables", &UnlockTables{}, StmtUnlockTables},
{"Flush", &Flush{}, StmtFlush},
{"CallProc", &CallProc{}, StmtCallProc},
{"Stream", &Stream{}, StmtStream},
{"VStream", &VStream{}, StmtVStream},
{"CommentOnly", &CommentOnly{}, StmtCommentOnly},
{"PrepareStmt", &PrepareStmt{}, StmtPrepare},
{"ExecuteStmt", &ExecuteStmt{}, StmtExecute},
{"DeallocateStmt", &DeallocateStmt{}, StmtDeallocate},
{"Kill", &Kill{}, StmtKill},
}

for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
result := ASTToStatementType(tc.ast)
assert.Equal(t, tc.want, result)
})
}
}

func TestMustRewriteAST(t *testing.T) {
testCases := []struct {
name string
stmt Statement
hasSelectLimit bool
}{
{"Set", &Set{}, true},
{"Show with ShowBasic", &Show{Internal: &ShowBasic{}}, true},
{"Show", &Show{}, false},
{"Select", &Select{}, true},
{"Insert", &Insert{}, false},
{"Update", &Update{}, false},
{"Delete", &Delete{}, false},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := MustRewriteAST(tc.stmt, tc.hasSelectLimit)
assert.Equal(t, tc.hasSelectLimit, result, "Test case: %s", tc.stmt)
})
}
}
Loading