Skip to content

Commit

Permalink
优化merge sql 输出逻辑,确保解析后输出顺序保持不变 alibaba#5631
Browse files Browse the repository at this point in the history
优化merge sql 输出逻辑,确保解析后输出顺序保持不变 alibaba#5631
  • Loading branch information
lizongbo committed Dec 21, 2023
1 parent a1236b1 commit b2558ae
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ public class SQLMergeStatement extends SQLStatementImpl {
private SQLExpr on;
private MergeUpdateClause updateClause;
private MergeInsertClause insertClause;
private boolean insertClauseFirst;
private SQLErrorLoggingClause errorLoggingClause;

public void accept0(SQLASTVisitor visitor) {
Expand Down Expand Up @@ -107,6 +108,14 @@ public List<SQLHint> getHints() {
return hints;
}

public boolean isInsertClauseFirst() {
return insertClauseFirst;
}

public void setInsertClauseFirst(boolean insertClauseFirst) {
this.insertClauseFirst = insertClauseFirst;
}

public static class MergeUpdateClause extends SQLObjectImpl {
private List<SQLUpdateSetItem> items = new ArrayList<SQLUpdateSetItem>();
private SQLExpr where;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5402,6 +5402,7 @@ public SQLStatement parseMerge() {
}

stmt.setInsertClause(insertClause);
stmt.setInsertClauseFirst(stmt.getUpdateClause() == null);
}

if (lexer.token == Token.WHEN) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7961,14 +7961,24 @@ public boolean visit(SQLMergeStatement x) {
x.getOn().accept(this);
print0(") ");

if (x.getUpdateClause() != null) {
println();
x.getUpdateClause().accept(this);
}

if (x.getInsertClause() != null) {
println();
x.getInsertClause().accept(this);
if (x.isInsertClauseFirst()) {
if (x.getInsertClause() != null) {
println();
x.getInsertClause().accept(this);
}
if (x.getUpdateClause() != null) {
println();
x.getUpdateClause().accept(this);
}
} else {
if (x.getUpdateClause() != null) {
println();
x.getUpdateClause().accept(this);
}
if (x.getInsertClause() != null) {
println();
x.getInsertClause().accept(this);
}
}

if (x.getErrorLoggingClause() != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,15 @@ public void test_0() throws Exception {
SQLMergeStatement mergeStatement = (SQLMergeStatement) stmtList.get(0);
String result = SQLUtils.toOracleString(mergeStatement);
Assert.assertEquals("MERGE INTO bonuses d\n" +
"USING (\n" +
"\tSELECT employee_id.*\n" +
"\tFROM employees\n" +
") s ON (employee_id = a) \n" +
"WHEN MATCHED THEN UPDATE SET d.bonus = bonus\n" +
"\tDELETE WHERE salary > 8000\n" +
"WHEN NOT MATCHED THEN INSERT (d.employee_id, d.bonus) VALUES (s.employee_id, s.salary)\n" +
"\tWHERE s.salary <= 8000",
result);
"USING (\n" +
"\tSELECT employee_id.*\n" +
"\tFROM employees\n" +
") s ON (employee_id = a) \n" +
"WHEN NOT MATCHED THEN INSERT (d.employee_id, d.bonus) VALUES (s.employee_id, s.salary)\n" +
"\tWHERE s.salary <= 8000\n" +
"WHEN MATCHED THEN UPDATE SET d.bonus = bonus\n" +
"\tDELETE WHERE salary > 8000",
result);
// Assert.assertTrue(visitor.getColumns().contains(new TableStat.Column("employees", "employee_id")));
// Assert.assertTrue(visitor.getColumns().contains(new TableStat.Column("employees", "salary")));
// Assert.assertTrue(visitor.getColumns().contains(new TableStat.Column("employees", "department_id")));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package com.alibaba.druid.bvt.sql.oracle.issues;

import com.alibaba.druid.DbType;
import com.alibaba.druid.bvt.sql.mysql.issues.Issue5421;
import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.parser.SQLParserUtils;
import com.alibaba.druid.sql.parser.SQLStatementParser;

import org.junit.Test;

import static org.junit.Assert.assertTrue;

/**
* 验证 Oracle merge sql的顺序问题
*
* @author lizongbo
* @see <a href="https://github.com/alibaba/druid/issues/5631">Issue来源</a>
*/
public class Issue5631 {

@Test
public void test_merge_into() throws Exception {
for (DbType dbType : new DbType[]{DbType.oracle}) {
for (String sql : new String[]{
"MERGE INTO target_table\n"
+ "USING source_table ON (target_table.id = source_table.id)\n"
+ "WHEN NOT MATCHED THEN INSERT (id, column1) VALUES (source_table.id, source_table.column1)\n"
+ "WHEN MATCHED THEN UPDATE SET target_table.column1 = source_table.column1",
"MERGE INTO target_table\n"
+ "USING source_table ON (target_table.id = source_table.id)\n"
+ "WHEN MATCHED THEN UPDATE SET target_table.column1 = source_table.column1\n"
+ "WHEN NOT MATCHED THEN INSERT (id, column1) VALUES (source_table.id, source_table.column1)",
}) {
SQLStatementParser parser = SQLParserUtils.createSQLStatementParser(sql, dbType);
SQLStatement statement = parser.parseStatement();
System.out.println(dbType + "原始的sql===" + sql);
System.out.println(dbType + "原始sql归一化===" + Issue5421.normalizeSql(sql));
String newSql = statement.toString();
System.out.println(dbType + "初次解析生成的sql===" + newSql);
System.out.println(dbType + "初次解析生成的sql归一化===" + Issue5421.normalizeSql(newSql));
parser = SQLParserUtils.createSQLStatementParser(newSql, dbType);
statement = parser.parseStatement();
System.out.println(dbType + "重新解析sql归一化===" + Issue5421.normalizeSql(statement.toString()));
assertTrue(Issue5421.normalizeSql(sql).equalsIgnoreCase(Issue5421.normalizeSql(statement.toString())));
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ public void test_0() throws Exception {
SQLStatement stmt = statementList.get(0);

assertEquals("MERGE INTO CustomerAccount CA\n" +
"USING (\n" +
"\tSELECT CustomerId, TransactionValue\n" +
"\tFROM RecentTransactions\n" +
") T ON (CA.CustomerId = T.CustomerId) \n" +
"WHEN MATCHED THEN UPDATE SET Balance = Balance + TransactionValue\n" +
"WHEN NOT MATCHED THEN INSERT (CustomerId, Balance) VALUES (T.CustomerId, T.TransactionValue);", stmt.toString());
"USING (\n" +
"\tSELECT CustomerId, TransactionValue\n" +
"\tFROM RecentTransactions\n" +
") T ON (CA.CustomerId = T.CustomerId) \n" +
"WHEN NOT MATCHED THEN INSERT (CustomerId, Balance) VALUES (T.CustomerId, T.TransactionValue)\n" +
"WHEN MATCHED THEN UPDATE SET Balance = Balance + TransactionValue;" , stmt.toString());

assertEquals(1, statementList.size());

Expand Down

0 comments on commit b2558ae

Please sign in to comment.