From 0e4a3d8cad251162973283c986c164e48c2e384b Mon Sep 17 00:00:00 2001 From: LiZongbo Date: Sun, 5 May 2024 19:57:17 +0800 Subject: [PATCH] =?UTF-8?q?=E9=92=88=E5=AF=B9mysql=E5=A2=9E=E5=8A=A0with?= =?UTF-8?q?=E5=AD=90=E6=9F=A5=E8=AF=A2=E7=9A=84=E8=A7=A3=E6=9E=90=E6=94=AF?= =?UTF-8?q?=E6=8C=81=20#5761?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 针对mysql增加with子查询的解析支持 #5761 --- .../druid/sql/ast/expr/SQLSelectExpr.java | 114 ++++++++ .../druid/sql/parser/SQLExprParser.java | 3 - .../druid/sql/parser/SQLSelectParser.java | 27 ++ .../sql/visitor/SQLASTOutputVisitor.java | 17 +- .../druid/sql/visitor/SQLASTVisitor.java | 7 + .../druid/bvt/sql/mysql/issues/Issue5761.java | 275 ++++++++++++++++++ .../druid/bvt/sql/mysql/issues/Issue5797.java | 1 - .../druid/bvt/sql/mysql/issues/Issue5803.java | 1 - .../druid/bvt/sql/odps/issues/Issue5791.java | 1 - 9 files changed, 439 insertions(+), 7 deletions(-) create mode 100644 core/src/main/java/com/alibaba/druid/sql/ast/expr/SQLSelectExpr.java create mode 100644 core/src/test/java/com/alibaba/druid/bvt/sql/mysql/issues/Issue5761.java diff --git a/core/src/main/java/com/alibaba/druid/sql/ast/expr/SQLSelectExpr.java b/core/src/main/java/com/alibaba/druid/sql/ast/expr/SQLSelectExpr.java new file mode 100644 index 0000000000..a2357db06c --- /dev/null +++ b/core/src/main/java/com/alibaba/druid/sql/ast/expr/SQLSelectExpr.java @@ -0,0 +1,114 @@ +/* + * Copyright 1999-2024 Alibaba Group Holding Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.alibaba.druid.sql.ast.expr; + +import com.alibaba.druid.sql.ast.*; +import com.alibaba.druid.sql.ast.statement.SQLSelect; +import com.alibaba.druid.sql.visitor.SQLASTVisitor; + +import java.io.Serializable; + +/** + * @author lizongbo + */ +public class SQLSelectExpr extends SQLExprImpl implements SQLReplaceable, Serializable { + private static final long serialVersionUID = 1L; + protected SQLSelect sqlSelect; + protected SQLExpr as; + + public SQLSelectExpr() { + } + + public SQLSelectExpr(SQLSelect sqlSelect) { + this.sqlSelect = sqlSelect; + } + + public SQLSelect getSqlSelect() { + return sqlSelect; + } + + public void setSqlSelect(SQLSelect sqlSelect) { + this.sqlSelect = sqlSelect; + } + + public SQLExpr getAs() { + return as; + } + + public void setAs(SQLExpr as) { + this.as = as; + } + + public SQLSelectExpr clone() { + SQLSelectExpr x = new SQLSelectExpr(); + if (sqlSelect != null) { + x.setSqlSelect(sqlSelect.clone()); + } + if (as != null) { + x.setAs(as.clone()); + } + x.setParenthesized(parenthesized); + return x; + } + + protected void accept0(SQLASTVisitor visitor) { + if (visitor.visit(this)) { + if (this.sqlSelect != null) { + this.sqlSelect.accept(visitor); + } + } + visitor.endVisit(this); + } + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + ((sqlSelect == null) ? 0 : sqlSelect.hashCode()); + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null) { + return false; + } + if (getClass() != obj.getClass()) { + return false; + } + SQLSelectExpr other = (SQLSelectExpr) obj; + if (sqlSelect == null) { + if (other.sqlSelect != null) { + return false; + } + } else if (!sqlSelect.equals(other.sqlSelect)) { + return false; + } + return true; + } + + @Override + public SQLDataType computeDataType() { + return SQLBooleanExpr.DATA_TYPE; + } + + @Override + public boolean replace(SQLExpr expr, SQLExpr target) { + return false; + } +} diff --git a/core/src/main/java/com/alibaba/druid/sql/parser/SQLExprParser.java b/core/src/main/java/com/alibaba/druid/sql/parser/SQLExprParser.java index c0c84ceb9a..ee44eb107f 100644 --- a/core/src/main/java/com/alibaba/druid/sql/parser/SQLExprParser.java +++ b/core/src/main/java/com/alibaba/druid/sql/parser/SQLExprParser.java @@ -395,7 +395,6 @@ public SQLExpr primary() { sqlExpr = new SQLMethodInvokeExpr(); break; } - sqlExpr = expr(); if (lexer.token == Token.COMMA) { @@ -5912,7 +5911,6 @@ public SQLSelectItem parseSelectItem() { boolean connectByRoot = false; Token token = lexer.token; int startPos = lexer.startPos; - if (token == Token.IDENTIFIER && !(lexer.hashLCase() == -5808529385363204345L && lexer.charAt(lexer.pos) == '\'' && dbType == DbType.mysql) // x'123' X'123' ) { @@ -6474,7 +6472,6 @@ public SQLSelectItem parseSelectItem() { while (lexer.token == Token.HINT) { lexer.nextToken(); } - expr = expr(); } diff --git a/core/src/main/java/com/alibaba/druid/sql/parser/SQLSelectParser.java b/core/src/main/java/com/alibaba/druid/sql/parser/SQLSelectParser.java index 5a4a3e2bc6..e6f014ee0e 100644 --- a/core/src/main/java/com/alibaba/druid/sql/parser/SQLSelectParser.java +++ b/core/src/main/java/com/alibaba/druid/sql/parser/SQLSelectParser.java @@ -24,6 +24,7 @@ import com.alibaba.druid.sql.dialect.hive.parser.HiveCreateTableParser; import com.alibaba.druid.sql.dialect.hive.stmt.HiveCreateTableStatement; import com.alibaba.druid.sql.dialect.mysql.ast.expr.MySqlOrderingExpr; +import com.alibaba.druid.sql.parser.Lexer.SavePoint; import com.alibaba.druid.util.FnvHash; import com.alibaba.druid.util.StringUtils; @@ -1074,6 +1075,32 @@ protected SQLExpr parseGroupByItem() { protected void parseSelectList(SQLSelectQueryBlock queryBlock) { final List selectList = queryBlock.getSelectList(); for (; ; ) { + SavePoint savePoint = lexer.markOut(); + if (lexer.token() == Token.LPAREN) { + lexer.nextToken(); + if (lexer.token() == Token.WITH) { + String alias = null; + SQLSelect select = select(); + SQLSelectExpr sqlSelectExpr = new SQLSelectExpr(select); + sqlSelectExpr.setParenthesized(true); + SQLSelectItem selectItem = new SQLSelectItem(sqlSelectExpr, alias, false); + selectList.add(selectItem); + selectItem.setParent(queryBlock); + accept(Token.RPAREN); + if (lexer.token() == Token.AS) { + accept(Token.AS); + sqlSelectExpr.setAs(new SQLIdentifierExpr(lexer.stringVal())); + lexer.nextToken(); + if (lexer.token != Token.COMMA) { + break; + } + lexer.nextToken(); + } + } else { + lexer.reset(savePoint); + } + } + final SQLSelectItem selectItem = this.exprParser.parseSelectItem(); selectList.add(selectItem); selectItem.setParent(queryBlock); diff --git a/core/src/main/java/com/alibaba/druid/sql/visitor/SQLASTOutputVisitor.java b/core/src/main/java/com/alibaba/druid/sql/visitor/SQLASTOutputVisitor.java index 77c21f33a8..c5b1ebb717 100644 --- a/core/src/main/java/com/alibaba/druid/sql/visitor/SQLASTOutputVisitor.java +++ b/core/src/main/java/com/alibaba/druid/sql/visitor/SQLASTOutputVisitor.java @@ -1152,11 +1152,26 @@ protected final void printExpr(SQLExpr x, boolean parameterized) { visit((SQLInListExpr) x); } else if (clazz == SQLNotExpr.class) { visit((SQLNotExpr) x); + } else if (clazz == SQLSelectExpr.class) { + visit((SQLSelectExpr) x); } else { x.accept(this); } } - + public boolean visit(SQLSelectExpr x) { + if (x.isParenthesized()) { + print('('); + } + visit(x.getSqlSelect()); + if (x.isParenthesized()) { + print(')'); + } + if (x.getAs() != null) { + print0(ucase ? " AS " : " as "); + printExpr(x.getAs()); + } + return false; + } public boolean visit(SQLCaseExpr x) { if (x.isParenthesized()) { print('('); diff --git a/core/src/main/java/com/alibaba/druid/sql/visitor/SQLASTVisitor.java b/core/src/main/java/com/alibaba/druid/sql/visitor/SQLASTVisitor.java index 2b1a204687..8bc7db7ea6 100644 --- a/core/src/main/java/com/alibaba/druid/sql/visitor/SQLASTVisitor.java +++ b/core/src/main/java/com/alibaba/druid/sql/visitor/SQLASTVisitor.java @@ -2564,4 +2564,11 @@ default boolean visit(SQLCostStatement x) { default void endVisit(SQLCostStatement x) { } + default boolean visit(SQLSelectExpr x) { + return true; + } + + default void endVisit(SQLSelectExpr x) { + } + } diff --git a/core/src/test/java/com/alibaba/druid/bvt/sql/mysql/issues/Issue5761.java b/core/src/test/java/com/alibaba/druid/bvt/sql/mysql/issues/Issue5761.java new file mode 100644 index 0000000000..00700dc69e --- /dev/null +++ b/core/src/test/java/com/alibaba/druid/bvt/sql/mysql/issues/Issue5761.java @@ -0,0 +1,275 @@ +package com.alibaba.druid.bvt.sql.mysql.issues; + +import java.util.List; + +import com.alibaba.druid.DbType; +import com.alibaba.druid.sql.ast.SQLStatement; +import com.alibaba.druid.sql.parser.SQLParserUtils; +import com.alibaba.druid.sql.parser.SQLStatementParser; + +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +/** + * @author lizongbo + * @see WITH (Common Table Expressions) + */ +public class Issue5761 { + + @Test + public void test_parse_with() { + for (DbType dbType : new DbType[]{ + DbType.mysql, + DbType.mariadb, + + }) { + + for (String sql : new String[]{ + "WITH\n" + + " cte1 AS (SELECT a, b FROM table1),\n" + + " cte2 AS (SELECT c, d FROM table2)\n" + + "SELECT b, d FROM cte1 JOIN cte2\n" + + "WHERE cte1.a = cte2.c;", + "WITH cte (col1, col2) AS\n" + + "(\n" + + " SELECT 1, 2\n" + + " UNION ALL\n" + + " SELECT 3, 4\n" + + ")\n" + + "SELECT col1, col2 FROM cte;", + "WITH cte AS\n" + + "(\n" + + " SELECT 1 AS col1, 2 AS col2\n" + + " UNION ALL\n" + + " SELECT 3, 4\n" + + ")\n" + + "SELECT col1, col2 FROM cte;", + "WITH cte1 AS (SELECT 1)\n" + + "SELECT * FROM (WITH cte2 AS (SELECT 2) SELECT * FROM cte2 JOIN cte1) AS dt;", + "WITH RECURSIVE cte (n) AS\n" + + "(\n" + + " SELECT 1\n" + + " UNION ALL\n" + + " SELECT n + 1 FROM cte WHERE n < 5\n" + + ")\n" + + "SELECT * FROM cte;", + "WITH RECURSIVE cte AS\n" + + "(\n" + + " SELECT 1 AS n, 'abc' AS str\n" + + " UNION ALL\n" + + " SELECT n + 1, CONCAT(str, str) FROM cte WHERE n < 3\n" + + ")\n" + + "SELECT * FROM cte;", + "WITH RECURSIVE cte AS\n" + + "(\n" + + " SELECT 1 AS n, CAST('abc' AS CHAR(20)) AS str\n" + + " UNION ALL\n" + + " SELECT n + 1, CONCAT(str, str) FROM cte WHERE n < 3\n" + + ")\n" + + "SELECT * FROM cte;", + "WITH RECURSIVE cte AS\n" + + "(\n" + + " SELECT 1 AS n, 1 AS p, -1 AS q\n" + + " UNION ALL\n" + + " SELECT n + 1, q * 2, p * 2 FROM cte WHERE n < 5\n" + + ")\n" + + "SELECT * FROM cte;", + "WITH RECURSIVE cte (n) AS\n" + + "(\n" + + " SELECT 1\n" + + " UNION ALL\n" + + " SELECT n + 1 FROM cte\n" + + ")\n" + + "SELECT * FROM cte;", + "WITH RECURSIVE cte (n) AS\n" + + "(\n" + + " SELECT 1\n" + + " UNION ALL\n" + + " SELECT n + 1 FROM cte\n" + + ")\n" + + "SELECT /*+ SET_VAR(cte_max_recursion_depth = 1M) */ * FROM cte;", + "WITH RECURSIVE cte (n) AS\n" + + "(\n" + + " SELECT 1\n" + + " UNION ALL\n" + + " SELECT n + 1 FROM cte\n" + + ")\n" + + "SELECT /*+ MAX_EXECUTION_TIME(1000) */ * FROM cte;", + "WITH RECURSIVE cte (n) AS\n" + + "(\n" + + " SELECT 1\n" + + " UNION ALL\n" + + " SELECT n + 1 FROM cte LIMIT 10000\n" + + ")\n" + + "SELECT * FROM cte;", + "WITH RECURSIVE cte (n) AS\n" + + "(\n" + + " SELECT 1\n" + + " UNION ALL\n" + + " SELECT n + 1 FROM cte LIMIT 10000\n" + + ")\n" + + "SELECT /*+ MAX_EXECUTION_TIME(1000) */ * FROM cte;", + "WITH RECURSIVE fibonacci (n, fib_n, next_fib_n) AS\n" + + "(\n" + + " SELECT 1, 0, 1\n" + + " UNION ALL\n" + + " SELECT n + 1, next_fib_n, fib_n + next_fib_n\n" + + " FROM fibonacci WHERE n < 10\n" + + ")\n" + + "SELECT * FROM fibonacci;", + "WITH RECURSIVE dates (date) AS\n" + + "(\n" + + " SELECT MIN(date) FROM sales\n" + + " UNION ALL\n" + + " SELECT date + INTERVAL 1 DAY FROM dates\n" + + " WHERE date + INTERVAL 1 DAY <= (SELECT MAX(date) FROM sales)\n" + + ")\n" + + "SELECT * FROM dates;", + "WITH RECURSIVE dates (date) AS\n" + + "(\n" + + " SELECT MIN(date) FROM sales\n" + + " UNION ALL\n" + + " SELECT date + INTERVAL 1 DAY FROM dates\n" + + " WHERE date + INTERVAL 1 DAY <= (SELECT MAX(date) FROM sales)\n" + + ")\n" + + "SELECT dates.date, COALESCE(SUM(price), 0) AS sum_price\n" + + "FROM dates LEFT JOIN sales ON dates.date = sales.date\n" + + "GROUP BY dates.date\n" + + "ORDER BY dates.date;", + "WITH RECURSIVE employee_paths (id, name, path) AS\n" + + "(\n" + + " SELECT id, name, CAST(id AS CHAR(200))\n" + + " FROM employees\n" + + " WHERE manager_id IS NULL\n" + + " UNION ALL\n" + + " SELECT e.id, e.name, CONCAT(ep.path, ',', e.id)\n" + + " FROM employee_paths AS ep JOIN employees AS e\n" + + " ON ep.id = e.manager_id\n" + + ")\n" + + "SELECT * FROM employee_paths ORDER BY path;", + "WITH cte AS (SELECT 1) SELECT * FROM cte;", + "WITH RECURSIVE cte (n) AS\n" + + "(\n" + + " SELECT 1\n" + + " UNION ALL\n" + + " SELECT n + 1 FROM cte WHERE n < 5\n" + + ")\n" + + "SELECT * FROM cte;", + "select\n" + + " id,\n" + + " (aa),\n" + + " (bb) cc,\n" + + " (\n" + + " WITH RECURSIVE link_hierarchy AS (\n" + + " SELECT id, parent_id\n" + + " FROM tmp_link\n" + + " WHERE id = ?\n" + + "\n" + + " UNION ALL\n" + + "\n" + + " SELECT tl.id, tl.parent_id\n" + + " FROM tmp_link tl\n" + + " INNER JOIN link_hierarchy lh ON tl.id = lh.parent_id\n" + + " )\n" + + " SELECT CONCAT('/', GROUP_CONCAT(id ORDER BY id ASC SEPARATOR '/')) AS pathaaa\n" + + " FROM link_hierarchy\n" + + " ) as pathbbb , qqqq\n" + + " from tmp_link;", + + "select\n" + + " (\n" + + " WITH RECURSIVE link_hierarchy AS (\n" + + " SELECT id, parent_id\n" + + " FROM tmp_link\n" + + " WHERE id = ?\n" + + "\n" + + " UNION ALL\n" + + "\n" + + " SELECT tl.id, tl.parent_id\n" + + " FROM tmp_link tl\n" + + " INNER JOIN link_hierarchy lh ON tl.id = lh.parent_id\n" + + " )\n" + + " SELECT CONCAT('/', GROUP_CONCAT(id ORDER BY id ASC SEPARATOR '/')) AS pathaaa\n" + + " FROM link_hierarchy\n" + + " ) as pathbbb \n" + + " from tmp_link;", + + "select\n" + + " (\n" + + " WITH RECURSIVE link_hierarchy AS (\n" + + " SELECT id, parent_id\n" + + " FROM tmp_link\n" + + " WHERE id = ?\n" + + "\n" + + " UNION ALL\n" + + "\n" + + " SELECT tl.id, tl.parent_id\n" + + " FROM tmp_link tl\n" + + " INNER JOIN link_hierarchy lh ON tl.id = lh.parent_id\n" + + " )\n" + + " SELECT CONCAT('/', GROUP_CONCAT(id ORDER BY id ASC SEPARATOR '/')) AS pathaaa\n" + + " FROM link_hierarchy\n" + + " ) as pathbbb , qqqq\n" + + " from tmp_link;", + + "select\n" + + " (\n" + + " WITH link_hierarchy AS (\n" + + " SELECT id, parent_id\n" + + " FROM tmp_link\n" + + " WHERE id = ?\n" + + "\n" + + " UNION ALL\n" + + "\n" + + " SELECT tl.id, tl.parent_id\n" + + " FROM tmp_link tl\n" + + " INNER JOIN link_hierarchy lh ON tl.id = lh.parent_id\n" + + " )\n" + + " SELECT CONCAT('/', GROUP_CONCAT(id ORDER BY id ASC SEPARATOR '/')) AS pathaaa\n" + + " FROM link_hierarchy\n" + + " ) as pathbbb , qwerty\n" + + " from tmp_link;", + + + "select\n" + + " id,\n" + + " (aa),\n" + + " (bb) cc,\n" + + " (\n" + + " WITH RECURSIVE link_hierarchy AS (\n" + + " SELECT id, parent_id\n" + + " FROM tmp_link\n" + + " WHERE id = ?\n" + + "\n" + + " UNION ALL\n" + + "\n" + + " SELECT tl.id, tl.parent_id\n" + + " FROM tmp_link tl\n" + + " INNER JOIN link_hierarchy lh ON tl.id = lh.parent_id\n" + + " )\n" + + " SELECT CONCAT('/', GROUP_CONCAT(id ORDER BY id ASC SEPARATOR '/')) AS pathaaa\n" + + " FROM link_hierarchy\n" + + " ) as pathbbb\n" + + " from tmp_link;", + }) { + System.out.println(dbType + "原始的sql===" + sql); + SQLStatementParser parser = SQLParserUtils.createSQLStatementParser(sql, dbType); + List statementList = parser.parseStatementList(); + String sqlGen = statementList.toString(); + System.out.println(dbType + "首次解析生成的sql===" + sqlGen); + StringBuilder sb = new StringBuilder(); + for (SQLStatement statement : statementList) { + sb.append(statement.toString()).append(";"); + } + sb.deleteCharAt(sb.length() - 1); + parser = SQLParserUtils.createSQLStatementParser(sb.toString(), dbType); + List statementListNew = parser.parseStatementList(); + String sqlGenNew = statementList.toString(); + System.out.println(dbType + "再次解析生成的sql===" + sqlGenNew); + assertEquals(statementList.toString(), statementListNew.toString()); + } + } + } +} diff --git a/core/src/test/java/com/alibaba/druid/bvt/sql/mysql/issues/Issue5797.java b/core/src/test/java/com/alibaba/druid/bvt/sql/mysql/issues/Issue5797.java index 0ce16d9cb4..4882977d91 100644 --- a/core/src/test/java/com/alibaba/druid/bvt/sql/mysql/issues/Issue5797.java +++ b/core/src/test/java/com/alibaba/druid/bvt/sql/mysql/issues/Issue5797.java @@ -43,7 +43,6 @@ public void test_parse_create_table() { System.out.println(dbType + "原始的sql===" + sql); SQLStatementParser parser = SQLParserUtils.createSQLStatementParser(sql, dbType); List statementList = parser.parseStatementList(); - com.alibaba.druid.sql.ast.statement.SQLJoinTableSource ggg; String sqlGen = statementList.toString(); System.out.println(dbType + "生成的sql===" + sqlGen); StringBuilder sb = new StringBuilder(); diff --git a/core/src/test/java/com/alibaba/druid/bvt/sql/mysql/issues/Issue5803.java b/core/src/test/java/com/alibaba/druid/bvt/sql/mysql/issues/Issue5803.java index 876650b7cb..607917b43e 100644 --- a/core/src/test/java/com/alibaba/druid/bvt/sql/mysql/issues/Issue5803.java +++ b/core/src/test/java/com/alibaba/druid/bvt/sql/mysql/issues/Issue5803.java @@ -38,7 +38,6 @@ public void test_parse_alter_table() { System.out.println(dbType + "原始的sql===" + sql); SQLStatementParser parser = SQLParserUtils.createSQLStatementParser(sql, dbType); List statementList = parser.parseStatementList(); - com.alibaba.druid.sql.ast.statement.SQLJoinTableSource ggg; String sqlGen = statementList.toString(); System.out.println(dbType + "生成的sql===" + sqlGen); StringBuilder sb = new StringBuilder(); diff --git a/core/src/test/java/com/alibaba/druid/bvt/sql/odps/issues/Issue5791.java b/core/src/test/java/com/alibaba/druid/bvt/sql/odps/issues/Issue5791.java index 0ea3c7e095..5b5dfce8bf 100644 --- a/core/src/test/java/com/alibaba/druid/bvt/sql/odps/issues/Issue5791.java +++ b/core/src/test/java/com/alibaba/druid/bvt/sql/odps/issues/Issue5791.java @@ -45,7 +45,6 @@ public void test_parse_comment() { System.out.println(dbType + "原始的sql===" + sql); SQLStatementParser parser = SQLParserUtils.createSQLStatementParser(sql, dbType); List statementList = parser.parseStatementList(); - com.alibaba.druid.sql.ast.statement.SQLJoinTableSource ggg; String sqlGen = statementList.toString(); System.out.println(dbType + "生成的sql===" + sqlGen); assertTrue(sqlGen.contains(" C2-2"));