Skip to content

Commit

Permalink
修复MySQL create user的sql解析bug并增强部分解析逻辑 alibaba#5774
Browse files Browse the repository at this point in the history
先前解析是有误的,因此修复,并根据官方语法,增加了几种情况的识别解析
  • Loading branch information
lizongbo committed Mar 15, 2024
1 parent 7cf57c3 commit 2ada350
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ public void accept0(MySqlASTVisitor visitor) {
public static class UserSpecification extends MySqlObjectImpl {
private SQLExpr user;
private boolean passwordHash;
private boolean randomPassword;
private SQLExpr password;
private SQLExpr authPlugin;
private boolean pluginAs;
Expand All @@ -71,6 +72,14 @@ public void setUser(SQLExpr user) {
this.user = (SQLName) user;
}

public boolean isRandomPassword() {
return randomPassword;
}

public void setRandomPassword(boolean randomPassword) {
this.randomPassword = randomPassword;
}

public boolean isPasswordHash() {
return passwordHash;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -911,40 +911,41 @@ public SQLStatement parseCreateUser() {
stmt.setIfNotExists(true);
}

SQLExpr expr = exprParser.primary();
if (expr instanceof SQLCharExpr) {
expr = new SQLIdentifierExpr(((SQLCharExpr) expr).getText());
}

if (expr instanceof SQLIdentifierExpr
&& lexer.token() == Token.VARIANT
&& lexer.stringVal().charAt(0) == '@'
) {
String str = lexer.stringVal();
MySqlUserName mySqlUserName = new MySqlUserName();
mySqlUserName.setUserName(((SQLIdentifierExpr) expr).getName());
mySqlUserName.setHost(str.substring(1));
expr = mySqlUserName;
MySqlUserName mySqlUserName = new MySqlUserName();
mySqlUserName.setUserName(trimQuotesBeginAndEnd(lexer.stringVal()));
lexer.nextToken();
String maybeHost=lexer.stringVal();
if ("@".equals(maybeHost)) {
lexer.nextToken();
mySqlUserName.setHost(trimQuotesBeginAndEnd(lexer.stringVal()));
lexer.nextToken();
} else if (maybeHost.startsWith("@")) { // eg: @localhost
maybeHost = maybeHost.substring(1);
mySqlUserName.setHost(trimQuotesBeginAndEnd(maybeHost));
lexer.nextToken();
}

userSpec.setUser(expr);
userSpec.setUser(mySqlUserName);

if (lexer.identifierEquals(FnvHash.Constants.IDENTIFIED)) {
lexer.nextToken();
if (lexer.token() == Token.BY) {
lexer.nextToken();

if (lexer.identifierEquals("PASSWORD")) {
if (lexer.identifierEquals("RANDOM")) {
lexer.nextToken();
userSpec.setPasswordHash(true);
}

SQLExpr password = this.exprParser.expr();
if (password instanceof SQLIdentifierExpr || password instanceof SQLCharExpr) {
userSpec.setPassword(password);
acceptIdentifier("PASSWORD");
userSpec.setRandomPassword(true);
} else {
throw new ParserException("syntax error. invalid " + password + " expression.");
if (lexer.identifierEquals("PASSWORD")) {
lexer.nextToken();
userSpec.setPasswordHash(true);
}
SQLExpr password = this.exprParser.expr();
if (password instanceof SQLIdentifierExpr || password instanceof SQLCharExpr) {
userSpec.setPassword(password);
} else {
throw new ParserException("syntax error. invalid " + password + " expression.");
}
}

} else if (lexer.token() == Token.WITH) {
Expand All @@ -961,11 +962,7 @@ public SQLStatement parseCreateUser() {
if (userSpec.isPluginAs()) {
// Remove ' because lexer don't remove it when token after as.
String psw = lexer.stringVal();
if (psw.length() >= 2 && '\'' == psw.charAt(0) && '\'' == psw.charAt(psw.length() - 1)) {
userSpec.setPassword(new SQLCharExpr(psw.substring(1, psw.length() - 1)));
} else {
userSpec.setPassword(new SQLCharExpr(psw));
}
userSpec.setPassword(new SQLCharExpr(trimQuotesBeginAndEnd(psw)));
lexer.nextToken();
} else {
userSpec.setPassword(this.exprParser.charExpr());
Expand All @@ -987,6 +984,17 @@ public SQLStatement parseCreateUser() {
return stmt;
}

static String trimQuotesBeginAndEnd(String str) {
if (str == null || str.length() < 2) {
return str;
}
char beginChar = str.charAt(0);
char endChar = str.charAt(str.length() - 1);
if ((beginChar == '\'' && endChar == '\'') || (beginChar == '\"' && endChar == '\"')) {
return str.substring(1, str.length() - 1);
}
return str;
}
public SQLStatement parseKill() {
accept(Token.KILL);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2012,6 +2012,11 @@ public boolean visit(UserSpecification x) {
print0(ucase ? "PASSWORD " : "password ");
}
x.getPassword().accept(this);
} else {
if (x.isRandomPassword()) {
print0(ucase ? " IDENTIFIED BY " : " identified by ");
print0(ucase ? "RANDOM PASSWORD " : "random password ");
}
}
return false;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package com.alibaba.druid.bvt.sql.mysql.issues;

import java.util.List;

import com.alibaba.druid.DbType;
import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.parser.SQLParserUtils;
import com.alibaba.druid.sql.parser.SQLStatementParser;

import org.junit.Test;

import static org.junit.Assert.assertEquals;

/**
* @author lizongbo
* @see <a href="https://github.com/alibaba/druid/issues/5774">Issue来源</a>
* @see <a href="https://dev.mysql.com/doc/refman/8.0/en/create-user.html">CREATE USER Statement</a>
* @see <a href="https://dev.mysql.com/doc/refman/8.0/en/account-names.html">Specifying Account Names</a>
*/
public class Issue5774 {


@Test
public void test_createuser_sql() {
for (String sql : new String[]{
"create user IF NOT EXISTS \"ptscr-2kaq\"@\"%\" identified by \"asdasdasdasd\";",
"create user IF NOT EXISTS \"ptscr-2kaq\" identified by \"asdasdasdasd\";",
"create user \"ptscr-2kaq\"@\"%\" identified by \"asdasdasdasd\";",
"create user \"ptscr-2kaq\"@\"%\" identified by RANDOM PASSWORD;",
"CREATE USER 'jeffrey'@'localhost' IDENTIFIED BY 'password';",
"CREATE USER 'jeffrey'@'localhost'\n"
+ " IDENTIFIED BY 'password';",
"CREATE USER 'jeffrey'@localhost IDENTIFIED BY 'password';",
// "CREATE USER 'jeffrey'@'localhost'\n"
// + " IDENTIFIED BY 'new_password' PASSWORD EXPIRE;",
"CREATE USER 'jeffrey'@'localhost'\n"
+ " IDENTIFIED WITH mysql_native_password BY 'password';",
// "CREATE USER 'u1'@'localhost'\n"
// + " IDENTIFIED WITH caching_sha2_password\n"
// + " BY 'sha2_password'\n"
// + " AND IDENTIFIED WITH authentication_ldap_sasl\n"
// + " AS 'uid=u1_ldap,ou=People,dc=example,dc=com';",
// "CREATE USER 'jeffrey'@'localhost' PASSWORD EXPIRE;",
// "CREATE USER 'jeffrey'@'localhost' PASSWORD EXPIRE DEFAULT;",
// "CREATE USER 'jeffrey'@'localhost' PASSWORD EXPIRE NEVER;",
}) {
DbType dbType = DbType.mysql;
System.out.println("原始的sql===" + sql);
SQLStatementParser parser = SQLParserUtils.createSQLStatementParser(sql, dbType);
List<SQLStatement> statementList = parser.parseStatementList();
System.out.println("生成的sql===" + statementList);
StringBuilder sb = new StringBuilder();
for (SQLStatement statement : statementList) {
sb.append(statement.toString()).append(";");
}
sb.deleteCharAt(sb.length() - 1);
parser = SQLParserUtils.createSQLStatementParser(sb.toString(), dbType);
List<SQLStatement> statementListNew = parser.parseStatementList();
System.out.println("再生成sql===" + statementListNew);
assertEquals(statementList.toString(), statementListNew.toString());
}
}
}

0 comments on commit 2ada350

Please sign in to comment.