Skip to content

Commit

Permalink
增加支持PostgreSQL的TABLESAMPLE语法解析 alibaba#5844
Browse files Browse the repository at this point in the history
增加支持PostgreSQL的TABLESAMPLE语法解析 alibaba#5844
  • Loading branch information
lizongbo committed May 3, 2024
1 parent b5d8cba commit ff64bab
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@

import com.alibaba.druid.sql.ast.*;
import com.alibaba.druid.sql.ast.expr.SQLIdentifierExpr;
import com.alibaba.druid.sql.ast.expr.SQLSizeExpr;
import com.alibaba.druid.sql.ast.statement.SQLExprTableSource;
import com.alibaba.druid.sql.ast.statement.SQLSelectQuery;
import com.alibaba.druid.sql.ast.statement.SQLSelectQueryBlock;
import com.alibaba.druid.sql.ast.statement.SQLTableSampling;
import com.alibaba.druid.sql.ast.statement.SQLTableSource;
import com.alibaba.druid.sql.dialect.postgresql.ast.stmt.PGFunctionTableSource;
import com.alibaba.druid.sql.dialect.postgresql.ast.stmt.PGSelectQueryBlock;
Expand Down Expand Up @@ -272,6 +274,99 @@ public SQLTableSource parseTableSourceRest(SQLTableSource tableSource) {
}
}

return this.parseTableSourceTableSample(tableSource);
}

public SQLTableSource parseTableSourceTableSample(SQLTableSource tableSource) {

if (lexer.identifierEquals(FnvHash.Constants.TABLESAMPLE) && tableSource instanceof SQLExprTableSource) {
Lexer.SavePoint mark = lexer.mark();
lexer.nextToken();
SQLTableSampling sampling = new SQLTableSampling();
if (lexer.identifierEquals(FnvHash.Constants.BERNOULLI)) {
lexer.nextToken();
sampling.setBernoulli(true);
} else if (lexer.identifierEquals(FnvHash.Constants.SYSTEM)) {
lexer.nextToken();
sampling.setSystem(true);
}

if (lexer.token() == Token.LPAREN) {
lexer.nextToken();

if (lexer.identifierEquals(FnvHash.Constants.BUCKET)) {
lexer.nextToken();
SQLExpr bucket = this.exprParser.primary();
sampling.setBucket(bucket);

if (lexer.token() == Token.OUT) {
lexer.nextToken();
accept(Token.OF);
SQLExpr outOf = this.exprParser.primary();
sampling.setOutOf(outOf);
}

if (lexer.token() == Token.ON) {
lexer.nextToken();
SQLExpr on = this.exprParser.expr();
sampling.setOn(on);
}
}

if (lexer.token() == Token.LITERAL_INT || lexer.token() == Token.LITERAL_FLOAT) {
SQLExpr val = this.exprParser.primary();

if (lexer.identifierEquals(FnvHash.Constants.ROWS)) {
lexer.nextToken();
sampling.setRows(val);
} else if (lexer.token() == Token.RPAREN) {
sampling.setRows(val);
} else {
acceptIdentifier("PERCENT");
sampling.setPercent(val);
}
}

if (lexer.token() == Token.IDENTIFIER) {
String strVal = lexer.stringVal();
char first = strVal.charAt(0);
char last = strVal.charAt(strVal.length() - 1);
if (last >= 'a' && last <= 'z') {
last -= 32; // to upper
}

boolean match = false;
if ((first == '.' || (first >= '0' && first <= '9'))) {
switch (last) {
case 'B':
case 'K':
case 'M':
case 'G':
case 'T':
case 'P':
match = true;
break;
default:
break;
}
}
SQLSizeExpr size = new SQLSizeExpr(strVal.substring(0, strVal.length() - 2), last);
sampling.setByteLength(size);
lexer.nextToken();
}

final SQLExprTableSource table = (SQLExprTableSource) tableSource;
table.setSampling(sampling);

accept(Token.RPAREN);
} else {
lexer.reset(mark);
}
}

if (lexer.identifierEquals(FnvHash.Constants.USING)) {
return tableSource;
}
return super.parseTableSourceRest(tableSource);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10519,7 +10519,7 @@ public boolean visit(SQLTableSampling x) {
if (rows != null) {
rows.accept(this);

if (dbType != DbType.mysql) {
if (!JdbcUtils.isMysqlDbType(dbType) && !JdbcUtils.isPgsqlDbType(dbType)) {
print0(ucase ? " ROWS" : " rows");
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package com.alibaba.druid.bvt.sql.postgresql.issues;

import java.util.List;

import com.alibaba.druid.DbType;
import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.ast.statement.SQLColumnDefinition;
import com.alibaba.druid.sql.ast.statement.SQLCreateTableStatement;
import com.alibaba.druid.sql.ast.statement.SQLPrimaryKey;
import com.alibaba.druid.sql.ast.statement.SQLTableElement;
import com.alibaba.druid.sql.dialect.oracle.parser.OracleCreateTableParser;
import com.alibaba.druid.sql.parser.SQLParserUtils;
import com.alibaba.druid.sql.parser.SQLStatementParser;

import org.junit.Test;

import static org.junit.Assert.assertEquals;

/**
* PostgreSQL解析TABLESAMPLE问题
*
* @author lizongbo
* @see <a href="https://github.com/alibaba/druid/issues/5844">Issue来源</a>
* @see <a href="https://www.postgresql.org/docs/current/tsm-system-rows.html"> the SYSTEM_ROWS sampling method for TABLESAMPLE </a>
* @see <a href="https://www.postgresql.org/docs/current/sql-select.html#SQL-TABLESAMPLE">TABLESAMPLE</a>
*/
public class Issue5844 {

@Test
public void test_parse_postgresql_tablesample() {
for (String sql : new String[]{
"SELECT * FROM app_qxx_zh TABLESAMPLE SYSTEM ( 5 )\n"
+ "WHERE random( ) < 0.01\n"
+ "ORDER BY show_count LIMIT 20",
"SELECT * FROM app_qxx_zh TABLESAMPLE BERNOULLI ( 0.01 )\n"
+ "WHERE random( ) < 0.01\n"
+ "ORDER BY show_count LIMIT 20",
}) {
SQLStatementParser parser1 = SQLParserUtils.createSQLStatementParser(sql, DbType.postgresql);
List<SQLStatement> statementList1 = parser1.parseStatementList();
System.out.println("原始的sql===" + sql);
String sqleNew = statementList1.get(0).toString();
System.out.println("生成的sql===" + sqleNew);
SQLStatementParser parser2 = SQLParserUtils.createSQLStatementParser(sqleNew, DbType.postgresql);
List<SQLStatement> statementList2 = parser2.parseStatementList();
String sqleNew2 = statementList2.get(0).toString();
System.out.println("再次解析生成的sql===" + sqleNew);
assertEquals(sqleNew, sqleNew2);
}

}
}

0 comments on commit ff64bab

Please sign in to comment.