diff --git a/.golangci.yml b/.golangci.yml index 9acf83cb9be..656e94021f0 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -91,6 +91,9 @@ issues: linters: - errcheck # This code is autogenerated and should be permanently excluded. + - path: '^go/vt/sqlparser/(ast_format|ast_format_fast).go' + linters: + - errcheck - path: '^go/vt/sqlparser/goyacc' linters: - errcheck diff --git a/go/tools/astfmtgen/main.go b/go/tools/astfmtgen/main.go index 4d30c6fa1d6..37be6dba71f 100644 --- a/go/tools/astfmtgen/main.go +++ b/go/tools/astfmtgen/main.go @@ -104,24 +104,27 @@ func (r *Rewriter) replaceAstfmtCalls(cursor *astutil.Cursor) bool { } case *ast.ExprStmt: if call, ok := v.X.(*ast.CallExpr); ok { - if r.isPrintfCall(call) { + switch r.methodName(call) { + case "astPrintf": return r.rewriteAstPrintf(cursor, call) + case "literal": + callexpr := call.Fun.(*ast.SelectorExpr) + callexpr.Sel.Name = "WriteString" + return true } } } return true } -func (r *Rewriter) isPrintfCall(n *ast.CallExpr) bool { - s, ok := n.Fun.(*ast.SelectorExpr) - if !ok { - return false - } - id := s.Sel - if id != nil && !r.pkg.TypesInfo.Types[id].IsType() { - return id.Name == "astPrintf" +func (r *Rewriter) methodName(n *ast.CallExpr) string { + if call, ok := n.Fun.(*ast.SelectorExpr); ok { + id := call.Sel + if id != nil && !r.pkg.TypesInfo.Types[id].IsType() { + return id.Name + } } - return false + return "" } func (r *Rewriter) rewriteLiteral(rcv ast.Expr, method string, arg ast.Expr) ast.Stmt { @@ -138,7 +141,10 @@ func (r *Rewriter) rewriteLiteral(rcv ast.Expr, method string, arg ast.Expr) ast func (r *Rewriter) rewriteAstPrintf(cursor *astutil.Cursor, expr *ast.CallExpr) bool { callexpr := expr.Fun.(*ast.SelectorExpr) lit := expr.Args[1].(*ast.BasicLit) - format, _ := strconv.Unquote(lit.Value) + format, err := strconv.Unquote(lit.Value) + if err != nil { + panic("bad literal argument") + } end := len(format) fieldnum := 0 @@ -172,6 +178,10 @@ func (r *Rewriter) rewriteAstPrintf(cursor *astutil.Cursor, expr *ast.CallExpr) break } i++ // '%' + if format[i] == '#' { + i++ + } + token := format[i] switch token { case 'c': diff --git a/go/vt/sqlparser/ast_equals.go b/go/vt/sqlparser/ast_equals.go index 44d35736920..9ace7909330 100644 --- a/go/vt/sqlparser/ast_equals.go +++ b/go/vt/sqlparser/ast_equals.go @@ -5369,6 +5369,7 @@ func EqualsRefOfTableOption(a, b *TableOption) bool { } return a.Name == b.Name && a.String == b.String && + a.CaseSensitive == b.CaseSensitive && EqualsRefOfLiteral(a.Value, b.Value) && EqualsTableNames(a.Tables, b.Tables) } diff --git a/go/vt/sqlparser/ast_format.go b/go/vt/sqlparser/ast_format.go index 8414d91b77f..00d4d59c9e3 100644 --- a/go/vt/sqlparser/ast_format.go +++ b/go/vt/sqlparser/ast_format.go @@ -17,7 +17,6 @@ limitations under the License. package sqlparser import ( - "strconv" "strings" "vitess.io/vitess/go/sqltypes" @@ -31,20 +30,20 @@ func (node *Select) Format(buf *TrackedBuffer) { buf.astPrintf(node, "select %v", node.Comments) if node.Distinct { - buf.WriteString(DistinctStr) + buf.literal(DistinctStr) } if node.Cache != nil { if *node.Cache { - buf.WriteString(SQLCacheStr) + buf.literal(SQLCacheStr) } else { - buf.WriteString(SQLNoCacheStr) + buf.literal(SQLNoCacheStr) } } if node.StraightJoinHint { - buf.WriteString(StraightJoinHint) + buf.literal(StraightJoinHint) } if node.SQLCalcFoundRows { - buf.WriteString(SQLCalcFoundRowsStr) + buf.literal(SQLCalcFoundRowsStr) } buf.astPrintf(node, "%v from ", node.SelectExprs) @@ -69,13 +68,13 @@ func (node *Union) Format(buf *TrackedBuffer) { buf.astPrintf(node, "%v", node.Left) } - buf.WriteString(" ") + buf.WriteByte(' ') if node.Distinct { - buf.WriteString(UnionStr) + buf.literal(UnionStr) } else { - buf.WriteString(UnionAllStr) + buf.literal(UnionAllStr) } - buf.WriteString(" ") + buf.WriteByte(' ') if requiresParen(node.Right) { buf.astPrintf(node, "(%v)", node.Right) @@ -156,7 +155,7 @@ func (node *Delete) Format(buf *TrackedBuffer) { } buf.astPrintf(node, "delete %v", node.Comments) if node.Ignore { - buf.WriteString("ignore ") + buf.literal("ignore ") } if node.Targets != nil { buf.astPrintf(node, "%v ", node.Targets) @@ -179,7 +178,7 @@ func (node *SetTransaction) Format(buf *TrackedBuffer) { for i, char := range node.Characteristics { if i > 0 { - buf.WriteString(", ") + buf.literal(", ") } buf.astPrintf(node, "%v", char) } @@ -198,7 +197,7 @@ func (node *DropDatabase) Format(buf *TrackedBuffer) { func (node *Flush) Format(buf *TrackedBuffer) { buf.astPrintf(node, "%s", FlushStr) if node.IsLocal { - buf.WriteString(" local") + buf.literal(" local") } if len(node.FlushOptions) != 0 { prefix := " " @@ -207,15 +206,15 @@ func (node *Flush) Format(buf *TrackedBuffer) { prefix = ", " } } else { - buf.WriteString(" tables") + buf.literal(" tables") if len(node.TableNames) != 0 { buf.astPrintf(node, " %v", node.TableNames) } if node.ForExport { - buf.WriteString(" for export") + buf.literal(" for export") } if node.WithLock { - buf.WriteString(" with read lock") + buf.literal(" with read lock") } } } @@ -299,14 +298,14 @@ func (node *PartitionSpec) Format(buf *TrackedBuffer) { buf.astPrintf(node, "%s ", ReorganizeStr) for i, n := range node.Names { if i != 0 { - buf.WriteString(", ") + buf.literal(", ") } buf.astPrintf(node, "%v", n) } - buf.WriteString(" into (") + buf.literal(" into (") for i, pd := range node.Definitions { if i != 0 { - buf.WriteString(", ") + buf.literal(", ") } buf.astPrintf(node, "%v", pd) } @@ -317,14 +316,14 @@ func (node *PartitionSpec) Format(buf *TrackedBuffer) { buf.astPrintf(node, "%s ", DropPartitionStr) for i, n := range node.Names { if i != 0 { - buf.WriteString(", ") + buf.literal(", ") } buf.astPrintf(node, "%v", n) } case DiscardAction: buf.astPrintf(node, "%s ", DiscardStr) if node.IsAll { - buf.WriteString("all") + buf.literal("all") } else { prefix := "" for _, n := range node.Names { @@ -332,11 +331,11 @@ func (node *PartitionSpec) Format(buf *TrackedBuffer) { prefix = ", " } } - buf.WriteString(" tablespace") + buf.literal(" tablespace") case ImportAction: buf.astPrintf(node, "%s ", ImportStr) if node.IsAll { - buf.WriteString("all") + buf.literal("all") } else { prefix := "" for _, n := range node.Names { @@ -344,11 +343,11 @@ func (node *PartitionSpec) Format(buf *TrackedBuffer) { prefix = ", " } } - buf.WriteString(" tablespace") + buf.literal(" tablespace") case TruncateAction: buf.astPrintf(node, "%s ", TruncatePartitionStr) if node.IsAll { - buf.WriteString("all") + buf.literal("all") } else { prefix := "" for _, n := range node.Names { @@ -361,12 +360,12 @@ func (node *PartitionSpec) Format(buf *TrackedBuffer) { case ExchangeAction: buf.astPrintf(node, "%s %v with table %v", ExchangeStr, node.Names[0], node.TableName) if node.WithoutValidation { - buf.WriteString(" without validation") + buf.literal(" without validation") } case AnalyzeAction: buf.astPrintf(node, "%s ", AnalyzePartitionStr) if node.IsAll { - buf.WriteString("all") + buf.literal("all") } else { prefix := "" for _, n := range node.Names { @@ -377,7 +376,7 @@ func (node *PartitionSpec) Format(buf *TrackedBuffer) { case CheckAction: buf.astPrintf(node, "%s ", CheckStr) if node.IsAll { - buf.WriteString("all") + buf.literal("all") } else { prefix := "" for _, n := range node.Names { @@ -388,7 +387,7 @@ func (node *PartitionSpec) Format(buf *TrackedBuffer) { case OptimizeAction: buf.astPrintf(node, "%s ", OptimizeStr) if node.IsAll { - buf.WriteString("all") + buf.literal("all") } else { prefix := "" for _, n := range node.Names { @@ -399,7 +398,7 @@ func (node *PartitionSpec) Format(buf *TrackedBuffer) { case RebuildAction: buf.astPrintf(node, "%s ", RebuildStr) if node.IsAll { - buf.WriteString("all") + buf.literal("all") } else { prefix := "" for _, n := range node.Names { @@ -410,7 +409,7 @@ func (node *PartitionSpec) Format(buf *TrackedBuffer) { case RepairAction: buf.astPrintf(node, "%s ", RepairStr) if node.IsAll { - buf.WriteString("all") + buf.literal("all") } else { prefix := "" for _, n := range node.Names { @@ -419,9 +418,9 @@ func (node *PartitionSpec) Format(buf *TrackedBuffer) { } } case RemoveAction: - buf.WriteString(RemoveStr) + buf.literal(RemoveStr) case UpgradeAction: - buf.WriteString(UpgradeStr) + buf.literal(UpgradeStr) default: panic("unimplemented") } @@ -439,7 +438,7 @@ func (node *PartitionDefinition) Format(buf *TrackedBuffer) { func (node *PartitionValueRange) Format(buf *TrackedBuffer) { buf.astPrintf(node, "values %s", node.Type.ToString()) if node.Maxvalue { - buf.WriteString(" maxvalue") + buf.literal(" maxvalue") } else { buf.astPrintf(node, " %v", node.Range) } @@ -447,16 +446,16 @@ func (node *PartitionValueRange) Format(buf *TrackedBuffer) { // Format formats the node. func (node *PartitionOption) Format(buf *TrackedBuffer) { - buf.WriteString("partition by") + buf.literal("partition by") if node.IsLinear { - buf.WriteString(" linear") + buf.literal(" linear") } switch node.Type { case HashType: buf.astPrintf(node, " hash (%v)", node.Expr) case KeyType: - buf.WriteString(" key") + buf.literal(" key") if node.KeyAlgorithm != 0 { buf.astPrintf(node, " algorithm = %d", node.KeyAlgorithm) } @@ -477,29 +476,29 @@ func (node *PartitionOption) Format(buf *TrackedBuffer) { buf.astPrintf(node, " %v", node.SubPartition) } if node.Definitions != nil { - buf.WriteString(" (") + buf.literal(" (") for i, pd := range node.Definitions { if i != 0 { - buf.WriteString(", ") + buf.literal(", ") } buf.astPrintf(node, "%v", pd) } - buf.WriteString(")") + buf.WriteByte(')') } } // Format formats the node. func (node *SubPartition) Format(buf *TrackedBuffer) { - buf.WriteString("subpartition by") + buf.literal("subpartition by") if node.IsLinear { - buf.WriteString(" linear") + buf.literal(" linear") } switch node.Type { case HashType: buf.astPrintf(node, " hash (%v)", node.Expr) case KeyType: - buf.WriteString(" key") + buf.literal(" key") if node.KeyAlgorithm != 0 { buf.astPrintf(node, " algorithm = %d", node.KeyAlgorithm) } @@ -531,7 +530,7 @@ func (ts *TableSpec) Format(buf *TrackedBuffer) { buf.astPrintf(ts, "\n)") for i, opt := range ts.Options { if i != 0 { - buf.WriteString(",\n ") + buf.literal(",\n ") } buf.astPrintf(ts, " %s", opt.Name) if opt.String != "" { @@ -554,7 +553,7 @@ func (col *ColumnDefinition) Format(buf *TrackedBuffer) { // Format returns a canonical string representation of the type and all relevant options func (ct *ColumnType) Format(buf *TrackedBuffer) { - buf.astPrintf(ct, "%s", ct.Type) + buf.astPrintf(ct, "%#s", ct.Type) if ct.Length != nil && ct.Scale != nil { buf.astPrintf(ct, "(%v,%v)", ct.Length, ct.Scale) @@ -574,11 +573,11 @@ func (ct *ColumnType) Format(buf *TrackedBuffer) { buf.astPrintf(ct, " %s", keywordStrings[ZEROFILL]) } if ct.Charset != "" { - buf.astPrintf(ct, " %s %s %s", keywordStrings[CHARACTER], keywordStrings[SET], ct.Charset) + buf.astPrintf(ct, " %s %s %#s", keywordStrings[CHARACTER], keywordStrings[SET], ct.Charset) } if ct.Options != nil { if ct.Options.Collate != "" { - buf.astPrintf(ct, " %s %s", keywordStrings[COLLATE], ct.Options.Collate) + buf.astPrintf(ct, " %s %#s", keywordStrings[COLLATE], ct.Options.Collate) } if ct.Options.Null != nil && ct.Options.As == nil { if *ct.Options.Null { @@ -728,15 +727,15 @@ func (c *ConstraintDefinition) Format(buf *TrackedBuffer) { func (a ReferenceAction) Format(buf *TrackedBuffer) { switch a { case Restrict: - buf.WriteString("restrict") + buf.literal("restrict") case Cascade: - buf.WriteString("cascade") + buf.literal("cascade") case NoAction: - buf.WriteString("no action") + buf.literal("no action") case SetNull: - buf.WriteString("set null") + buf.literal("set null") case SetDefault: - buf.WriteString("set default") + buf.literal("set default") } } @@ -793,17 +792,17 @@ func (node *Use) Format(buf *TrackedBuffer) { // Format formats the node. func (node *Commit) Format(buf *TrackedBuffer) { - buf.WriteString("commit") + buf.literal("commit") } // Format formats the node. func (node *Begin) Format(buf *TrackedBuffer) { - buf.WriteString("begin") + buf.literal("begin") } // Format formats the node. func (node *Rollback) Format(buf *TrackedBuffer) { - buf.WriteString("rollback") + buf.literal("rollback") } // Format formats the node. @@ -854,7 +853,7 @@ func (node *PrepareStmt) Format(buf *TrackedBuffer) { func (node *ExecuteStmt) Format(buf *TrackedBuffer) { buf.astPrintf(node, "execute %v%v", node.Comments, node.Name) if len(node.Arguments) > 0 { - buf.WriteString(" using ") + buf.literal(" using ") } var prefix string for _, n := range node.Arguments { @@ -875,12 +874,12 @@ func (node *CallProc) Format(buf *TrackedBuffer) { // Format formats the node. func (node *OtherRead) Format(buf *TrackedBuffer) { - buf.WriteString("otherread") + buf.literal("otherread") } // Format formats the node. func (node *OtherAdmin) Format(buf *TrackedBuffer) { - buf.WriteString("otheradmin") + buf.literal("otheradmin") } // Format formats the node. @@ -933,7 +932,7 @@ func (node Columns) Format(buf *TrackedBuffer) { buf.astPrintf(node, "%s%v", prefix, n) prefix = ", " } - buf.WriteString(")") + buf.WriteByte(')') } // Format formats the node @@ -946,7 +945,7 @@ func (node Partitions) Format(buf *TrackedBuffer) { buf.astPrintf(node, "%s%v", prefix, n) prefix = ", " } - buf.WriteString(")") + buf.WriteByte(')') } // Format formats the node. @@ -1161,7 +1160,7 @@ func (node *Subquery) Format(buf *TrackedBuffer) { // Format formats the node. func (node *DerivedTable) Format(buf *TrackedBuffer) { if node.Lateral { - buf.WriteString("lateral ") + buf.literal("lateral ") } buf.astPrintf(node, "(%v)", node.Select) } @@ -1217,10 +1216,10 @@ func (node *TrimFuncExpr) Format(buf *TrackedBuffer) { } if (node.Type.ToString() != "") || (node.TrimArg != nil) { - buf.WriteString("from ") + buf.literal("from ") } buf.astPrintf(node, "%v", node.StringArg) - buf.WriteString(")") + buf.WriteByte(')') } // Format formats the node. @@ -1243,7 +1242,7 @@ func (node *CurTimeFuncExpr) Format(buf *TrackedBuffer) { // Format formats the node. func (node *CollateExpr) Format(buf *TrackedBuffer) { - buf.astPrintf(node, "%v collate %s", node.Expr, node.Collation) + buf.astPrintf(node, "%v collate %#s", node.Expr, node.Collation) } // Format formats the node. @@ -1262,7 +1261,7 @@ func (node *FuncExpr) Format(buf *TrackedBuffer) { if containEscapableChars(funcName, NoAt) { writeEscapedString(buf, funcName) } else { - buf.WriteString(funcName) + buf.literal(funcName) } buf.astPrintf(node, "(%s%v)", distinct, node.Exprs) } @@ -1357,9 +1356,9 @@ func (node *CaseExpr) Format(buf *TrackedBuffer) { func (node *Default) Format(buf *TrackedBuffer) { buf.astPrintf(node, "default") if node.ColName != "" { - buf.WriteString("(") + buf.WriteByte('(') formatID(buf, node.ColName, NoAt) - buf.WriteString(")") + buf.WriteByte(')') } } @@ -1449,8 +1448,8 @@ func (node SetExprs) Format(buf *TrackedBuffer) { // Format formats the node. func (node *SetExpr) Format(buf *TrackedBuffer) { if node.Scope != ImplicitScope { - buf.WriteString(node.Scope.ToString()) - buf.WriteString(" ") + buf.literal(node.Scope.ToString()) + buf.WriteByte(' ') } // We don't have to backtick set variable names. switch { @@ -1474,6 +1473,9 @@ func (node OnDup) Format(buf *TrackedBuffer) { // Format formats the node. func (node ColIdent) Format(buf *TrackedBuffer) { + if node.IsEmpty() { + return + } for i := NoAt; i < node.at; i++ { buf.WriteByte('@') } @@ -1487,40 +1489,40 @@ func (node TableIdent) Format(buf *TrackedBuffer) { // Format formats the node. func (node IsolationLevel) Format(buf *TrackedBuffer) { - buf.WriteString("isolation level ") + buf.literal("isolation level ") switch node { case ReadUncommitted: - buf.WriteString(ReadUncommittedStr) + buf.literal(ReadUncommittedStr) case ReadCommitted: - buf.WriteString(ReadCommittedStr) + buf.literal(ReadCommittedStr) case RepeatableRead: - buf.WriteString(RepeatableReadStr) + buf.literal(RepeatableReadStr) case Serializable: - buf.WriteString(SerializableStr) + buf.literal(SerializableStr) default: - buf.WriteString("Unknown Isolation level value") + buf.literal("Unknown Isolation level value") } } // Format formats the node. func (node AccessMode) Format(buf *TrackedBuffer) { if node == ReadOnly { - buf.WriteString(TxReadOnly) + buf.literal(TxReadOnly) } else { - buf.WriteString(TxReadWrite) + buf.literal(TxReadWrite) } } // Format formats the node. func (node *Load) Format(buf *TrackedBuffer) { - buf.WriteString("AST node missing for Load type") + buf.literal("AST node missing for Load type") } // Format formats the node. func (node *ShowBasic) Format(buf *TrackedBuffer) { - buf.WriteString("show") + buf.literal("show") if node.Full { - buf.WriteString(" full") + buf.literal(" full") } buf.astPrintf(node, "%s", node.Command.ToString()) if !node.Tbl.IsEmpty() { @@ -1558,36 +1560,38 @@ func (node *SelectInto) Format(buf *TrackedBuffer) { func (node *CreateDatabase) Format(buf *TrackedBuffer) { buf.astPrintf(node, "create database %v", node.Comments) if node.IfNotExists { - buf.WriteString("if not exists ") + buf.literal("if not exists ") } buf.astPrintf(node, "%v", node.DBName) if node.CreateOptions != nil { for _, createOption := range node.CreateOptions { if createOption.IsDefault { - buf.WriteString(" default") + buf.literal(" default") } - buf.WriteString(createOption.Type.ToString()) - buf.WriteString(" " + createOption.Value) + buf.literal(createOption.Type.ToString()) + buf.WriteByte(' ') + buf.literal(createOption.Value) } } } // Format formats the node. func (node *AlterDatabase) Format(buf *TrackedBuffer) { - buf.WriteString("alter database") + buf.literal("alter database") if !node.DBName.IsEmpty() { buf.astPrintf(node, " %v", node.DBName) } if node.UpdateDataDirectory { - buf.WriteString(" upgrade data directory name") + buf.literal(" upgrade data directory name") } if node.AlterOptions != nil { for _, createOption := range node.AlterOptions { if createOption.IsDefault { - buf.WriteString(" default") + buf.literal(" default") } - buf.WriteString(createOption.Type.ToString()) - buf.WriteString(" " + createOption.Value) + buf.literal(createOption.Type.ToString()) + buf.WriteByte(' ') + buf.literal(createOption.Value) } } } @@ -1596,12 +1600,12 @@ func (node *AlterDatabase) Format(buf *TrackedBuffer) { func (node *CreateTable) Format(buf *TrackedBuffer) { buf.astPrintf(node, "create %v", node.Comments) if node.Temp { - buf.WriteString("temporary ") + buf.literal("temporary ") } - buf.WriteString("table ") + buf.literal("table ") if node.IfNotExists { - buf.WriteString("if not exists ") + buf.literal("if not exists ") } buf.astPrintf(node, "%v", node.Table) @@ -1617,7 +1621,7 @@ func (node *CreateTable) Format(buf *TrackedBuffer) { func (node *CreateView) Format(buf *TrackedBuffer) { buf.astPrintf(node, "create %v", node.Comments) if node.IsReplace { - buf.WriteString("or replace ") + buf.literal("or replace ") } if node.Algorithm != "" { buf.astPrintf(node, "algorithm = %s ", node.Algorithm) @@ -1645,7 +1649,7 @@ func (node *LockTables) Format(buf *TrackedBuffer) { // Format formats the UnlockTables node. func (node *UnlockTables) Format(buf *TrackedBuffer) { - buf.WriteString("unlock tables") + buf.literal("unlock tables") } // Format formats the node. @@ -1668,9 +1672,9 @@ func (node *AlterView) Format(buf *TrackedBuffer) { } func (definer *Definer) Format(buf *TrackedBuffer) { - buf.WriteString(definer.Name) + buf.astPrintf(definer, "%#s", definer.Name) if definer.Address != "" { - buf.astPrintf(definer, "@%s", definer.Address) + buf.astPrintf(definer, "@%#s", definer.Address) } } @@ -1703,7 +1707,7 @@ func (node *AlterTable) Format(buf *TrackedBuffer) { prefix := "" for i, option := range node.AlterOptions { if i != 0 { - buf.WriteString(",") + buf.WriteByte(',') } buf.astPrintf(node, " %v", option) if node.PartitionSpec != nil && node.PartitionSpec.Action != RemoveAction { @@ -1747,7 +1751,7 @@ func (node *AddColumns) Format(buf *TrackedBuffer) { buf.astPrintf(node, ", %v", col) } } - buf.WriteString(")") + buf.WriteByte(')') } } @@ -1790,18 +1794,18 @@ func (node *ModifyColumn) Format(buf *TrackedBuffer) { // Format formats the node func (node *AlterCharset) Format(buf *TrackedBuffer) { - buf.astPrintf(node, "convert to character set %s", node.CharacterSet) + buf.astPrintf(node, "convert to character set %#s", node.CharacterSet) if node.Collate != "" { - buf.astPrintf(node, " collate %s", node.Collate) + buf.astPrintf(node, " collate %#s", node.Collate) } } // Format formats the node func (node *KeyState) Format(buf *TrackedBuffer) { if node.Enable { - buf.WriteString("enable keys") + buf.literal("enable keys") } else { - buf.WriteString("disable keys") + buf.literal("disable keys") } } @@ -1809,9 +1813,9 @@ func (node *KeyState) Format(buf *TrackedBuffer) { // Format formats the node func (node *TablespaceOperation) Format(buf *TrackedBuffer) { if node.Import { - buf.WriteString("import tablespace") + buf.literal("import tablespace") } else { - buf.WriteString("discard tablespace") + buf.literal("discard tablespace") } } @@ -1830,7 +1834,7 @@ func (node *DropKey) Format(buf *TrackedBuffer) { // Format formats the node func (node *Force) Format(buf *TrackedBuffer) { - buf.WriteString("force") + buf.literal("force") } // Format formats the node @@ -1861,9 +1865,9 @@ func (node *RenameIndex) Format(buf *TrackedBuffer) { // Format formats the node func (node *Validation) Format(buf *TrackedBuffer) { if node.With { - buf.WriteString("with validation") + buf.literal("with validation") } else { - buf.WriteString("without validation") + buf.literal("without validation") } } @@ -1871,14 +1875,19 @@ func (node *Validation) Format(buf *TrackedBuffer) { func (node TableOptions) Format(buf *TrackedBuffer) { for i, option := range node { if i != 0 { - buf.WriteString(" ") + buf.WriteByte(' ') } buf.astPrintf(node, "%s", option.Name) - if option.String != "" { - buf.astPrintf(node, " %s", option.String) - } else if option.Value != nil { + switch { + case option.String != "": + if option.CaseSensitive { + buf.astPrintf(node, " %#s", option.String) + } else { + buf.astPrintf(node, " %s", option.String) + } + case option.Value != nil: buf.astPrintf(node, " %v", option.Value) - } else { + default: buf.astPrintf(node, " (%v)", option.Tables) } } @@ -1959,9 +1968,7 @@ func (node *JtOnResponse) Format(buf *TrackedBuffer) { // Format formats the node. func (node Offset) Format(buf *TrackedBuffer) { - buf.WriteString("[") - buf.WriteString(strconv.Itoa(int(node))) - buf.WriteString("]") + buf.astPrintf(node, "[%d]", int(node)) } // Format formats the node. @@ -1977,7 +1984,7 @@ func (node *JSONSchemaValidationReportFuncExpr) Format(buf *TrackedBuffer) { // Format formats the node. func (node *JSONArrayExpr) Format(buf *TrackedBuffer) { //buf.astPrintf(node,"%s(,"node.Name.Lowered()) - buf.WriteString("json_array(") + buf.literal("json_array(") if len(node.Params) > 0 { var prefix string for _, n := range node.Params { @@ -1985,13 +1992,13 @@ func (node *JSONArrayExpr) Format(buf *TrackedBuffer) { prefix = ", " } } - buf.WriteString(")") + buf.WriteByte(')') } // Format formats the node. func (node *JSONObjectExpr) Format(buf *TrackedBuffer) { //buf.astPrintf(node,"%s(,"node.Name.Lowered()) - buf.WriteString("json_object(") + buf.literal("json_object(") if len(node.Params) > 0 { for i, p := range node.Params { if i != 0 { @@ -2001,7 +2008,7 @@ func (node *JSONObjectExpr) Format(buf *TrackedBuffer) { buf.astPrintf(node, "%v", p) } } - buf.WriteString(")") + buf.WriteByte(')') } // Format formats the node. @@ -2018,14 +2025,14 @@ func (node *JSONQuoteExpr) Format(buf *TrackedBuffer) { func (node *JSONContainsExpr) Format(buf *TrackedBuffer) { buf.astPrintf(node, "json_contains(%v, %v", node.Target, node.Candidate) if len(node.PathList) > 0 { - buf.WriteString(", ") + buf.literal(", ") } var prefix string for _, n := range node.PathList { buf.astPrintf(node, "%s%v", prefix, n) prefix = ", " } - buf.WriteString(")") + buf.WriteByte(')') } // Format formats the node @@ -2036,7 +2043,7 @@ func (node *JSONContainsPathExpr) Format(buf *TrackedBuffer) { buf.astPrintf(node, "%s%v", prefix, n) prefix = ", " } - buf.WriteString(")") + buf.WriteByte(')') } // Format formats the node @@ -2047,21 +2054,21 @@ func (node *JSONExtractExpr) Format(buf *TrackedBuffer) { buf.astPrintf(node, "%s%v", prefix, n) prefix = ", " } - buf.WriteString(")") + buf.WriteByte(')') } // Format formats the node func (node *JSONKeysExpr) Format(buf *TrackedBuffer) { buf.astPrintf(node, "json_keys(%v", node.JSONDoc) if len(node.PathList) > 0 { - buf.WriteString(", ") + buf.literal(", ") } var prefix string for _, n := range node.PathList { buf.astPrintf(node, "%s%v", prefix, n) prefix = ", " } - buf.WriteString(")") + buf.WriteByte(')') } // Format formats the node @@ -2076,14 +2083,14 @@ func (node *JSONSearchExpr) Format(buf *TrackedBuffer) { buf.astPrintf(node, ", %v", node.EscapeChar) } if len(node.PathList) > 0 { - buf.WriteString(", ") + buf.literal(", ") } var prefix string for _, n := range node.PathList { buf.astPrintf(node, "%s%v", prefix, n) prefix = ", " } - buf.WriteString(")") + buf.WriteByte(')') } // Format formats the node diff --git a/go/vt/sqlparser/ast_format_fast.go b/go/vt/sqlparser/ast_format_fast.go index f3177906845..0af55b2fb93 100644 --- a/go/vt/sqlparser/ast_format_fast.go +++ b/go/vt/sqlparser/ast_format_fast.go @@ -19,7 +19,6 @@ package sqlparser import ( "fmt" - "strconv" "strings" "vitess.io/vitess/go/sqltypes" @@ -84,13 +83,13 @@ func (node *Union) formatFast(buf *TrackedBuffer) { node.Left.formatFast(buf) } - buf.WriteString(" ") + buf.WriteByte(' ') if node.Distinct { buf.WriteString(UnionStr) } else { buf.WriteString(UnionAllStr) } - buf.WriteString(" ") + buf.WriteByte(' ') if requiresParen(node.Right) { buf.WriteByte('(') @@ -662,7 +661,7 @@ func (node *PartitionOption) formatFast(buf *TrackedBuffer) { } pd.formatFast(buf) } - buf.WriteString(")") + buf.WriteByte(')') } } @@ -1246,7 +1245,7 @@ func (node Columns) formatFast(buf *TrackedBuffer) { n.formatFast(buf) prefix = ", " } - buf.WriteString(")") + buf.WriteByte(')') } // formatFast formats the node @@ -1260,7 +1259,7 @@ func (node Partitions) formatFast(buf *TrackedBuffer) { n.formatFast(buf) prefix = ", " } - buf.WriteString(")") + buf.WriteByte(')') } // formatFast formats the node. @@ -1617,7 +1616,7 @@ func (node *TrimFuncExpr) formatFast(buf *TrackedBuffer) { buf.WriteString("from ") } buf.printExpr(node, node.StringArg, true) - buf.WriteString(")") + buf.WriteByte(')') } // formatFast formats the node. @@ -1819,9 +1818,9 @@ func (node *CaseExpr) formatFast(buf *TrackedBuffer) { func (node *Default) formatFast(buf *TrackedBuffer) { buf.WriteString("default") if node.ColName != "" { - buf.WriteString("(") + buf.WriteByte('(') formatID(buf, node.ColName, NoAt) - buf.WriteString(")") + buf.WriteByte(')') } } @@ -1925,7 +1924,7 @@ func (node SetExprs) formatFast(buf *TrackedBuffer) { func (node *SetExpr) formatFast(buf *TrackedBuffer) { if node.Scope != ImplicitScope { buf.WriteString(node.Scope.ToString()) - buf.WriteString(" ") + buf.WriteByte(' ') } // We don't have to backtick set variable names. switch { @@ -1956,6 +1955,9 @@ func (node OnDup) formatFast(buf *TrackedBuffer) { // formatFast formats the node. func (node ColIdent) formatFast(buf *TrackedBuffer) { + if node.IsEmpty() { + return + } for i := NoAt; i < node.at; i++ { buf.WriteByte('@') } @@ -2061,7 +2063,8 @@ func (node *CreateDatabase) formatFast(buf *TrackedBuffer) { buf.WriteString(" default") } buf.WriteString(createOption.Type.ToString()) - buf.WriteString(" " + createOption.Value) + buf.WriteByte(' ') + buf.WriteString(createOption.Value) } } } @@ -2082,7 +2085,8 @@ func (node *AlterDatabase) formatFast(buf *TrackedBuffer) { buf.WriteString(" default") } buf.WriteString(createOption.Type.ToString()) - buf.WriteString(" " + createOption.Value) + buf.WriteByte(' ') + buf.WriteString(createOption.Value) } } } @@ -2245,7 +2249,7 @@ func (node *AlterTable) formatFast(buf *TrackedBuffer) { prefix := "" for i, option := range node.AlterOptions { if i != 0 { - buf.WriteString(",") + buf.WriteByte(',') } buf.WriteByte(' ') option.formatFast(buf) @@ -2300,7 +2304,7 @@ func (node *AddColumns) formatFast(buf *TrackedBuffer) { col.formatFast(buf) } } - buf.WriteString(")") + buf.WriteByte(')') } } @@ -2447,16 +2451,22 @@ func (node *Validation) formatFast(buf *TrackedBuffer) { func (node TableOptions) formatFast(buf *TrackedBuffer) { for i, option := range node { if i != 0 { - buf.WriteString(" ") + buf.WriteByte(' ') } buf.WriteString(option.Name) - if option.String != "" { - buf.WriteByte(' ') - buf.WriteString(option.String) - } else if option.Value != nil { + switch { + case option.String != "": + if option.CaseSensitive { + buf.WriteByte(' ') + buf.WriteString(option.String) + } else { + buf.WriteByte(' ') + buf.WriteString(option.String) + } + case option.Value != nil: buf.WriteByte(' ') option.Value.formatFast(buf) - } else { + default: buf.WriteString(" (") option.Tables.formatFast(buf) buf.WriteByte(')') @@ -2567,9 +2577,9 @@ func (node *JtOnResponse) formatFast(buf *TrackedBuffer) { // formatFast formats the node. func (node Offset) formatFast(buf *TrackedBuffer) { - buf.WriteString("[") - buf.WriteString(strconv.Itoa(int(node))) - buf.WriteString("]") + buf.WriteByte('[') + buf.WriteString(fmt.Sprintf("%d", int(node))) + buf.WriteByte(']') } // formatFast formats the node. @@ -2602,7 +2612,7 @@ func (node *JSONArrayExpr) formatFast(buf *TrackedBuffer) { prefix = ", " } } - buf.WriteString(")") + buf.WriteByte(')') } // formatFast formats the node. @@ -2618,7 +2628,7 @@ func (node *JSONObjectExpr) formatFast(buf *TrackedBuffer) { p.formatFast(buf) } } - buf.WriteString(")") + buf.WriteByte(')') } // formatFast formats the node. @@ -2650,7 +2660,7 @@ func (node *JSONContainsExpr) formatFast(buf *TrackedBuffer) { buf.printExpr(node, n, true) prefix = ", " } - buf.WriteString(")") + buf.WriteByte(')') } // formatFast formats the node @@ -2666,7 +2676,7 @@ func (node *JSONContainsPathExpr) formatFast(buf *TrackedBuffer) { buf.printExpr(node, n, true) prefix = ", " } - buf.WriteString(")") + buf.WriteByte(')') } // formatFast formats the node @@ -2680,7 +2690,7 @@ func (node *JSONExtractExpr) formatFast(buf *TrackedBuffer) { buf.printExpr(node, n, true) prefix = ", " } - buf.WriteString(")") + buf.WriteByte(')') } // formatFast formats the node @@ -2696,7 +2706,7 @@ func (node *JSONKeysExpr) formatFast(buf *TrackedBuffer) { buf.printExpr(node, n, true) prefix = ", " } - buf.WriteString(")") + buf.WriteByte(')') } // formatFast formats the node @@ -2729,7 +2739,7 @@ func (node *JSONSearchExpr) formatFast(buf *TrackedBuffer) { buf.printExpr(node, n, true) prefix = ", " } - buf.WriteString(")") + buf.WriteByte(')') } // formatFast formats the node diff --git a/go/vt/sqlparser/ast_funcs.go b/go/vt/sqlparser/ast_funcs.go index df0a92d2880..3e70655dd34 100644 --- a/go/vt/sqlparser/ast_funcs.go +++ b/go/vt/sqlparser/ast_funcs.go @@ -58,8 +58,9 @@ type Visit func(node SQLNode) (kontinue bool, err error) func Append(buf *strings.Builder, node SQLNode) { tbuf := &TrackedBuffer{ Builder: buf, + fast: true, } - node.Format(tbuf) + node.formatFast(tbuf) } // IndexColumn describes a column in an index definition with optional length @@ -85,10 +86,11 @@ type IndexOption struct { // TableOption is used for create table options like AUTO_INCREMENT, INSERT_METHOD, etc type TableOption struct { - Name string - Value *Literal - String string - Tables TableNames + Name string + Value *Literal + String string + Tables TableNames + CaseSensitive bool } // ColumnKeyOption indicates whether or not the given column is defined as an @@ -778,7 +780,7 @@ func containEscapableChars(s string, at AtCount) bool { func formatID(buf *TrackedBuffer, original string, at AtCount) { _, isKeyword := keywordLookupTable.LookupString(original) - if isKeyword || containEscapableChars(original, at) { + if buf.escape || isKeyword || containEscapableChars(original, at) { writeEscapedString(buf, original) } else { buf.WriteString(original) diff --git a/go/vt/sqlparser/cached_size.go b/go/vt/sqlparser/cached_size.go index d09c32ca78f..c31dd88f6b4 100644 --- a/go/vt/sqlparser/cached_size.go +++ b/go/vt/sqlparser/cached_size.go @@ -2681,7 +2681,7 @@ func (cached *TableOption) CachedSize(alloc bool) int64 { } size := int64(0) if alloc { - size += int64(64) + size += int64(80) } // field Name string size += hack.RuntimeAllocSize(int64(len(cached.Name))) diff --git a/go/vt/sqlparser/parser.go b/go/vt/sqlparser/parser.go index 747f0419310..5f78b7d14b6 100644 --- a/go/vt/sqlparser/parser.go +++ b/go/vt/sqlparser/parser.go @@ -303,14 +303,3 @@ loop: err = tokenizer.LastError return } - -// String returns a string representation of an SQLNode. -func String(node SQLNode) string { - if node == nil { - return "" - } - - buf := NewTrackedBuffer(nil) - node.formatFast(buf) - return buf.String() -} diff --git a/go/vt/sqlparser/sql.go b/go/vt/sqlparser/sql.go index 664b95bd780..bf0e5ddb98a 100644 --- a/go/vt/sqlparser/sql.go +++ b/go/vt/sqlparser/sql.go @@ -10321,7 +10321,7 @@ yydefault: var yyLOCAL *TableOption //line sql.y:2357 { - yyLOCAL = &TableOption{Name: (string(yyDollar[2].str)), String: yyDollar[4].str} + yyLOCAL = &TableOption{Name: (string(yyDollar[2].str)), String: yyDollar[4].str, CaseSensitive: true} } yyVAL.union = yyLOCAL case 408: @@ -10329,7 +10329,7 @@ yydefault: var yyLOCAL *TableOption //line sql.y:2361 { - yyLOCAL = &TableOption{Name: string(yyDollar[2].str), String: yyDollar[4].str} + yyLOCAL = &TableOption{Name: string(yyDollar[2].str), String: yyDollar[4].str, CaseSensitive: true} } yyVAL.union = yyLOCAL case 409: diff --git a/go/vt/sqlparser/sql.y b/go/vt/sqlparser/sql.y index 04d348d6aa6..356cd6ee41e 100644 --- a/go/vt/sqlparser/sql.y +++ b/go/vt/sqlparser/sql.y @@ -2355,11 +2355,11 @@ table_option: } | default_optional charset_or_character_set equal_opt charset { - $$ = &TableOption{Name:(string($2)), String:$4} + $$ = &TableOption{Name:(string($2)), String:$4, CaseSensitive: true} } | default_optional COLLATE equal_opt charset { - $$ = &TableOption{Name:string($2), String:$4} + $$ = &TableOption{Name:string($2), String:$4, CaseSensitive: true} } | CHECKSUM equal_opt INTEGRAL { diff --git a/go/vt/sqlparser/tracked_buffer.go b/go/vt/sqlparser/tracked_buffer.go index b9211069580..6d332b870e4 100644 --- a/go/vt/sqlparser/tracked_buffer.go +++ b/go/vt/sqlparser/tracked_buffer.go @@ -36,14 +36,54 @@ type TrackedBuffer struct { *strings.Builder bindLocations []bindLocation nodeFormatter NodeFormatter + literal func(string) (int, error) + escape bool + fast bool } // NewTrackedBuffer creates a new TrackedBuffer. func NewTrackedBuffer(nodeFormatter NodeFormatter) *TrackedBuffer { - return &TrackedBuffer{ + buf := &TrackedBuffer{ Builder: new(strings.Builder), nodeFormatter: nodeFormatter, } + buf.literal = buf.WriteString + buf.fast = nodeFormatter == nil + return buf +} + +func (buf *TrackedBuffer) writeStringUpperCase(lit string) (int, error) { + // Upcasing is performed for ASCII only, following MySQL's behavior + buf.Grow(len(lit)) + for i := 0; i < len(lit); i++ { + c := lit[i] + if 'a' <= c && c <= 'z' { + c -= 'a' - 'A' + } + buf.WriteByte(c) + } + return len(lit), nil +} + +// SetUpperCase sets whether all SQL statements formatted by this TrackedBuffer will be normalized into +// uppercase. By default, formatted statements are normalized into lowercase. +// Enabling this option will prevent the optimized fastFormat routines from running. +func (buf *TrackedBuffer) SetUpperCase(enable bool) { + buf.fast = false + if enable { + buf.literal = buf.writeStringUpperCase + } else { + buf.literal = buf.WriteString + } +} + +// SetEscapeAllIdentifiers sets whether ALL identifiers in the serialized SQL query should be quoted +// and escaped. By default, identifiers are only escaped if they match the name of a SQL keyword or they +// contain characters that must be escaped. +// Enabling this option will prevent the optimized fastFormat routines from running. +func (buf *TrackedBuffer) SetEscapeAllIdentifiers(enable bool) { + buf.fast = false + buf.escape = enable } // WriteNode function, initiates the writing of a single SQLNode tree by passing @@ -98,14 +138,20 @@ func (buf *TrackedBuffer) astPrintf(currentNode SQLNode, format string, values . i++ } if i > lasti { - buf.WriteString(format[lasti:i]) + _, _ = buf.literal(format[lasti:i]) } if i >= end { break } i++ // '%' - token := format[i] - switch token { + + caseSensitive := false + if format[i] == '#' { + caseSensitive = true + i++ + } + + switch format[i] { case 'c': switch v := values[fieldnum].(type) { case byte: @@ -117,15 +163,17 @@ func (buf *TrackedBuffer) astPrintf(currentNode SQLNode, format string, values . } case 's': switch v := values[fieldnum].(type) { - case []byte: - buf.Write(v) case string: - buf.WriteString(v) + if caseSensitive { + buf.WriteString(v) + } else { + _, _ = buf.literal(v) + } default: panic(fmt.Sprintf("unexpected TrackedBuffer type %T", v)) } case 'l', 'r', 'v': - left := token != 'r' + left := format[i] != 'r' value := values[fieldnum] expr := getExpressionForParensEval(checkParens, value) @@ -164,10 +212,13 @@ func getExpressionForParensEval(checkParens bool, value any) Expr { } func (buf *TrackedBuffer) formatter(node SQLNode) { - if buf.nodeFormatter == nil { + switch { + case buf.fast: node.formatFast(buf) - } else { + case buf.nodeFormatter != nil: buf.nodeFormatter(buf, node) + default: + node.Format(buf) } } @@ -241,3 +292,28 @@ func BuildParsedQuery(in string, vars ...any) *ParsedQuery { buf.Myprintf(in, vars...) return buf.ParsedQuery() } + +// String returns a string representation of an SQLNode. +func String(node SQLNode) string { + if node == nil { + return "" + } + + buf := NewTrackedBuffer(nil) + node.formatFast(buf) + return buf.String() +} + +// CanonicalString returns a canonical string representation of an SQLNode where all identifiers +// are always escaped and all SQL syntax is in uppercase. This matches the canonical output from MySQL. +func CanonicalString(node SQLNode) string { + if node == nil { + return "" // do not return '', which is Go syntax. + } + + buf := NewTrackedBuffer(nil) + buf.SetUpperCase(true) + buf.SetEscapeAllIdentifiers(true) + node.Format(buf) + return buf.String() +} diff --git a/go/vt/sqlparser/tracked_buffer_test.go b/go/vt/sqlparser/tracked_buffer_test.go index 279b98dd067..53cbe49088c 100644 --- a/go/vt/sqlparser/tracked_buffer_test.go +++ b/go/vt/sqlparser/tracked_buffer_test.go @@ -20,6 +20,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestBuildParsedQuery(t *testing.T) { @@ -49,3 +50,107 @@ func TestBuildParsedQuery(t *testing.T) { }) } } + +func TestCanonicalOutput(t *testing.T) { + testcases := []struct { + input string + canonical string + }{ + { + "create table t(id int)", + "CREATE TABLE `t` (\n\t`id` int\n)", + }, + { + "create algorithm = merge sql security definer view a (b,c,d) as select * from e with cascaded check option", + "CREATE ALGORITHM = MERGE SQL SECURITY DEFINER VIEW `a`(`b`, `c`, `d`) AS SELECT * FROM `e` WITH CASCADED CHECK OPTION", + }, + { + "create or replace algorithm = temptable definer = a@b.c.d sql security definer view a(b,c,d) as select * from e with local check option", + "CREATE OR REPLACE ALGORITHM = TEMPTABLE DEFINER = a@`b.c.d` SQL SECURITY DEFINER VIEW `a`(`b`, `c`, `d`) AS SELECT * FROM `e` WITH LOCAL CHECK OPTION", + }, + { + "create table `a`(`id` int, primary key(`id`))", + "CREATE TABLE `a` (\n\t`id` int,\n\tPRIMARY KEY (`id`)\n)", + }, + { + "create table `a`(`id` int primary key)", + "CREATE TABLE `a` (\n\t`id` int PRIMARY KEY\n)", + }, + { + "create table a (id int not null auto_increment, v varchar(32) default null, v2 varchar(62) charset utf8mb4 collate utf8mb4_0900_ai_ci, key v_idx(v(16)))", + "CREATE TABLE `a` (\n\t`id` int NOT NULL AUTO_INCREMENT,\n\t`v` varchar(32) DEFAULT NULL,\n\t`v2` varchar(62) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci,\n\tKEY `v_idx` (`v`(16))\n)", + }, + { + "create table a (id int not null primary key, dt datetime default current_timestamp)", + "CREATE TABLE `a` (\n\t`id` int NOT NULL PRIMARY KEY,\n\t`dt` datetime DEFAULT CURRENT_TIMESTAMP()\n)", + }, + { + "create table `insert`(`update` int, primary key(`delete`))", + "CREATE TABLE `insert` (\n\t`update` int,\n\tPRIMARY KEY (`delete`)\n)", + }, + { + "alter table a engine=innodb", + "ALTER TABLE `a` ENGINE INNODB", + }, + { + "alter table a comment='a b c'", + "ALTER TABLE `a` COMMENT 'a b c'", + }, + { + "alter table a add column c char not null default 'x'", + "ALTER TABLE `a` ADD COLUMN `c` char NOT NULL DEFAULT 'x'", + }, + { + "alter table t2 modify column id bigint unsigned primary key", + "ALTER TABLE `t2` MODIFY COLUMN `id` bigint UNSIGNED PRIMARY KEY", + }, + { + "alter table t1 modify column a int first, modify column b int after a", + "ALTER TABLE `t1` MODIFY COLUMN `a` int FIRST, MODIFY COLUMN `b` int AFTER `a`", + }, + { + "alter table t1 drop key `PRIMARY`, add primary key (id,n)", + "ALTER TABLE `t1` DROP KEY `PRIMARY`, ADD PRIMARY KEY (`id`, `n`)", + }, + { + "alter table t1 drop foreign key f", + "ALTER TABLE `t1` DROP FOREIGN KEY `f`", + }, + { + "alter table t1 add constraint f foreign key (i) references parent (id) on delete cascade on update set null", + "ALTER TABLE `t1` ADD CONSTRAINT `f` FOREIGN KEY (`i`) REFERENCES `parent` (`id`) ON DELETE CASCADE ON UPDATE SET NULL", + }, + { + "alter table t1 remove partitioning", + "ALTER TABLE `t1` REMOVE PARTITIONING", + }, + { + "alter table t1 partition by hash (id) partitions 5", + "ALTER TABLE `t1` PARTITION BY HASH (`id`) PARTITIONS 5", + }, + { + "alter table t1 partition by list (id) (partition p1 values in (11, 21), partition p2 values in (12, 22))", + "ALTER TABLE `t1` PARTITION BY LIST (`id`) (PARTITION `p1` VALUES IN (11, 21), PARTITION `p2` VALUES IN (12, 22))", + }, + { + "alter table t1 row_format=compressed, character set=utf8", + "ALTER TABLE `t1` ROW_FORMAT COMPRESSED, CHARSET utf8", + }, + } + + for _, tc := range testcases { + t.Run(tc.input, func(t *testing.T) { + tree, err := Parse(tc.input) + require.NoError(t, err, tc.input) + + out := CanonicalString(tree) + require.Equal(t, tc.canonical, out, "bad serialization") + + // Make sure we've generated a valid query! + rereadStmt, err := Parse(out) + require.NoError(t, err, out) + out = CanonicalString(rereadStmt) + require.Equal(t, tc.canonical, out, "bad serialization") + }) + } +}