diff --git a/core/src/main/java/com/alibaba/druid/sql/ast/statement/SQLMergeStatement.java b/core/src/main/java/com/alibaba/druid/sql/ast/statement/SQLMergeStatement.java index 946948ab19..a430e9596d 100644 --- a/core/src/main/java/com/alibaba/druid/sql/ast/statement/SQLMergeStatement.java +++ b/core/src/main/java/com/alibaba/druid/sql/ast/statement/SQLMergeStatement.java @@ -30,6 +30,7 @@ public class SQLMergeStatement extends SQLStatementImpl { private SQLExpr on; private MergeUpdateClause updateClause; private MergeInsertClause insertClause; + private boolean insertClauseFirst; private SQLErrorLoggingClause errorLoggingClause; public void accept0(SQLASTVisitor visitor) { @@ -107,6 +108,14 @@ public List getHints() { return hints; } + public boolean isInsertClauseFirst() { + return insertClauseFirst; + } + + public void setInsertClauseFirst(boolean insertClauseFirst) { + this.insertClauseFirst = insertClauseFirst; + } + public static class MergeUpdateClause extends SQLObjectImpl { private List items = new ArrayList(); private SQLExpr where; diff --git a/core/src/main/java/com/alibaba/druid/sql/parser/SQLStatementParser.java b/core/src/main/java/com/alibaba/druid/sql/parser/SQLStatementParser.java index 84ff10d98a..958c03f6b8 100644 --- a/core/src/main/java/com/alibaba/druid/sql/parser/SQLStatementParser.java +++ b/core/src/main/java/com/alibaba/druid/sql/parser/SQLStatementParser.java @@ -5402,6 +5402,7 @@ public SQLStatement parseMerge() { } stmt.setInsertClause(insertClause); + stmt.setInsertClauseFirst(stmt.getUpdateClause() == null); } if (lexer.token == Token.WHEN) { diff --git a/core/src/main/java/com/alibaba/druid/sql/visitor/SQLASTOutputVisitor.java b/core/src/main/java/com/alibaba/druid/sql/visitor/SQLASTOutputVisitor.java index edf418445c..d228ae3efc 100644 --- a/core/src/main/java/com/alibaba/druid/sql/visitor/SQLASTOutputVisitor.java +++ b/core/src/main/java/com/alibaba/druid/sql/visitor/SQLASTOutputVisitor.java @@ -7961,14 +7961,24 @@ public boolean visit(SQLMergeStatement x) { x.getOn().accept(this); print0(") "); - if (x.getUpdateClause() != null) { - println(); - x.getUpdateClause().accept(this); - } - - if (x.getInsertClause() != null) { - println(); - x.getInsertClause().accept(this); + if (x.isInsertClauseFirst()) { + if (x.getInsertClause() != null) { + println(); + x.getInsertClause().accept(this); + } + if (x.getUpdateClause() != null) { + println(); + x.getUpdateClause().accept(this); + } + } else { + if (x.getUpdateClause() != null) { + println(); + x.getUpdateClause().accept(this); + } + if (x.getInsertClause() != null) { + println(); + x.getInsertClause().accept(this); + } } if (x.getErrorLoggingClause() != null) { diff --git a/core/src/test/java/com/alibaba/druid/bvt/sql/oracle/OracleMergeTest10.java b/core/src/test/java/com/alibaba/druid/bvt/sql/oracle/OracleMergeTest10.java index 1e2bc3df9a..51ae1f270c 100644 --- a/core/src/test/java/com/alibaba/druid/bvt/sql/oracle/OracleMergeTest10.java +++ b/core/src/test/java/com/alibaba/druid/bvt/sql/oracle/OracleMergeTest10.java @@ -41,15 +41,15 @@ public void test_0() throws Exception { SQLMergeStatement mergeStatement = (SQLMergeStatement) stmtList.get(0); String result = SQLUtils.toOracleString(mergeStatement); Assert.assertEquals("MERGE INTO bonuses d\n" + - "USING (\n" + - "\tSELECT employee_id.*\n" + - "\tFROM employees\n" + - ") s ON (employee_id = a) \n" + - "WHEN MATCHED THEN UPDATE SET d.bonus = bonus\n" + - "\tDELETE WHERE salary > 8000\n" + - "WHEN NOT MATCHED THEN INSERT (d.employee_id, d.bonus) VALUES (s.employee_id, s.salary)\n" + - "\tWHERE s.salary <= 8000", - result); + "USING (\n" + + "\tSELECT employee_id.*\n" + + "\tFROM employees\n" + + ") s ON (employee_id = a) \n" + + "WHEN NOT MATCHED THEN INSERT (d.employee_id, d.bonus) VALUES (s.employee_id, s.salary)\n" + + "\tWHERE s.salary <= 8000\n" + + "WHEN MATCHED THEN UPDATE SET d.bonus = bonus\n" + + "\tDELETE WHERE salary > 8000", + result); // Assert.assertTrue(visitor.getColumns().contains(new TableStat.Column("employees", "employee_id"))); // Assert.assertTrue(visitor.getColumns().contains(new TableStat.Column("employees", "salary"))); // Assert.assertTrue(visitor.getColumns().contains(new TableStat.Column("employees", "department_id"))); diff --git a/core/src/test/java/com/alibaba/druid/bvt/sql/oracle/issues/Issue5631.java b/core/src/test/java/com/alibaba/druid/bvt/sql/oracle/issues/Issue5631.java new file mode 100644 index 0000000000..eb841762b8 --- /dev/null +++ b/core/src/test/java/com/alibaba/druid/bvt/sql/oracle/issues/Issue5631.java @@ -0,0 +1,48 @@ +package com.alibaba.druid.bvt.sql.oracle.issues; + +import com.alibaba.druid.DbType; +import com.alibaba.druid.bvt.sql.mysql.issues.Issue5421; +import com.alibaba.druid.sql.ast.SQLStatement; +import com.alibaba.druid.sql.parser.SQLParserUtils; +import com.alibaba.druid.sql.parser.SQLStatementParser; + +import org.junit.Test; + +import static org.junit.Assert.assertTrue; + +/** + * 验证 Oracle merge sql的顺序问题 + * + * @author lizongbo + * @see Issue来源 + */ +public class Issue5631 { + + @Test + public void test_merge_into() throws Exception { + for (DbType dbType : new DbType[]{DbType.oracle}) { + for (String sql : new String[]{ + "MERGE INTO target_table\n" + + "USING source_table ON (target_table.id = source_table.id)\n" + + "WHEN NOT MATCHED THEN INSERT (id, column1) VALUES (source_table.id, source_table.column1)\n" + + "WHEN MATCHED THEN UPDATE SET target_table.column1 = source_table.column1", + "MERGE INTO target_table\n" + + "USING source_table ON (target_table.id = source_table.id)\n" + + "WHEN MATCHED THEN UPDATE SET target_table.column1 = source_table.column1\n" + + "WHEN NOT MATCHED THEN INSERT (id, column1) VALUES (source_table.id, source_table.column1)", + }) { + SQLStatementParser parser = SQLParserUtils.createSQLStatementParser(sql, dbType); + SQLStatement statement = parser.parseStatement(); + System.out.println(dbType + "原始的sql===" + sql); + System.out.println(dbType + "原始sql归一化===" + Issue5421.normalizeSql(sql)); + String newSql = statement.toString(); + System.out.println(dbType + "初次解析生成的sql===" + newSql); + System.out.println(dbType + "初次解析生成的sql归一化===" + Issue5421.normalizeSql(newSql)); + parser = SQLParserUtils.createSQLStatementParser(newSql, dbType); + statement = parser.parseStatement(); + System.out.println(dbType + "重新解析sql归一化===" + Issue5421.normalizeSql(statement.toString())); + assertTrue(Issue5421.normalizeSql(sql).equalsIgnoreCase(Issue5421.normalizeSql(statement.toString()))); + } + } + } +} diff --git a/core/src/test/java/com/alibaba/druid/bvt/sql/postgresql/PGMergeIntoTest0.java b/core/src/test/java/com/alibaba/druid/bvt/sql/postgresql/PGMergeIntoTest0.java index c89cde63f2..a23b14c305 100644 --- a/core/src/test/java/com/alibaba/druid/bvt/sql/postgresql/PGMergeIntoTest0.java +++ b/core/src/test/java/com/alibaba/druid/bvt/sql/postgresql/PGMergeIntoTest0.java @@ -39,12 +39,12 @@ public void test_0() throws Exception { SQLStatement stmt = statementList.get(0); assertEquals("MERGE INTO CustomerAccount CA\n" + - "USING (\n" + - "\tSELECT CustomerId, TransactionValue\n" + - "\tFROM RecentTransactions\n" + - ") T ON (CA.CustomerId = T.CustomerId) \n" + - "WHEN MATCHED THEN UPDATE SET Balance = Balance + TransactionValue\n" + - "WHEN NOT MATCHED THEN INSERT (CustomerId, Balance) VALUES (T.CustomerId, T.TransactionValue);", stmt.toString()); + "USING (\n" + + "\tSELECT CustomerId, TransactionValue\n" + + "\tFROM RecentTransactions\n" + + ") T ON (CA.CustomerId = T.CustomerId) \n" + + "WHEN NOT MATCHED THEN INSERT (CustomerId, Balance) VALUES (T.CustomerId, T.TransactionValue)\n" + + "WHEN MATCHED THEN UPDATE SET Balance = Balance + TransactionValue;" , stmt.toString()); assertEquals(1, statementList.size());