From 4522bb824d600181317c5631103a9fc4b9ca1503 Mon Sep 17 00:00:00 2001 From: mcy Date: Sat, 2 Nov 2024 14:42:42 +0800 Subject: [PATCH] optimize code --- .../common/configuration/CheckResult.java | 6 +- .../ares/common/configuration/Options.java | 2 +- .../ares/common/exceptions/AresErrorCode.java | 4 +- .../ares/common/exceptions/CommonError.java | 43 +- .../ares/common/utils/PasswordUtil.java | 72 --- .../github/ares/common/utils/TimeUtils.java | 4 +- .../com/github/ares/common/utils/Tuple2.java | 20 +- .../java/com/github/ares/parser/PlParser.java | 4 +- .../parser/config/ParserServiceModule.java | 3 - .../github/ares/parser/model/TableWith.java | 2 +- .../ares/parser/plan/LogicalOperation.java | 2 +- .../sqlparser/sparksql/CommonParser.java | 53 ++ .../sqlparser/sparksql/CriteriaParser.java | 183 +++--- .../sqlparser/sparksql/DeleteSqlParser.java | 105 ++++ .../parser/sqlparser/sparksql/HintParser.java | 80 +++ .../sqlparser/sparksql/InsertSqlParser.java | 134 +++++ .../sqlparser/sparksql/MergeSqlParser.java | 201 +++++++ .../sqlparser/sparksql/SelectSqlParser.java | 68 +++ .../sqlparser/sparksql/SparkSqlParser.java | 533 +----------------- .../sqlparser/sparksql/UpdateSqlParser.java | 119 ++++ .../ares/parser/utils/PLParserUtil.java | 3 - .../ares/parser/visitor/PlBaseVisitor.java | 296 +++++----- .../ares/parser/visitor/PlBodyVisitor.java | 173 +++--- .../visitor/PlCallStatementVisitor.java | 152 +---- .../visitor/PlCreateTableWithVisitor.java | 4 +- .../parser/visitor/PlFunctionBodyVisitor.java | 129 +++-- .../github/ares/parser/test/PlParserTest.java | 102 +++- .../ares/spark/starter/SparkStarter.java | 27 +- .../ares/spark/starter/SparkStarter.java | 2 +- 29 files changed, 1375 insertions(+), 1151 deletions(-) delete mode 100644 ares-common/src/main/java/com/github/ares/common/utils/PasswordUtil.java create mode 100644 ares-parser/src/main/java/com/github/ares/parser/sqlparser/sparksql/CommonParser.java create mode 100644 ares-parser/src/main/java/com/github/ares/parser/sqlparser/sparksql/DeleteSqlParser.java create mode 100644 ares-parser/src/main/java/com/github/ares/parser/sqlparser/sparksql/HintParser.java create mode 100644 ares-parser/src/main/java/com/github/ares/parser/sqlparser/sparksql/InsertSqlParser.java create mode 100644 ares-parser/src/main/java/com/github/ares/parser/sqlparser/sparksql/MergeSqlParser.java create mode 100644 ares-parser/src/main/java/com/github/ares/parser/sqlparser/sparksql/SelectSqlParser.java create mode 100644 ares-parser/src/main/java/com/github/ares/parser/sqlparser/sparksql/UpdateSqlParser.java diff --git a/ares-common/src/main/java/com/github/ares/common/configuration/CheckResult.java b/ares-common/src/main/java/com/github/ares/common/configuration/CheckResult.java index 207185a..dc49716 100644 --- a/ares-common/src/main/java/com/github/ares/common/configuration/CheckResult.java +++ b/ares-common/src/main/java/com/github/ares/common/configuration/CheckResult.java @@ -7,12 +7,12 @@ public class CheckResult { private static final CheckResult SUCCESS = new CheckResult(true, ""); - private boolean success; + private boolean isSuccess; private String msg; - private CheckResult(boolean success, String msg) { - this.success = success; + private CheckResult(boolean isSuccess, String msg) { + this.isSuccess = isSuccess; this.msg = msg; } diff --git a/ares-common/src/main/java/com/github/ares/common/configuration/Options.java b/ares-common/src/main/java/com/github/ares/common/configuration/Options.java index 5756813..1053909 100644 --- a/ares-common/src/main/java/com/github/ares/common/configuration/Options.java +++ b/ares-common/src/main/java/com/github/ares/common/configuration/Options.java @@ -248,7 +248,7 @@ public static class SingleChoiceOptionBuilder { private final String key; private final TypeReference typeReference; - SingleChoiceOptionBuilder(String key, TypeReference typeReference, List optionValues) { + SingleChoiceOptionBuilder(String key, TypeReference typeReference, List optionValues) { this.optionValues = optionValues; this.key = key; this.typeReference = typeReference; diff --git a/ares-common/src/main/java/com/github/ares/common/exceptions/AresErrorCode.java b/ares-common/src/main/java/com/github/ares/common/exceptions/AresErrorCode.java index d944499..53c0403 100644 --- a/ares-common/src/main/java/com/github/ares/common/exceptions/AresErrorCode.java +++ b/ares-common/src/main/java/com/github/ares/common/exceptions/AresErrorCode.java @@ -17,8 +17,10 @@ package com.github.ares.common.exceptions; +import java.io.Serializable; + /** Ares connector error code interface */ -public interface AresErrorCode { +public interface AresErrorCode extends Serializable { /** * Get error code * diff --git a/ares-common/src/main/java/com/github/ares/common/exceptions/CommonError.java b/ares-common/src/main/java/com/github/ares/common/exceptions/CommonError.java index 9f90a38..0375326 100644 --- a/ares-common/src/main/java/com/github/ares/common/exceptions/CommonError.java +++ b/ares-common/src/main/java/com/github/ares/common/exceptions/CommonError.java @@ -35,33 +35,38 @@ import static com.github.ares.common.exceptions.CommonErrorCode.WRITE_ARES_ROW_ERROR; public class CommonError { + private static final String KEY_IDENTIFIER = "identifier"; + private static final String KEY_OPERATION = "operation"; + private static final String KEY_FILE_NAME = "fileName"; + private static final String KEY_DATA_TYPE = "dataType"; + private static final String KEY_FIELD = "field"; private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); public static AresRuntimeException fileOperationFailed( String identifier, String operation, String fileName, Throwable cause) { Map params = new HashMap<>(); - params.put("identifier", identifier); - params.put("operation", operation); - params.put("fileName", fileName); + params.put(KEY_IDENTIFIER, identifier); + params.put(KEY_OPERATION, operation); + params.put(KEY_FILE_NAME, fileName); return new AresRuntimeException(FILE_OPERATION_FAILED, params, cause); } public static AresRuntimeException fileOperationFailed( String identifier, String operation, String fileName) { Map params = new HashMap<>(); - params.put("identifier", identifier); - params.put("operation", operation); - params.put("fileName", fileName); + params.put(KEY_IDENTIFIER, identifier); + params.put(KEY_OPERATION, operation); + params.put(KEY_FILE_NAME, fileName); return new AresRuntimeException(FILE_OPERATION_FAILED, params); } public static AresRuntimeException fileNotExistFailed( String identifier, String operation, String fileName) { Map params = new HashMap<>(); - params.put("identifier", identifier); - params.put("operation", operation); - params.put("fileName", fileName); + params.put(KEY_IDENTIFIER, identifier); + params.put(KEY_OPERATION, operation); + params.put(KEY_FILE_NAME, fileName); return new AresRuntimeException(FILE_NOT_EXISTED, params); } @@ -76,18 +81,18 @@ public static AresRuntimeException writeAresRowFailed( public static AresRuntimeException unsupportedDataType( String identifier, String dataType, String field) { Map params = new HashMap<>(); - params.put("identifier", identifier); - params.put("dataType", dataType); - params.put("field", field); + params.put(KEY_IDENTIFIER, identifier); + params.put(KEY_DATA_TYPE, dataType); + params.put(KEY_FIELD, field); return new AresRuntimeException(UNSUPPORTED_DATA_TYPE, params); } public static AresRuntimeException convertToAresTypeError( String identifier, String dataType, String field) { Map params = new HashMap<>(); - params.put("identifier", identifier); - params.put("dataType", dataType); - params.put("field", field); + params.put(KEY_IDENTIFIER, identifier); + params.put(KEY_DATA_TYPE, dataType); + params.put(KEY_FIELD, field); return new AresRuntimeException(CONVERT_TO_ARES_TYPE_ERROR_SIMPLE, params); } @@ -96,9 +101,9 @@ public static AresRuntimeException convertToAresTypeError( public static AresRuntimeException convertToConnectorTypeError( String identifier, String dataType, String field) { Map params = new HashMap<>(); - params.put("identifier", identifier); - params.put("dataType", dataType); - params.put("field", field); + params.put(KEY_IDENTIFIER, identifier); + params.put(KEY_DATA_TYPE, dataType); + params.put(KEY_FIELD, field); return new AresRuntimeException(CONVERT_TO_CONNECTOR_TYPE_ERROR_SIMPLE, params); } @@ -137,7 +142,7 @@ public static AresRuntimeException jsonOperationError(String identifier, String public static AresRuntimeException jsonOperationError( String identifier, String payload, Throwable cause) { Map params = new HashMap<>(); - params.put("identifier", identifier); + params.put(KEY_IDENTIFIER, identifier); params.put("payload", payload); AresErrorCode code = JSON_OPERATION_FAILED; diff --git a/ares-common/src/main/java/com/github/ares/common/utils/PasswordUtil.java b/ares-common/src/main/java/com/github/ares/common/utils/PasswordUtil.java deleted file mode 100644 index 6059db7..0000000 --- a/ares-common/src/main/java/com/github/ares/common/utils/PasswordUtil.java +++ /dev/null @@ -1,72 +0,0 @@ -package com.github.ares.common.utils; - -import javax.crypto.Cipher; -import javax.crypto.SecretKeyFactory; -import javax.crypto.spec.DESKeySpec; -import javax.crypto.spec.IvParameterSpec; -import java.security.Key; -import java.util.Base64; - -public class PasswordUtil { - public static String DEFAULT_PASSWORDS = "23rfv(U*VSnwaf:fh"; - - private static final String IV_PARAMETER = "12345678"; - - private static final String ALGORITHM = "DES"; - - private static final String CIPHER_ALGORITHM = "DES/CBC/PKCS5Padding"; - - private static final String CHARSET = "utf-8"; - - public static String encrypt(String data) { - return encrypt(DEFAULT_PASSWORDS, data); - } - - public static String decrypt(String data) { - return decrypt(DEFAULT_PASSWORDS, data); - } - - private static Key generateKey(String password) throws Exception { - DESKeySpec dks = new DESKeySpec(password.getBytes(CHARSET)); - SecretKeyFactory keyFactory = SecretKeyFactory.getInstance(ALGORITHM); - return keyFactory.generateSecret(dks); - } - - public static String encrypt(String password, String data) { - if (password == null || password.length() < 8) { - throw new RuntimeException("Encrypt failed, password length must greater than 8"); - } - if (data == null) return null; - try { - Key secretKey = generateKey(password); - Cipher cipher = Cipher.getInstance(CIPHER_ALGORITHM); - IvParameterSpec iv = new IvParameterSpec(IV_PARAMETER.getBytes(CHARSET)); - cipher.init(Cipher.ENCRYPT_MODE, secretKey, iv); - byte[] bytes = cipher.doFinal(data.getBytes(CHARSET)); - - return new String(Base64.getEncoder().encode(bytes)); - - } catch (Exception e) { - e.printStackTrace(); - return data; - } - } - - public static String decrypt(String password, String data) { - if (password == null || password.length() < 8) { - throw new RuntimeException("Encrypt failed, password length must greater than 8"); - } - if (data == null) return null; - try { - Key secretKey = generateKey(password); - Cipher cipher = Cipher.getInstance(CIPHER_ALGORITHM); - IvParameterSpec iv = new IvParameterSpec(IV_PARAMETER.getBytes(CHARSET)); - cipher.init(Cipher.DECRYPT_MODE, secretKey, iv); - return new String( - cipher.doFinal(Base64.getDecoder().decode(data.getBytes(CHARSET))), CHARSET); - } catch (Exception e) { - e.printStackTrace(); - return data; - } - } -} diff --git a/ares-common/src/main/java/com/github/ares/common/utils/TimeUtils.java b/ares-common/src/main/java/com/github/ares/common/utils/TimeUtils.java index c078aa2..c80ac9d 100644 --- a/ares-common/src/main/java/com/github/ares/common/utils/TimeUtils.java +++ b/ares-common/src/main/java/com/github/ares/common/utils/TimeUtils.java @@ -2,12 +2,12 @@ import java.time.LocalTime; import java.time.format.DateTimeFormatter; +import java.util.EnumMap; import java.util.HashMap; import java.util.Map; public class TimeUtils { - private static final Map FORMATTER_MAP = - new HashMap(); + private static final Map FORMATTER_MAP = new EnumMap<>(Formatter.class); static { FORMATTER_MAP.put( diff --git a/ares-common/src/main/java/com/github/ares/common/utils/Tuple2.java b/ares-common/src/main/java/com/github/ares/common/utils/Tuple2.java index 8d636cc..c9ed120 100644 --- a/ares-common/src/main/java/com/github/ares/common/utils/Tuple2.java +++ b/ares-common/src/main/java/com/github/ares/common/utils/Tuple2.java @@ -2,26 +2,26 @@ import java.io.Serializable; -public class Tuple2 implements Serializable { +public class Tuple2 implements Serializable { private static final long serialVersionUID = 1L; - private final T1 _1; - private final T2 _2; + private final T1 value1; + private final T2 value2; - public Tuple2(T1 _1, T2 _2) { - this._1 = _1; - this._2 = _2; + public Tuple2(T1 value1, T2 value2) { + this.value1 = value1; + this.value2 = value2; } - public static Tuple2 of(T1 _1, T2 _2) { - return new Tuple2<>(_1, _2); + public static Tuple2 of(T1 value1, T2 value2) { + return new Tuple2<>(value1, value2); } public T1 _1() { - return _1; + return value1; } public T2 _2() { - return _2; + return value2; } } diff --git a/ares-parser/src/main/java/com/github/ares/parser/PlParser.java b/ares-parser/src/main/java/com/github/ares/parser/PlParser.java index a0b2fac..ad5acd5 100644 --- a/ares-parser/src/main/java/com/github/ares/parser/PlParser.java +++ b/ares-parser/src/main/java/com/github/ares/parser/PlParser.java @@ -48,7 +48,7 @@ private Sql_scriptContext parse(InputStream in) { parser.removeErrorListeners(); parser.addErrorListener(parserErrorListener); - Sql_scriptContext sql_scriptContext = parser.sql_script(); + Sql_scriptContext sqlScriptContext = parser.sql_script(); if (!lexerErrorListener.getErrors().isEmpty()) { throw new ParseException(String.join("\n", lexerErrorListener.getErrors())); @@ -56,7 +56,7 @@ private Sql_scriptContext parse(InputStream in) { if (!parserErrorListener.getErrors().isEmpty()) { throw new ParseException(String.join("\n", parserErrorListener.getErrors())); } - return sql_scriptContext; + return sqlScriptContext; } catch (Exception e) { throw new ParseException(e.getMessage(), e); } diff --git a/ares-parser/src/main/java/com/github/ares/parser/config/ParserServiceModule.java b/ares-parser/src/main/java/com/github/ares/parser/config/ParserServiceModule.java index 24e416c..19c46f8 100644 --- a/ares-parser/src/main/java/com/github/ares/parser/config/ParserServiceModule.java +++ b/ares-parser/src/main/java/com/github/ares/parser/config/ParserServiceModule.java @@ -5,7 +5,4 @@ import com.github.ares.com.google.inject.name.Names; public class ParserServiceModule extends AbstractModule { - @Override - protected void configure() { - } } diff --git a/ares-parser/src/main/java/com/github/ares/parser/model/TableWith.java b/ares-parser/src/main/java/com/github/ares/parser/model/TableWith.java index 1b9afff..3bfd40d 100644 --- a/ares-parser/src/main/java/com/github/ares/parser/model/TableWith.java +++ b/ares-parser/src/main/java/com/github/ares/parser/model/TableWith.java @@ -18,7 +18,7 @@ public abstract class TableWith extends LogicalOperation implements Serializable protected Map options = new LinkedHashMap<>(); - public TableWith(OperationType plainType) { + protected TableWith(OperationType plainType) { super(plainType); } } diff --git a/ares-parser/src/main/java/com/github/ares/parser/plan/LogicalOperation.java b/ares-parser/src/main/java/com/github/ares/parser/plan/LogicalOperation.java index 0eb4246..0fde157 100644 --- a/ares-parser/src/main/java/com/github/ares/parser/plan/LogicalOperation.java +++ b/ares-parser/src/main/java/com/github/ares/parser/plan/LogicalOperation.java @@ -11,7 +11,7 @@ public abstract class LogicalOperation implements Serializable { private OperationType operationType; - public LogicalOperation(OperationType operationType) { + protected LogicalOperation(OperationType operationType) { this.operationType = operationType; } diff --git a/ares-parser/src/main/java/com/github/ares/parser/sqlparser/sparksql/CommonParser.java b/ares-parser/src/main/java/com/github/ares/parser/sqlparser/sparksql/CommonParser.java new file mode 100644 index 0000000..ed424e4 --- /dev/null +++ b/ares-parser/src/main/java/com/github/ares/parser/sqlparser/sparksql/CommonParser.java @@ -0,0 +1,53 @@ +package com.github.ares.parser.sqlparser.sparksql; + +import com.github.ares.api.common.CriteriaClause; +import com.github.ares.org.antlr.v4.runtime.CharStream; +import com.github.ares.org.antlr.v4.runtime.CharStreams; +import com.github.ares.org.antlr.v4.runtime.CommonTokenStream; +import com.github.ares.parser.antlr4.CaseChangingCharStream; +import com.github.ares.parser.antlr4.CustomErrorListener; +import com.github.ares.parser.antlr4.sparksql.SqlBaseLexer; +import com.github.ares.parser.antlr4.sparksql.SqlBaseParser; + +import java.io.IOException; +import java.io.InputStream; +import java.util.List; + +public class CommonParser { + + public static final String UNSUPPORTED_EXP_MSG = "unsupported syntax: "; + public static final String UNSUPPORTED_EXP_MSG_WITH_PARAM = UNSUPPORTED_EXP_MSG + " %s"; + + public static final String SQL_SELECT_PREFIX = "SELECT "; + + public static SqlBaseParser parseSql(InputStream in) throws IOException { + CharStream s = CharStreams.fromStream(in); + CaseChangingCharStream upper = new CaseChangingCharStream(s, true); + + CustomErrorListener lexerErrorListener = new CustomErrorListener(); + SqlBaseLexer lexer = new SqlBaseLexer(upper); + lexer.removeErrorListeners(); + lexer.addErrorListener(lexerErrorListener); + CommonTokenStream tokens = new CommonTokenStream(lexer); + SqlBaseParser parser = new SqlBaseParser(tokens); + CustomErrorListener parserErrorListener = new CustomErrorListener(); + parser.removeErrorListeners(); + parser.addErrorListener(parserErrorListener); + return parser; + } + + public static void visitCriteriaClause(CriteriaClause criteriaClause, List items) { + if ("AND".equalsIgnoreCase(criteriaClause.getOperator()) || "OR".equalsIgnoreCase(criteriaClause.getOperator())) { + visitCriteriaClause(criteriaClause.getLeftCriteria(), items); + visitCriteriaClause(criteriaClause.getRightCriteria(), items); + } else { + if ("IN".equalsIgnoreCase(criteriaClause.getOperator())) { + if (criteriaClause.getInItems() != null) { + items.addAll(criteriaClause.getInItems()); + } + } else { + items.add(criteriaClause.getRightExpr()); + } + } + } +} diff --git a/ares-parser/src/main/java/com/github/ares/parser/sqlparser/sparksql/CriteriaParser.java b/ares-parser/src/main/java/com/github/ares/parser/sqlparser/sparksql/CriteriaParser.java index 2a403d5..451a377 100644 --- a/ares-parser/src/main/java/com/github/ares/parser/sqlparser/sparksql/CriteriaParser.java +++ b/ares-parser/src/main/java/com/github/ares/parser/sqlparser/sparksql/CriteriaParser.java @@ -9,56 +9,44 @@ import java.util.List; import java.util.StringJoiner; +import static com.github.ares.parser.sqlparser.sparksql.CommonParser.UNSUPPORTED_EXP_MSG; import static com.github.ares.parser.utils.PLParserUtil.getFullText; public class CriteriaParser { - public static void parseWhereClause(SqlBaseParser.BooleanExpressionContext booleanExpressionContext, CriteriaClause criteriaClause, String targetTableAlias) { - if (booleanExpressionContext instanceof SqlBaseParser.LogicalBinaryContext) { - SqlBaseParser.LogicalBinaryContext logicalBinaryContext = (SqlBaseParser.LogicalBinaryContext) booleanExpressionContext; - if ("AND".equalsIgnoreCase(logicalBinaryContext.operator.getText())) { - criteriaClause.setOperator("AND"); - } else if ("OR".equalsIgnoreCase(logicalBinaryContext.operator.getText())) { - criteriaClause.setOperator("OR"); - } else { - throw new ParseException("unsupported syntax: " + getFullText(booleanExpressionContext) + " in WHERE clause"); - } - CriteriaClause leftClause = new CriteriaClause(); - SqlBaseParser.BooleanExpressionContext leftExpressionContext = logicalBinaryContext.left; - parseWhereClause(leftExpressionContext, leftClause, targetTableAlias); - criteriaClause.setLeftCriteria(leftClause); + private static final String UNSUPPORTED_EXP_MSG = "unsupported syntax: %s in WHERE clause"; + + private static final String OP_AND = "AND"; + private static final String OP_OR = "OR"; + private static final String OP_EQ = "="; + private static final String OP_NE = "!="; + private static final String OP_NE2 = "<>"; + private static final String OP_GT = ">"; + private static final String OP_LT = "<"; + private static final String OP_GE = ">="; + private static final String OP_LE = "<="; + private static final String OP_IN = "IN"; + private static final String OP_LIKE = "LIKE"; + private static final String OP_NOT_LIKE = "NOT LIKE"; - CriteriaClause rightClause = new CriteriaClause(); - SqlBaseParser.BooleanExpressionContext rightExpressionContext = logicalBinaryContext.right; - parseWhereClause(rightExpressionContext, rightClause, targetTableAlias); - criteriaClause.setRightCriteria(rightClause); + private CriteriaParser() { + } + + /** + * parse where clause + * + * @param booleanExpressionContext boolean expression context + * @param criteriaClause criteria clause for result + * @param targetTableAlias target table alias + */ + public static void parseWhereClause(SqlBaseParser.BooleanExpressionContext booleanExpressionContext, CriteriaClause criteriaClause, String targetTableAlias) { + if (booleanExpressionContext instanceof SqlBaseParser.LogicalBinaryContext) { + parseLogicalBinaryContext((SqlBaseParser.LogicalBinaryContext) booleanExpressionContext, criteriaClause, targetTableAlias); } else if (booleanExpressionContext instanceof SqlBaseParser.PredicatedContext) { SqlBaseParser.PredicatedContext predicatedContext = (SqlBaseParser.PredicatedContext) booleanExpressionContext; - String operator; SqlBaseParser.PrimaryExpressionContext primaryExpressionContext; if (predicatedContext.valueExpression() instanceof SqlBaseParser.ComparisonContext) { - SqlBaseParser.ComparisonContext comparisonContext = (SqlBaseParser.ComparisonContext) predicatedContext.valueExpression(); - operator = comparisonContext.comparisonOperator().getText(); - if ("=".equals(operator)) { - criteriaClause.setOperator("="); - } else if ("!=".equals(operator) || "<>".equals(operator)) { - criteriaClause.setOperator("!="); - } else if (">".equals(operator)) { - criteriaClause.setOperator(">"); - } else if ("<".equals(operator)) { - criteriaClause.setOperator("<"); - } else if (">=".equals(operator)) { - criteriaClause.setOperator(">="); - } else if ("<=".equals(operator)) { - criteriaClause.setOperator("<="); - } else { - throw new ParseException(String.format("unsupported operator %s in WHERE clause: %s: ", operator, getFullText(booleanExpressionContext))); - } - if (!(comparisonContext.left instanceof SqlBaseParser.ValueExpressionDefaultContext)) { - throw new ParseException("unsupported syntax: " + getFullText(booleanExpressionContext) + " in WHERE clause"); - } - primaryExpressionContext = ((SqlBaseParser.ValueExpressionDefaultContext) comparisonContext.left).primaryExpression(); - criteriaClause.setRightExpr(getFullText(comparisonContext.right)); + primaryExpressionContext = parseComparisonContext((SqlBaseParser.ComparisonContext) predicatedContext.valueExpression(), criteriaClause); } else if (predicatedContext.valueExpression() instanceof SqlBaseParser.ValueExpressionDefaultContext) { SqlBaseParser.ValueExpressionDefaultContext valueExpressionDefaultContext = (SqlBaseParser.ValueExpressionDefaultContext) predicatedContext.valueExpression(); primaryExpressionContext = valueExpressionDefaultContext.primaryExpression(); @@ -69,37 +57,10 @@ public static void parseWhereClause(SqlBaseParser.BooleanExpressionContext boole parseWhereClause(expressionContext, criteriaClause, targetTableAlias); return; } else { - if (predicatedContext.predicate().IN() != null) { - criteriaClause.setOperator("IN"); - } else if (predicatedContext.predicate().NOT() != null && predicatedContext.predicate().LIKE() != null) { - criteriaClause.setOperator("NOT LIKE"); - } else if (predicatedContext.predicate().LIKE() != null) { - criteriaClause.setOperator("LIKE"); - } else { - throw new ParseException("unsupported syntax: " + getFullText(booleanExpressionContext) + " in WHERE clause"); - } - - if ("IN".equalsIgnoreCase(criteriaClause.getOperator())) { - if (predicatedContext.predicate().expression().isEmpty()) { - throw new ParseException("unsupported syntax: " + getFullText(booleanExpressionContext) + " in WHERE clause"); - } - criteriaClause.setInItems(new ArrayList<>()); - predicatedContext.predicate().expression().forEach(expr -> - criteriaClause.getInItems().add(getFullText(expr))); - } else { - if (predicatedContext.predicate().valueExpression().isEmpty()) { - throw new ParseException("unsupported syntax: " + getFullText(booleanExpressionContext) + " in WHERE clause"); - } - StringJoiner joiner = new StringJoiner(" "); - for (SqlBaseParser.ValueExpressionContext expressionContext : predicatedContext.predicate().valueExpression()) { - joiner.add(getFullText(expressionContext)); - } - String rightExpr = joiner.toString(); - criteriaClause.setRightExpr(rightExpr); - } + parseOtherConditionContext(predicatedContext, criteriaClause); } } else { - throw new ParseException("unsupported syntax: " + getFullText(booleanExpressionContext) + " in WHERE clause"); + throw new ParseException(String.format(UNSUPPORTED_EXP_MSG, getFullText(booleanExpressionContext))); } if (primaryExpressionContext instanceof SqlBaseParser.ColumnReferenceContext) { @@ -113,19 +74,19 @@ public static void parseWhereClause(SqlBaseParser.BooleanExpressionContext boole throw new ParseException(String.format("cannot found alias '%s' of field: '%s' in WHERE clause: %s", alias, field, getFullText(booleanExpressionContext))); } } else { - throw new ParseException("unsupported syntax: " + getFullText(booleanExpressionContext) + " in WHERE clause"); + throw new ParseException(String.format(UNSUPPORTED_EXP_MSG, getFullText(booleanExpressionContext))); } } } public static void visitOnWhereClause(CriteriaClause clause, StringBuilder conditionSql) { - if ("AND".equalsIgnoreCase(clause.getOperator())) { + if (OP_AND.equalsIgnoreCase(clause.getOperator())) { conditionSql.append(" ( "); visitOnWhereClause(clause.getLeftCriteria(), conditionSql); conditionSql.append(" AND "); visitOnWhereClause(clause.getRightCriteria(), conditionSql); conditionSql.append(" ) "); - } else if ("OR".equalsIgnoreCase(clause.getOperator())) { + } else if (OP_OR.equalsIgnoreCase(clause.getOperator())) { conditionSql.append(" ( "); visitOnWhereClause(clause.getLeftCriteria(), conditionSql); conditionSql.append(" OR "); @@ -145,4 +106,80 @@ public static void visitOnWhereClause(CriteriaClause clause, StringBuilder condi throw new AresException("Unsupported where clause: " + clause); } } + + private static void parseLogicalBinaryContext(SqlBaseParser.LogicalBinaryContext logicalBinaryContext, CriteriaClause criteriaClause, String targetTableAlias) { + if (OP_AND.equalsIgnoreCase(logicalBinaryContext.operator.getText())) { + criteriaClause.setOperator(OP_AND); + } else if (OP_OR.equalsIgnoreCase(logicalBinaryContext.operator.getText())) { + criteriaClause.setOperator(OP_OR); + } else { + throw new ParseException(String.format(UNSUPPORTED_EXP_MSG, getFullText(logicalBinaryContext))); + } + + CriteriaClause leftClause = new CriteriaClause(); + SqlBaseParser.BooleanExpressionContext leftExpressionContext = logicalBinaryContext.left; + parseWhereClause(leftExpressionContext, leftClause, targetTableAlias); + criteriaClause.setLeftCriteria(leftClause); + + CriteriaClause rightClause = new CriteriaClause(); + SqlBaseParser.BooleanExpressionContext rightExpressionContext = logicalBinaryContext.right; + parseWhereClause(rightExpressionContext, rightClause, targetTableAlias); + criteriaClause.setRightCriteria(rightClause); + } + + private static SqlBaseParser.PrimaryExpressionContext parseComparisonContext( + SqlBaseParser.ComparisonContext comparisonContext, CriteriaClause criteriaClause) { + String operator = comparisonContext.comparisonOperator().getText(); + if (OP_EQ.equals(operator)) { + criteriaClause.setOperator(OP_EQ); + } else if (OP_NE.equals(operator) || OP_NE2.equals(operator)) { + criteriaClause.setOperator(OP_NE); + } else if (OP_GT.equals(operator)) { + criteriaClause.setOperator(OP_GT); + } else if (OP_LT.equals(operator)) { + criteriaClause.setOperator(OP_LT); + } else if (OP_GE.equals(operator)) { + criteriaClause.setOperator(OP_GE); + } else if (OP_LE.equals(operator)) { + criteriaClause.setOperator(OP_LE); + } else { + throw new ParseException(String.format(UNSUPPORTED_EXP_MSG, getFullText(comparisonContext))); + } + if (!(comparisonContext.left instanceof SqlBaseParser.ValueExpressionDefaultContext)) { + throw new ParseException(String.format(UNSUPPORTED_EXP_MSG, getFullText(comparisonContext))); + } + criteriaClause.setRightExpr(getFullText(comparisonContext.right)); + return ((SqlBaseParser.ValueExpressionDefaultContext) comparisonContext.left).primaryExpression(); + } + + private static void parseOtherConditionContext(SqlBaseParser.PredicatedContext predicatedContext, CriteriaClause criteriaClause) { + if (predicatedContext.predicate().IN() != null) { + criteriaClause.setOperator(OP_IN); + } else if (predicatedContext.predicate().NOT() != null && predicatedContext.predicate().LIKE() != null) { + criteriaClause.setOperator(OP_NOT_LIKE); + } else if (predicatedContext.predicate().LIKE() != null) { + criteriaClause.setOperator(OP_LIKE); + } else { + throw new ParseException(String.format(UNSUPPORTED_EXP_MSG, getFullText(predicatedContext))); + } + + if (OP_IN.equalsIgnoreCase(criteriaClause.getOperator())) { + if (predicatedContext.predicate().expression().isEmpty()) { + throw new ParseException(String.format(UNSUPPORTED_EXP_MSG, getFullText(predicatedContext))); + } + criteriaClause.setInItems(new ArrayList<>()); + predicatedContext.predicate().expression().forEach(expr -> + criteriaClause.getInItems().add(getFullText(expr))); + } else { + if (predicatedContext.predicate().valueExpression().isEmpty()) { + throw new ParseException(String.format(UNSUPPORTED_EXP_MSG, getFullText(predicatedContext))); + } + StringJoiner joiner = new StringJoiner(" "); + for (SqlBaseParser.ValueExpressionContext expressionContext : predicatedContext.predicate().valueExpression()) { + joiner.add(getFullText(expressionContext)); + } + String rightExpr = joiner.toString(); + criteriaClause.setRightExpr(rightExpr); + } + } } diff --git a/ares-parser/src/main/java/com/github/ares/parser/sqlparser/sparksql/DeleteSqlParser.java b/ares-parser/src/main/java/com/github/ares/parser/sqlparser/sparksql/DeleteSqlParser.java new file mode 100644 index 0000000..0cb2a60 --- /dev/null +++ b/ares-parser/src/main/java/com/github/ares/parser/sqlparser/sparksql/DeleteSqlParser.java @@ -0,0 +1,105 @@ +package com.github.ares.parser.sqlparser.sparksql; + +import com.github.ares.api.common.CriteriaClause; +import com.github.ares.common.exceptions.ParseException; +import com.github.ares.parser.antlr4.sparksql.SqlBaseParser; +import com.github.ares.parser.sqlparser.model.SQLDelete; +import com.github.ares.parser.sqlparser.model.SQLHint; +import org.apache.commons.lang3.StringUtils; +import org.apache.commons.lang3.tuple.Pair; + +import java.io.ByteArrayInputStream; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; + +import static com.github.ares.parser.sqlparser.sparksql.CommonParser.SQL_SELECT_PREFIX; +import static com.github.ares.parser.sqlparser.sparksql.CommonParser.UNSUPPORTED_EXP_MSG_WITH_PARAM; + +public class DeleteSqlParser { + private DeleteSqlParser() { + } + + /** + * Parse delete SQL and return SQLDelete object. + * + * @param sql delete SQL + * @return sqlDelete object + */ + public static SQLDelete parseDelete(String sql) { + SQLDelete sqlDelete = new SQLDelete(); + try (InputStream in = new ByteArrayInputStream(sql.getBytes(StandardCharsets.UTF_8))) { + SqlBaseParser parser = CommonParser.parseSql(in); + SqlBaseParser.DmlStatementNoWithContext dmlStatementNoWithContext = parser.dmlStatementNoWith(); + + if (!(dmlStatementNoWithContext instanceof SqlBaseParser.DeleteFromTableContext)) { + throw new ParseException(String.format(UNSUPPORTED_EXP_MSG_WITH_PARAM, sql)); + } + + SqlBaseParser.DeleteFromTableContext deleteFromTableContext = (SqlBaseParser.DeleteFromTableContext) dmlStatementNoWithContext; + + SqlBaseParser.MultipartIdentifierContext mappingTable = deleteFromTableContext.multipartIdentifier(0); + sqlDelete.setTable(mappingTable.getText()); + + if (deleteFromTableContext.source != null || deleteFromTableContext.sourceQuery != null) { + parseSourceQuery(deleteFromTableContext, sqlDelete, sql); + } else if (!deleteFromTableContext.tableAlias().isEmpty()) { + sqlDelete.setAlias(deleteFromTableContext.tableAlias().get(0).getText()); + } + + if (deleteFromTableContext.whereClause() == null) { + throw new ParseException("delete SQL must have WHERE clause: " + sql); + } + CriteriaClause criteriaClause = new CriteriaClause(); + SqlBaseParser.BooleanExpressionContext expressionContext = deleteFromTableContext.whereClause().booleanExpression(); + CriteriaParser.parseWhereClause(expressionContext, criteriaClause, sqlDelete.getAlias()); + sqlDelete.setWhereClause(criteriaClause); + + List selectItems = new ArrayList<>(); + CommonParser.visitCriteriaClause(criteriaClause, selectItems); + + StringBuilder selectSql = new StringBuilder(); + selectSql.append(SQL_SELECT_PREFIX); + selectSql.append(String.join(", ", selectItems)); + if (StringUtils.isNotBlank(sqlDelete.getJoinTable())) { + selectSql.append(" FROM ").append(sqlDelete.getJoinTable()).append(" ").append(sqlDelete.getJoinAlias()); + } else if (StringUtils.isNotBlank(sqlDelete.getJoinSql())) { + selectSql.append(" FROM (").append(sqlDelete.getJoinSql()).append(") ").append(sqlDelete.getJoinAlias()); + } + sqlDelete.setSourceSql(selectSql.toString()); + } catch (ParseException e) { + throw e; + } catch (Exception e) { + throw new ParseException(e.getMessage(), e); + } + return sqlDelete; + } + + private static void parseSourceQuery(SqlBaseParser.DeleteFromTableContext deleteFromTableContext, SQLDelete sqlDelete, String sql) { + if (deleteFromTableContext.tableAlias().size() != 2) { + throw new ParseException(String.format("Alias not defined for source table or target table: %s", sql)); + } + if (deleteFromTableContext.source != null) { + sqlDelete.setJoinTable(deleteFromTableContext.source.getText()); + String sourceTableAlias = deleteFromTableContext.tableAlias().get(1).getText(); + sqlDelete.setJoinAlias(sourceTableAlias); + sqlDelete.setAlias(deleteFromTableContext.tableAlias().get(0).getText()); + } else if (deleteFromTableContext.sourceQuery != null) { + if (!(deleteFromTableContext.sourceQuery.queryTerm() instanceof SqlBaseParser.QueryTermDefaultContext)) { + throw new ParseException(String.format(UNSUPPORTED_EXP_MSG_WITH_PARAM, sql)); + } + SqlBaseParser.QueryPrimaryContext queryPrimaryContext = ((SqlBaseParser.QueryTermDefaultContext) deleteFromTableContext.sourceQuery.queryTerm()).queryPrimary(); + if (!(queryPrimaryContext instanceof SqlBaseParser.QueryPrimaryDefaultContext)) { + throw new ParseException(String.format(UNSUPPORTED_EXP_MSG_WITH_PARAM, sql)); + } + String sourceTableAlias = deleteFromTableContext.tableAlias().get(1).getText(); + sqlDelete.setJoinAlias(sourceTableAlias); + sqlDelete.setAlias(deleteFromTableContext.tableAlias().get(0).getText()); + Pair, String> hintsWithSql = HintParser.parseSelectHints(sql, + (SqlBaseParser.QueryPrimaryDefaultContext) queryPrimaryContext); + sqlDelete.setHints(hintsWithSql.getLeft()); + sqlDelete.setJoinSql(hintsWithSql.getRight()); + } + } +} diff --git a/ares-parser/src/main/java/com/github/ares/parser/sqlparser/sparksql/HintParser.java b/ares-parser/src/main/java/com/github/ares/parser/sqlparser/sparksql/HintParser.java new file mode 100644 index 0000000..3e51593 --- /dev/null +++ b/ares-parser/src/main/java/com/github/ares/parser/sqlparser/sparksql/HintParser.java @@ -0,0 +1,80 @@ +package com.github.ares.parser.sqlparser.sparksql; + +import com.github.ares.api.common.EngineType; +import com.github.ares.api.common.ExecutionEngineType; +import com.github.ares.common.exceptions.ParseException; +import com.github.ares.org.antlr.v4.runtime.tree.ParseTree; +import com.github.ares.parser.antlr4.sparksql.SqlBaseParser; +import com.github.ares.parser.sqlparser.model.SQLHint; +import org.apache.commons.lang3.StringUtils; +import org.apache.commons.lang3.tuple.Pair; + +import java.util.ArrayList; +import java.util.List; + +import static com.github.ares.parser.sqlparser.sparksql.CommonParser.UNSUPPORTED_EXP_MSG_WITH_PARAM; +import static com.github.ares.parser.utils.PLParserUtil.getFullText; + +public class HintParser { + private static final String INNER_HINT_MAPJOIN = "mapjoin"; + private static final String INNER_HINT_BROADCAST = "broadcast"; + + private HintParser() { + } + + public static Pair, String> parseSelectHints(String sql, SqlBaseParser.QueryPrimaryDefaultContext queryPrimaryDefaultContext) { + if (!(queryPrimaryDefaultContext.querySpecification() instanceof SqlBaseParser.RegularQuerySpecificationContext)) { + throw new ParseException(String.format(UNSUPPORTED_EXP_MSG_WITH_PARAM, sql)); + } + SqlBaseParser.RegularQuerySpecificationContext regularQuerySpecificationContext = (SqlBaseParser.RegularQuerySpecificationContext) queryPrimaryDefaultContext.querySpecification(); + if (regularQuerySpecificationContext.selectClause() == null) { + throw new ParseException(String.format(UNSUPPORTED_EXP_MSG_WITH_PARAM, sql)); + } + + List sqlHints = new ArrayList<>(); + SqlBaseParser.SelectClauseContext selectClauseContext = regularQuerySpecificationContext.selectClause(); + if (!selectClauseContext.hints.isEmpty()) { + for (int i = 0; i < selectClauseContext.hints.size(); i++) { + SQLHint sqlHint = new SQLHint(); + SqlBaseParser.HintContext hintContext = selectClauseContext.hint(i); + String hintName = hintContext.hintStatement.hintName.getText(); + sqlHint.setHintName(hintName); + for (SqlBaseParser.PrimaryExpressionContext primaryExpressionContext : hintContext.hintStatement.parameters) { + String parameter = getFullText(primaryExpressionContext); + if (StringUtils.isNotBlank(parameter)) { + sqlHint.getArguments().add(parameter); + } + } + sqlHints.add(sqlHint); + } + } + + // filter out hints and into clause from select clause + StringBuilder selectSql = new StringBuilder(); + for (ParseTree child : regularQuerySpecificationContext.children) { + if (child instanceof SqlBaseParser.SelectClauseContext) { + filterHints((SqlBaseParser.SelectClauseContext) child, selectSql); + } else { + selectSql.append(getFullText(child)).append(" "); + } + } + + return Pair.of(sqlHints, selectSql.toString()); + } + + private static void filterHints(SqlBaseParser.SelectClauseContext selectClauseContext, StringBuilder selectSql) { + for (ParseTree grandChild : selectClauseContext.children) { + if (grandChild instanceof SqlBaseParser.IntoClauseContext || grandChild instanceof SqlBaseParser.HintContext) { + if ((grandChild instanceof SqlBaseParser.HintContext) && ExecutionEngineType.engineType == EngineType.SPARK) { + String hint = getFullText(grandChild); + String hintLower = hint.toLowerCase(); + if (hintLower.contains(INNER_HINT_MAPJOIN) || hintLower.contains(INNER_HINT_BROADCAST)) { + selectSql.append(hint).append(" "); + } + } + continue; + } + selectSql.append(getFullText(grandChild)).append(" "); + } + } +} diff --git a/ares-parser/src/main/java/com/github/ares/parser/sqlparser/sparksql/InsertSqlParser.java b/ares-parser/src/main/java/com/github/ares/parser/sqlparser/sparksql/InsertSqlParser.java new file mode 100644 index 0000000..a8c0e48 --- /dev/null +++ b/ares-parser/src/main/java/com/github/ares/parser/sqlparser/sparksql/InsertSqlParser.java @@ -0,0 +1,134 @@ +package com.github.ares.parser.sqlparser.sparksql; + +import com.github.ares.common.exceptions.ParseException; +import com.github.ares.org.antlr.v4.runtime.tree.ParseTree; +import com.github.ares.org.antlr.v4.runtime.tree.TerminalNodeImpl; +import com.github.ares.parser.antlr4.sparksql.SqlBaseParser; +import com.github.ares.parser.sqlparser.model.SQLHint; +import com.github.ares.parser.sqlparser.model.SQLInsert; +import org.apache.commons.lang3.tuple.Pair; + +import java.io.ByteArrayInputStream; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; +import java.util.StringJoiner; + +import static com.github.ares.parser.sqlparser.sparksql.CommonParser.SQL_SELECT_PREFIX; +import static com.github.ares.parser.sqlparser.sparksql.CommonParser.UNSUPPORTED_EXP_MSG_WITH_PARAM; +import static com.github.ares.parser.utils.PLParserUtil.getFullText; + +public class InsertSqlParser { + private static final String SQL_INTO = "INTO"; + + private InsertSqlParser() { + } + + /** + * Parse insert sql and return SQLInsert object. + * + * @param sql insert sql + * @return SQLInsert object + */ + public static SQLInsert parseInsert(String sql) { + SQLInsert sqlInsert = new SQLInsert(); + + try (InputStream in = new ByteArrayInputStream(sql.getBytes(StandardCharsets.UTF_8))) { + + SqlBaseParser parser = CommonParser.parseSql(in); + SqlBaseParser.DmlStatementNoWithContext dmlStatementNoWithContext = parser.dmlStatementNoWith(); + + if (!(dmlStatementNoWithContext instanceof SqlBaseParser.SingleInsertQueryContext)) { + throw new ParseException(String.format(UNSUPPORTED_EXP_MSG_WITH_PARAM, sql)); + } + SqlBaseParser.SingleInsertQueryContext singleInsertQueryContext = (SqlBaseParser.SingleInsertQueryContext) dmlStatementNoWithContext; + SqlBaseParser.InsertIntoTableContext insertIntoContext = (SqlBaseParser.InsertIntoTableContext) singleInsertQueryContext.insertInto(); + if (insertIntoContext.getChildCount() < 3 || + !(insertIntoContext.getChild(1) instanceof TerminalNodeImpl) || + !SQL_INTO.equalsIgnoreCase(insertIntoContext.getChild(1).getText()) || + !(insertIntoContext.getChild(2) instanceof SqlBaseParser.MultipartIdentifierContext)) { + throw new ParseException(String.format(UNSUPPORTED_EXP_MSG_WITH_PARAM, sql)); + } + SqlBaseParser.MultipartIdentifierContext multipartIdentifierContext = insertIntoContext.multipartIdentifier(); + sqlInsert.setTable(multipartIdentifierContext.getText()); + if (insertIntoContext.identifierList() != null) { + SqlBaseParser.IdentifierListContext identifierListContext = insertIntoContext.identifierList(); + if (identifierListContext.identifierSeq() == null) { + throw new ParseException(String.format(UNSUPPORTED_EXP_MSG_WITH_PARAM, sql)); + } + + SqlBaseParser.IdentifierSeqContext identifierSeqContext = identifierListContext.identifierSeq(); + for (SqlBaseParser.ErrorCapturingIdentifierContext errorCapturingIdentifierContext : identifierSeqContext.errorCapturingIdentifier()) { + sqlInsert.getColumns().add(errorCapturingIdentifierContext.getText()); + } + } + + String selectSql; + SqlBaseParser.QueryContext queryContext = singleInsertQueryContext.query(); + SqlBaseParser.QueryPrimaryContext queryPrimaryContext = ((SqlBaseParser.QueryTermDefaultContext) queryContext.queryTerm()).queryPrimary(); + if (queryPrimaryContext instanceof SqlBaseParser.InlineTableDefault1Context) { + selectSql = parseInlineTableContext(queryContext, sql, sqlInsert); + } else if (queryPrimaryContext instanceof SqlBaseParser.QueryPrimaryDefaultContext) { + Pair, String> hintsWithSql = HintParser.parseSelectHints(sql, + (SqlBaseParser.QueryPrimaryDefaultContext) queryPrimaryContext); + sqlInsert.setHints(hintsWithSql.getLeft()); + selectSql = hintsWithSql.getRight(); + } else { + throw new ParseException(String.format(UNSUPPORTED_EXP_MSG_WITH_PARAM, sql)); + } + sqlInsert.setSourceSql(selectSql); + } catch (ParseException e) { + throw e; + } catch (Exception e) { + throw new ParseException(e.getMessage(), e); + } + return sqlInsert; + } + + /** + * Parse inline table context. + * + * @param queryContext query context object + * @param sql sql string + * @param sqlInsert SQLInsert object + * @return select sql string + */ + private static String parseInlineTableContext(SqlBaseParser.QueryContext queryContext, String sql, SQLInsert sqlInsert) { + List expressionContexts = ((SqlBaseParser.InlineTableDefault1Context) + ((SqlBaseParser.QueryTermDefaultContext) queryContext.queryTerm()).queryPrimary()).inlineTable().expression(); + List valuesExpressions = new ArrayList<>(); + List> valuesArray = new ArrayList<>(); + for (SqlBaseParser.ExpressionContext expressionContext : expressionContexts) { + if (expressionContext.getChildCount() < 1 || !(expressionContext.getChild(0) instanceof SqlBaseParser.PredicatedContext)) { + throw new ParseException(String.format(UNSUPPORTED_EXP_MSG_WITH_PARAM, sql)); + } + SqlBaseParser.PredicatedContext predicatedContext = (SqlBaseParser.PredicatedContext) expressionContext.getChild(0); + if (predicatedContext.valueExpression().getChildCount() < 1 || !(predicatedContext.valueExpression().getChild(0) instanceof SqlBaseParser.RowConstructorContext)) { + throw new ParseException(String.format(UNSUPPORTED_EXP_MSG_WITH_PARAM, sql)); + } + List values = new ArrayList<>(); + SqlBaseParser.RowConstructorContext rowConstructorContext = (SqlBaseParser.RowConstructorContext) predicatedContext.valueExpression().getChild(0); + for (int i = 0; i < rowConstructorContext.getChildCount(); i++) { + ParseTree item = rowConstructorContext.getChild(i); + if (item instanceof TerminalNodeImpl) { + continue; + } + values.add(getFullText(item)); + } + valuesArray.add(values); + sqlInsert.setValuesArray(valuesArray); + + String selectExpression = getFullText(expressionContext); + if (!selectExpression.startsWith("(") && !selectExpression.endsWith(")")) { + throw new ParseException(String.format(UNSUPPORTED_EXP_MSG_WITH_PARAM, sql)); + } + selectExpression = selectExpression.substring(0, selectExpression.length() - 1).substring(1); + valuesExpressions.add(selectExpression); + } + StringJoiner stringJoiner = new StringJoiner(" UNION ALL "); + valuesExpressions.forEach(valuesExpression -> stringJoiner.add(SQL_SELECT_PREFIX + valuesExpression)); + + return stringJoiner.toString(); + } +} diff --git a/ares-parser/src/main/java/com/github/ares/parser/sqlparser/sparksql/MergeSqlParser.java b/ares-parser/src/main/java/com/github/ares/parser/sqlparser/sparksql/MergeSqlParser.java new file mode 100644 index 0000000..b9159d1 --- /dev/null +++ b/ares-parser/src/main/java/com/github/ares/parser/sqlparser/sparksql/MergeSqlParser.java @@ -0,0 +1,201 @@ +package com.github.ares.parser.sqlparser.sparksql; + +import com.github.ares.api.common.CriteriaClause; +import com.github.ares.common.exceptions.ParseException; +import com.github.ares.parser.antlr4.sparksql.SqlBaseParser; +import com.github.ares.parser.sqlparser.model.SQLHint; +import com.github.ares.parser.sqlparser.model.SQLInsert; +import com.github.ares.parser.sqlparser.model.SQLMerge; +import com.github.ares.parser.sqlparser.model.SQLUpdate; +import org.apache.commons.lang3.tuple.Pair; + +import java.io.ByteArrayInputStream; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; + +import static com.github.ares.parser.sqlparser.sparksql.CommonParser.SQL_SELECT_PREFIX; +import static com.github.ares.parser.sqlparser.sparksql.CommonParser.UNSUPPORTED_EXP_MSG_WITH_PARAM; +import static com.github.ares.parser.sqlparser.sparksql.CriteriaParser.parseWhereClause; +import static com.github.ares.parser.sqlparser.sparksql.CriteriaParser.visitOnWhereClause; +import static com.github.ares.parser.utils.PLParserUtil.getFullText; + +public class MergeSqlParser { + private MergeSqlParser() { + } + + /** + * Parse merge into SQL and return SQLMerge object. + * + * @param sql merge into SQL + * @return SQLMerge object + */ + public static SQLMerge parseMerge(String sql) { + SQLMerge sqlMerge = new SQLMerge(); + try (InputStream in = new ByteArrayInputStream(sql.getBytes(StandardCharsets.UTF_8))) { + SqlBaseParser parser = CommonParser.parseSql(in); + SqlBaseParser.DmlStatementNoWithContext dmlStatementNoWithContext = parser.dmlStatementNoWith(); + + if (!(dmlStatementNoWithContext instanceof SqlBaseParser.MergeIntoTableContext)) { + throw new ParseException(String.format(UNSUPPORTED_EXP_MSG_WITH_PARAM, sql)); + } + getFullText(dmlStatementNoWithContext); + + SqlBaseParser.MergeIntoTableContext mergeIntoTableContext = (SqlBaseParser.MergeIntoTableContext) dmlStatementNoWithContext; + SqlBaseParser.MultipartIdentifierContext mappingTable = mergeIntoTableContext.multipartIdentifier(0); + sqlMerge.setTable(mappingTable.getText()); + + if (mergeIntoTableContext.source != null || mergeIntoTableContext.sourceQuery != null) { + parseSourceQuery(mergeIntoTableContext, sqlMerge, sql); + } else if (!mergeIntoTableContext.tableAlias().isEmpty()) { + sqlMerge.setAlias(mergeIntoTableContext.tableAlias().get(0).getText()); + } + + CriteriaClause onClause = new CriteriaClause(); + SqlBaseParser.BooleanExpressionContext onExpressionContext = mergeIntoTableContext.mergeCondition; + CriteriaParser.parseWhereClause(onExpressionContext, onClause, sqlMerge.getAlias()); + List onSelectItems = new ArrayList<>(); + CommonParser.visitCriteriaClause(onClause, onSelectItems); + sqlMerge.setOnSelectItems(onSelectItems); + + String usingSQL; + if (sqlMerge.getUsingTable() != null) { + usingSQL = "SELECT * FROM " + sqlMerge.getUsingTable(); + } else { + usingSQL = sqlMerge.getUsingSql(); + } + + StringBuilder conditionSql = new StringBuilder(); + visitOnWhereClause(onClause, conditionSql); + + if (mergeIntoTableContext.matchedClause().size() > 1 || mergeIntoTableContext.notMatchedClause().size() > 1) { + throw new ParseException(String.format(UNSUPPORTED_EXP_MSG_WITH_PARAM, sql)); + } + if (!mergeIntoTableContext.notMatchedClause().isEmpty()) { + parseNotMatchedClause(mergeIntoTableContext, sqlMerge, usingSQL, conditionSql.toString()); + } + + if (!mergeIntoTableContext.matchedClause().isEmpty()) { + parseMatchedClause(mergeIntoTableContext, sqlMerge, usingSQL, onClause, conditionSql.toString()); + } + } catch (ParseException e) { + throw e; + } catch (Exception e) { + throw new ParseException(e.getMessage(), e); + } + return sqlMerge; + } + + private static void parseSourceQuery(SqlBaseParser.MergeIntoTableContext mergeIntoTableContext, SQLMerge sqlMerge, String sql) { + if (mergeIntoTableContext.tableAlias().size() != 2) { + throw new ParseException(String.format("Alias not defined for source table or target table: %s", sql)); + } + if (mergeIntoTableContext.source != null) { + sqlMerge.setUsingTable(mergeIntoTableContext.source.getText()); + String sourceTableAlias = mergeIntoTableContext.tableAlias().get(1).getText(); + sqlMerge.setUsingAlias(sourceTableAlias); + sqlMerge.setAlias(mergeIntoTableContext.tableAlias().get(0).getText()); + } else if (mergeIntoTableContext.sourceQuery != null) { + if (!(mergeIntoTableContext.sourceQuery.queryTerm() instanceof SqlBaseParser.QueryTermDefaultContext)) { + throw new ParseException(String.format(UNSUPPORTED_EXP_MSG_WITH_PARAM, sql)); + } + SqlBaseParser.QueryPrimaryContext queryPrimaryContext = ((SqlBaseParser.QueryTermDefaultContext) mergeIntoTableContext.sourceQuery.queryTerm()).queryPrimary(); + if (!(queryPrimaryContext instanceof SqlBaseParser.QueryPrimaryDefaultContext)) { + throw new ParseException(String.format(UNSUPPORTED_EXP_MSG_WITH_PARAM, sql)); + } + String sourceTableAlias = mergeIntoTableContext.tableAlias().get(1).getText(); + sqlMerge.setUsingAlias(sourceTableAlias); + sqlMerge.setAlias(mergeIntoTableContext.tableAlias().get(0).getText()); + Pair, String> hintsWithSql = HintParser.parseSelectHints(sql, + (SqlBaseParser.QueryPrimaryDefaultContext) queryPrimaryContext); + sqlMerge.setHints(hintsWithSql.getLeft()); + sqlMerge.setUsingSql(hintsWithSql.getRight()); + } + } + + private static void parseNotMatchedClause(SqlBaseParser.MergeIntoTableContext mergeIntoTableContext, SQLMerge sqlMerge, + String usingSQL, String conditionSql) { + SQLInsert sqlInsert = new SQLInsert(); + sqlInsert.setTable(sqlMerge.getTable()); + + SqlBaseParser.NotMatchedActionContext notMatchedAction = mergeIntoTableContext.notMatchedClause().get(0).notMatchedAction(); + + for (SqlBaseParser.MultipartIdentifierContext multipartIdentifierContext : notMatchedAction.multipartIdentifierList().multipartIdentifier()) { + sqlInsert.getColumns().add(multipartIdentifierContext.errorCapturingIdentifier.identifier().getText()); + } + + List> valuesArray = new ArrayList<>(); + List values = new ArrayList<>(); + for (SqlBaseParser.ExpressionContext expressionContext : notMatchedAction.expression()) { + values.add(getFullText(expressionContext)); + } + valuesArray.add(values); + sqlInsert.setValuesArray(valuesArray); + sqlMerge.setSqlInsert(sqlInsert); + + StringBuilder sourceSql = new StringBuilder(); + sourceSql.append(SQL_SELECT_PREFIX); + sourceSql.append(String.join(", ", sqlInsert.getValuesArray().get(0))); + sourceSql.append(" FROM ("); + sourceSql.append(usingSQL).append(") ").append(sqlMerge.getUsingAlias()).append(" "); + sourceSql.append("WHERE NOT EXISTS (SELECT 1 FROM "); + sourceSql.append(sqlMerge.getTable()); + sourceSql.append(" WHERE ").append(conditionSql).append(" )"); + sqlInsert.setSourceSql(sourceSql.toString()); + } + + private static void parseMatchedClause(SqlBaseParser.MergeIntoTableContext mergeIntoTableContext, SQLMerge sqlMerge, + String usingSQL, CriteriaClause onClause, String conditionSql) { + + SQLUpdate sqlUpdate = new SQLUpdate(); + sqlUpdate.setTable(sqlMerge.getTable()); + + sqlUpdate.setAlias(sqlMerge.getAlias()); + sqlUpdate.setJoinAlias(sqlMerge.getUsingAlias()); + if (sqlMerge.getUsingTable() != null) { + sqlUpdate.setJoinTable(sqlMerge.getUsingTable()); + } else if (sqlMerge.getUsingSql() != null) { + sqlUpdate.setJoinSql(sqlMerge.getUsingSql()); + } + + SqlBaseParser.MatchedActionContext matchedActionContext = mergeIntoTableContext.matchedClause().get(0).matchedAction(); + + for (SqlBaseParser.AssignmentContext assignmentContext : matchedActionContext.assignmentList().assignment()) { + sqlUpdate.getUpdateColumns().add(assignmentContext.key.errorCapturingIdentifier.identifier().getText()); + sqlUpdate.getUpdateValues().add(getFullText(assignmentContext.value)); + } + + if (matchedActionContext.booleanExpression() != null) { + CriteriaClause whereClause = new CriteriaClause(); + parseWhereClause(matchedActionContext.booleanExpression(), whereClause, sqlUpdate.getAlias()); + + sqlUpdate.setWhereClause(whereClause); + List selectItems = new ArrayList<>(); + CommonParser.visitCriteriaClause(whereClause, selectItems); + sqlUpdate.setSelectWhereItems(selectItems); + + CriteriaClause allWhereClause = new CriteriaClause(); + allWhereClause.setLeftCriteria(onClause); + allWhereClause.setOperator("AND"); + allWhereClause.setRightCriteria(whereClause); + sqlMerge.setAllWhereClause(allWhereClause); + } + sqlMerge.setSqlUpdate(sqlUpdate); + + StringBuilder sourceSql = new StringBuilder(); + sourceSql.append(SQL_SELECT_PREFIX); + sourceSql.append(String.join(", ", sqlUpdate.getUpdateValues())); + sourceSql.append(", ").append(String.join(", ", sqlMerge.getOnSelectItems())); + if (sqlUpdate.getSelectWhereItems() != null) { + sourceSql.append(", ").append(String.join(", ", sqlUpdate.getSelectWhereItems())); + } + sourceSql.append(" FROM ("); + sourceSql.append(usingSQL).append(") ").append(sqlMerge.getUsingAlias()).append(" "); + sourceSql.append("WHERE EXISTS (SELECT 1 FROM "); + sourceSql.append(sqlMerge.getTable()); + sourceSql.append(" WHERE ").append(conditionSql).append(" )"); + + sqlUpdate.setSourceSql(sourceSql.toString()); + } +} diff --git a/ares-parser/src/main/java/com/github/ares/parser/sqlparser/sparksql/SelectSqlParser.java b/ares-parser/src/main/java/com/github/ares/parser/sqlparser/sparksql/SelectSqlParser.java new file mode 100644 index 0000000..2861c27 --- /dev/null +++ b/ares-parser/src/main/java/com/github/ares/parser/sqlparser/sparksql/SelectSqlParser.java @@ -0,0 +1,68 @@ +package com.github.ares.parser.sqlparser.sparksql; + +import com.github.ares.common.exceptions.ParseException; +import com.github.ares.parser.antlr4.sparksql.SqlBaseParser; +import com.github.ares.parser.sqlparser.model.SQLHint; +import com.github.ares.parser.sqlparser.model.SQLSelect; +import org.apache.commons.lang3.tuple.Pair; + +import java.io.ByteArrayInputStream; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; + +import static com.github.ares.parser.sqlparser.sparksql.CommonParser.UNSUPPORTED_EXP_MSG_WITH_PARAM; +import static com.github.ares.parser.utils.PLParserUtil.clearParam; + +public class SelectSqlParser { + private SelectSqlParser() { + } + + public static SQLSelect parseSelect(String sql) { + SQLSelect sqlSelect = new SQLSelect(); + try (InputStream in = new ByteArrayInputStream(sql.getBytes(StandardCharsets.UTF_8))) { + SqlBaseParser parser = CommonParser.parseSql(in); + SqlBaseParser.QueryContext queryContext = parser.query(); + + if (!(queryContext.queryTerm() instanceof SqlBaseParser.QueryTermDefaultContext)) { + throw new ParseException(String.format(UNSUPPORTED_EXP_MSG_WITH_PARAM, sql)); + } + SqlBaseParser.QueryTermDefaultContext queryTermDefaultContext = (SqlBaseParser.QueryTermDefaultContext) queryContext.queryTerm(); + if (!(queryTermDefaultContext.queryPrimary() instanceof SqlBaseParser.QueryPrimaryDefaultContext)) { + throw new ParseException(String.format(UNSUPPORTED_EXP_MSG_WITH_PARAM, sql)); + } + SqlBaseParser.QueryPrimaryDefaultContext queryPrimaryDefaultContext = (SqlBaseParser.QueryPrimaryDefaultContext) queryTermDefaultContext.queryPrimary(); + if (!(queryPrimaryDefaultContext.querySpecification() instanceof SqlBaseParser.RegularQuerySpecificationContext)) { + throw new ParseException(String.format(UNSUPPORTED_EXP_MSG_WITH_PARAM, sql)); + } + SqlBaseParser.RegularQuerySpecificationContext regularQuerySpecificationContext = (SqlBaseParser.RegularQuerySpecificationContext) queryPrimaryDefaultContext.querySpecification(); + SqlBaseParser.IntoClauseContext intoClauseContext = regularQuerySpecificationContext.selectClause().intoClause(); + if (intoClauseContext == null) { + Pair, String> hintsWithSql = HintParser.parseSelectHints(sql, + queryPrimaryDefaultContext); + sqlSelect.setHints(hintsWithSql.getLeft()); + sqlSelect.setSourceSql(hintsWithSql.getRight()); + } else { + List intoParams = new ArrayList<>(); + intoClauseContext.expression().forEach(expressionContext -> intoParams.add(expressionContext.getText())); + + sqlSelect.setIntoParams(new ArrayList<>()); + for (String intoParam : intoParams) { + intoParam = clearParam(intoParam); + sqlSelect.getIntoParams().add(intoParam); + } + + Pair, String> hintsWithSql = HintParser.parseSelectHints(sql, + queryPrimaryDefaultContext); + sqlSelect.setHints(hintsWithSql.getLeft()); + sqlSelect.setSourceSql(hintsWithSql.getRight()); + } + } catch (ParseException e) { + throw e; + } catch (Exception e) { + throw new ParseException(e.getMessage(), e); + } + return sqlSelect; + } +} diff --git a/ares-parser/src/main/java/com/github/ares/parser/sqlparser/sparksql/SparkSqlParser.java b/ares-parser/src/main/java/com/github/ares/parser/sqlparser/sparksql/SparkSqlParser.java index 3f42659..ed4e5f4 100644 --- a/ares-parser/src/main/java/com/github/ares/parser/sqlparser/sparksql/SparkSqlParser.java +++ b/ares-parser/src/main/java/com/github/ares/parser/sqlparser/sparksql/SparkSqlParser.java @@ -1,499 +1,57 @@ package com.github.ares.parser.sqlparser.sparksql; -import com.github.ares.api.common.CriteriaClause; -import com.github.ares.api.common.EngineType; -import com.github.ares.api.common.ExecutionEngineType; import com.github.ares.common.exceptions.ParseException; -import com.github.ares.org.antlr.v4.runtime.CharStream; -import com.github.ares.org.antlr.v4.runtime.CharStreams; -import com.github.ares.org.antlr.v4.runtime.CommonTokenStream; -import com.github.ares.org.antlr.v4.runtime.tree.ParseTree; -import com.github.ares.org.antlr.v4.runtime.tree.TerminalNodeImpl; -import com.github.ares.parser.antlr4.CaseChangingCharStream; -import com.github.ares.parser.antlr4.CustomErrorListener; -import com.github.ares.parser.antlr4.sparksql.SqlBaseLexer; import com.github.ares.parser.antlr4.sparksql.SqlBaseParser; import com.github.ares.parser.sqlparser.SQLParser; import com.github.ares.parser.sqlparser.model.SQLDelete; -import com.github.ares.parser.sqlparser.model.SQLHint; import com.github.ares.parser.sqlparser.model.SQLInsert; import com.github.ares.parser.sqlparser.model.SQLMerge; import com.github.ares.parser.sqlparser.model.SQLSelect; import com.github.ares.parser.sqlparser.model.SQLTruncate; import com.github.ares.parser.sqlparser.model.SQLUpdate; -import org.apache.commons.lang3.StringUtils; -import org.apache.commons.lang3.tuple.Pair; import java.io.ByteArrayInputStream; -import java.io.IOException; import java.io.InputStream; import java.nio.charset.StandardCharsets; -import java.util.ArrayList; -import java.util.List; -import java.util.StringJoiner; -import static com.github.ares.parser.sqlparser.sparksql.CriteriaParser.parseWhereClause; -import static com.github.ares.parser.sqlparser.sparksql.CriteriaParser.visitOnWhereClause; -import static com.github.ares.parser.utils.PLParserUtil.clearParam; -import static com.github.ares.parser.utils.PLParserUtil.getFullText; +import static com.github.ares.parser.sqlparser.sparksql.CommonParser.UNSUPPORTED_EXP_MSG_WITH_PARAM; public class SparkSqlParser implements SQLParser { - @Override - public SQLSelect parseSelect(String sql) { - SQLSelect sqlSelect = new SQLSelect(); - try (InputStream in = new ByteArrayInputStream(sql.getBytes(StandardCharsets.UTF_8))) { - SqlBaseParser parser = parseSql(in); - SqlBaseParser.QueryContext queryContext = parser.query(); - - if (!(queryContext.queryTerm() instanceof SqlBaseParser.QueryTermDefaultContext)) { - throw new ParseException(String.format("unsupported syntax: %s", sql)); - } - SqlBaseParser.QueryTermDefaultContext queryTermDefaultContext = (SqlBaseParser.QueryTermDefaultContext) queryContext.queryTerm(); - if (!(queryTermDefaultContext.queryPrimary() instanceof SqlBaseParser.QueryPrimaryDefaultContext)) { - throw new ParseException(String.format("unsupported syntax: %s", sql)); - } - SqlBaseParser.QueryPrimaryDefaultContext queryPrimaryDefaultContext = (SqlBaseParser.QueryPrimaryDefaultContext) queryTermDefaultContext.queryPrimary(); - if (!(queryPrimaryDefaultContext.querySpecification() instanceof SqlBaseParser.RegularQuerySpecificationContext)) { - throw new ParseException(String.format("unsupported syntax: %s", sql)); - } - SqlBaseParser.RegularQuerySpecificationContext regularQuerySpecificationContext = (SqlBaseParser.RegularQuerySpecificationContext) queryPrimaryDefaultContext.querySpecification(); - SqlBaseParser.IntoClauseContext intoClauseContext = regularQuerySpecificationContext.selectClause().intoClause(); - if (intoClauseContext == null) { - Pair, String> hintsWithSql = parseSelectHints(sql, queryPrimaryDefaultContext); - sqlSelect.setHints(hintsWithSql.getLeft()); - sqlSelect.setSourceSql(hintsWithSql.getRight()); - } else { - List intoParams = new ArrayList<>(); - intoClauseContext.expression().forEach(expressionContext -> intoParams.add(expressionContext.getText())); - sqlSelect.setIntoParams(new ArrayList<>()); - for (String intoParam : intoParams) { - intoParam = clearParam(intoParam); - sqlSelect.getIntoParams().add(intoParam); - } - Pair, String> hintsWithSql = parseSelectHints(sql, queryPrimaryDefaultContext); - sqlSelect.setHints(hintsWithSql.getLeft()); - sqlSelect.setSourceSql(hintsWithSql.getRight()); - } - } catch (ParseException e) { - throw e; - } catch (Exception e) { - throw new ParseException(e.getMessage(), e); - } - return sqlSelect; + @Override + public SQLSelect parseSelect(String sql) { + return SelectSqlParser.parseSelect(sql); } @Override public SQLInsert parseInsert(String sql) { - SQLInsert sqlInsert = new SQLInsert(); - - try (InputStream in = new ByteArrayInputStream(sql.getBytes(StandardCharsets.UTF_8))) { - - SqlBaseParser parser = parseSql(in); - SqlBaseParser.DmlStatementNoWithContext dmlStatementNoWithContext = parser.dmlStatementNoWith(); - - if (!(dmlStatementNoWithContext instanceof SqlBaseParser.SingleInsertQueryContext)) { - throw new ParseException(String.format("unsupported syntax: %s", sql)); - } - SqlBaseParser.SingleInsertQueryContext singleInsertQueryContext = (SqlBaseParser.SingleInsertQueryContext) dmlStatementNoWithContext; - SqlBaseParser.InsertIntoTableContext insertIntoContext = (SqlBaseParser.InsertIntoTableContext) singleInsertQueryContext.insertInto(); - if (insertIntoContext.getChildCount() < 3 || - !(insertIntoContext.getChild(1) instanceof TerminalNodeImpl) || - !"INTO".equalsIgnoreCase(insertIntoContext.getChild(1).getText()) || - !(insertIntoContext.getChild(2) instanceof SqlBaseParser.MultipartIdentifierContext)) { - throw new ParseException(String.format("unsupported syntax: %s", sql)); - } - SqlBaseParser.MultipartIdentifierContext multipartIdentifierContext = insertIntoContext.multipartIdentifier(); - sqlInsert.setTable(multipartIdentifierContext.getText()); - if (insertIntoContext.identifierList() != null) { - SqlBaseParser.IdentifierListContext identifierListContext = insertIntoContext.identifierList(); - if (identifierListContext.identifierSeq() == null) { - throw new ParseException(String.format("unsupported syntax: %s", sql)); - } - - SqlBaseParser.IdentifierSeqContext identifierSeqContext = identifierListContext.identifierSeq(); - for (SqlBaseParser.ErrorCapturingIdentifierContext errorCapturingIdentifierContext : identifierSeqContext.errorCapturingIdentifier()) { - sqlInsert.getColumns().add(errorCapturingIdentifierContext.getText()); - } - } - - String selectSql; - SqlBaseParser.QueryContext queryContext = singleInsertQueryContext.query(); - SqlBaseParser.QueryPrimaryContext queryPrimaryContext = ((SqlBaseParser.QueryTermDefaultContext) queryContext.queryTerm()).queryPrimary(); - if (queryPrimaryContext instanceof SqlBaseParser.InlineTableDefault1Context) { - List expressionContexts = ((SqlBaseParser.InlineTableDefault1Context) - ((SqlBaseParser.QueryTermDefaultContext) queryContext.queryTerm()).queryPrimary()).inlineTable().expression(); - List valuesExpressions = new ArrayList<>(); - List> valuesArray = new ArrayList<>(); - for (SqlBaseParser.ExpressionContext expressionContext : expressionContexts) { - if (expressionContext.getChildCount() < 1 || !(expressionContext.getChild(0) instanceof SqlBaseParser.PredicatedContext)) { - throw new ParseException(String.format("unsupported syntax: %s", sql)); - } - SqlBaseParser.PredicatedContext predicatedContext = (SqlBaseParser.PredicatedContext) expressionContext.getChild(0); - if (predicatedContext.valueExpression().getChildCount() < 1 || !(predicatedContext.valueExpression().getChild(0) instanceof SqlBaseParser.RowConstructorContext)) { - throw new ParseException(String.format("unsupported syntax: %s", sql)); - } - List values = new ArrayList<>(); - SqlBaseParser.RowConstructorContext rowConstructorContext = (SqlBaseParser.RowConstructorContext) predicatedContext.valueExpression().getChild(0); - for (int i = 0; i < rowConstructorContext.getChildCount(); i++) { - ParseTree item = rowConstructorContext.getChild(i); - if (item instanceof TerminalNodeImpl) { - continue; - } - values.add(getFullText(item)); - } - valuesArray.add(values); - sqlInsert.setValuesArray(valuesArray); - - String selectExpression = getFullText(expressionContext); - if (!selectExpression.startsWith("(") && !selectExpression.endsWith(")")) { - throw new ParseException(String.format("unsupported syntax: %s", sql)); - } - selectExpression = selectExpression.substring(0, selectExpression.length() - 1).substring(1); - valuesExpressions.add(selectExpression); - } - StringJoiner stringJoiner = new StringJoiner(" UNION ALL "); - valuesExpressions.forEach(valuesExpression -> { - stringJoiner.add("SELECT " + valuesExpression); - }); - - selectSql = stringJoiner.toString(); - } else if (queryPrimaryContext instanceof SqlBaseParser.QueryPrimaryDefaultContext) { - Pair, String> hintsWithSql = parseSelectHints(sql, (SqlBaseParser.QueryPrimaryDefaultContext) queryPrimaryContext); - sqlInsert.setHints(hintsWithSql.getLeft()); - selectSql = hintsWithSql.getRight(); - } else { - throw new ParseException(String.format("unsupported syntax: %s", sql)); - } - sqlInsert.setSourceSql(selectSql); - } catch (ParseException e) { - throw e; - } catch (Exception e) { - throw new ParseException(e.getMessage(), e); - } - return sqlInsert; + return InsertSqlParser.parseInsert(sql); } @Override public SQLUpdate parseUpdate(String sql) { - SQLUpdate sqlUpdate = new SQLUpdate(); - try (InputStream in = new ByteArrayInputStream(sql.getBytes(StandardCharsets.UTF_8))) { - SqlBaseParser parser = parseSql(in); - SqlBaseParser.DmlStatementNoWithContext dmlStatementNoWithContext = parser.dmlStatementNoWith(); - - if (!(dmlStatementNoWithContext instanceof SqlBaseParser.UpdateTableContext)) { - throw new ParseException(String.format("unsupported syntax: %s", sql)); - } - - SqlBaseParser.UpdateTableContext updateTableContext = (SqlBaseParser.UpdateTableContext) dmlStatementNoWithContext; - - SqlBaseParser.MultipartIdentifierContext mappingTable = updateTableContext.multipartIdentifier(0); - sqlUpdate.setTable(mappingTable.getText()); - - if (updateTableContext.source != null || updateTableContext.sourceQuery != null) { - if (updateTableContext.tableAlias().size() != 2) { - throw new ParseException(String.format("Alias not defined for source table or target table: %s", sql)); - } - if (updateTableContext.source != null) { - sqlUpdate.setJoinTable(updateTableContext.source.getText()); - String sourceTableAlias = updateTableContext.tableAlias().get(1).getText(); - sqlUpdate.setJoinAlias(sourceTableAlias); - sqlUpdate.setAlias(updateTableContext.tableAlias().get(0).getText()); - } else if (updateTableContext.sourceQuery != null) { - if (!(updateTableContext.sourceQuery.queryTerm() instanceof SqlBaseParser.QueryTermDefaultContext)) { - throw new ParseException(String.format("unsupported syntax: %s", sql)); - } - SqlBaseParser.QueryPrimaryContext queryPrimaryContext = ((SqlBaseParser.QueryTermDefaultContext) updateTableContext.sourceQuery.queryTerm()).queryPrimary(); - if (!(queryPrimaryContext instanceof SqlBaseParser.QueryPrimaryDefaultContext)) { - throw new ParseException(String.format("unsupported syntax: %s", sql)); - } - String sourceTableAlias = updateTableContext.tableAlias().get(1).getText(); - sqlUpdate.setJoinAlias(sourceTableAlias); - sqlUpdate.setAlias(updateTableContext.tableAlias().get(0).getText()); - Pair, String> hintsWithSql = parseSelectHints(sql, (SqlBaseParser.QueryPrimaryDefaultContext) queryPrimaryContext); - sqlUpdate.setHints(hintsWithSql.getLeft()); - sqlUpdate.setJoinSql(hintsWithSql.getRight()); - } - } else if (!updateTableContext.tableAlias().isEmpty()) { - sqlUpdate.setAlias(updateTableContext.tableAlias().get(0).getText()); - } - - List assignmentContexts = updateTableContext.setClause().assignmentList().assignment(); - for (SqlBaseParser.AssignmentContext assignmentContext : assignmentContexts) { - List identifierContexts = assignmentContext.multipartIdentifier().errorCapturingIdentifier(); - if (StringUtils.isBlank(sqlUpdate.getAlias())) { - if (identifierContexts.size() == 2 && identifierContexts.get(0).getText().equalsIgnoreCase(sqlUpdate.getAlias())) { - throw new ParseException("column owner must be same as table alias in update statement: " + sql); - } - } - String targetCol = assignmentContext.multipartIdentifier().errorCapturingIdentifier.getText(); - String sourceExpression = getFullText(assignmentContext.expression()); - sqlUpdate.getUpdateColumns().add(targetCol); - sqlUpdate.getUpdateValues().add(sourceExpression); - } - if (updateTableContext.whereClause() == null) { - throw new ParseException("update SQL must have WHERE clause: " + sql); - } - CriteriaClause criteriaClause = new CriteriaClause(); - SqlBaseParser.BooleanExpressionContext expressionContext = updateTableContext.whereClause().booleanExpression(); - CriteriaParser.parseWhereClause(expressionContext, criteriaClause, sqlUpdate.getAlias()); - sqlUpdate.setWhereClause(criteriaClause); - - List selectItems = new ArrayList<>(); - visitCriteriaClause(criteriaClause, selectItems); - - StringBuilder selectSql = new StringBuilder(); - selectSql.append("SELECT "); - selectSql.append(String.join(", ", sqlUpdate.getUpdateValues())); - selectSql.append(", ").append(String.join(", ", selectItems)); - if (StringUtils.isNotBlank(sqlUpdate.getJoinTable())) { - selectSql.append(" FROM ").append(sqlUpdate.getJoinTable()).append(" ").append(sqlUpdate.getJoinAlias()); - } else if (StringUtils.isNotBlank(sqlUpdate.getJoinSql())) { - selectSql.append(" FROM (").append(sqlUpdate.getJoinSql()).append(") ").append(sqlUpdate.getJoinAlias()); - } - sqlUpdate.setSourceSql(selectSql.toString()); - } catch (ParseException e) { - throw e; - } catch (Exception e) { - throw new ParseException(e.getMessage(), e); - } - return sqlUpdate; + return UpdateSqlParser.parseUpdate(sql); } @Override public SQLDelete parseDelete(String sql) { - SQLDelete sqlDelete = new SQLDelete(); - try (InputStream in = new ByteArrayInputStream(sql.getBytes(StandardCharsets.UTF_8))) { - SqlBaseParser parser = parseSql(in); - SqlBaseParser.DmlStatementNoWithContext dmlStatementNoWithContext = parser.dmlStatementNoWith(); - - if (!(dmlStatementNoWithContext instanceof SqlBaseParser.DeleteFromTableContext)) { - throw new ParseException(String.format("unsupported syntax: %s", sql)); - } - - SqlBaseParser.DeleteFromTableContext deleteFromTableContext = (SqlBaseParser.DeleteFromTableContext) dmlStatementNoWithContext; - - SqlBaseParser.MultipartIdentifierContext mappingTable = deleteFromTableContext.multipartIdentifier(0); - sqlDelete.setTable(mappingTable.getText()); - - if (deleteFromTableContext.source != null || deleteFromTableContext.sourceQuery != null) { - if (deleteFromTableContext.tableAlias().size() != 2) { - throw new ParseException(String.format("Alias not defined for source table or target table: %s", sql)); - } - if (deleteFromTableContext.source != null) { - sqlDelete.setJoinTable(deleteFromTableContext.source.getText()); - String sourceTableAlias = deleteFromTableContext.tableAlias().get(1).getText(); - sqlDelete.setJoinAlias(sourceTableAlias); - sqlDelete.setAlias(deleteFromTableContext.tableAlias().get(0).getText()); - } else if (deleteFromTableContext.sourceQuery != null) { - if (!(deleteFromTableContext.sourceQuery.queryTerm() instanceof SqlBaseParser.QueryTermDefaultContext)) { - throw new ParseException(String.format("unsupported syntax: %s", sql)); - } - SqlBaseParser.QueryPrimaryContext queryPrimaryContext = ((SqlBaseParser.QueryTermDefaultContext) deleteFromTableContext.sourceQuery.queryTerm()).queryPrimary(); - if (!(queryPrimaryContext instanceof SqlBaseParser.QueryPrimaryDefaultContext)) { - throw new ParseException(String.format("unsupported syntax: %s", sql)); - } - String sourceTableAlias = deleteFromTableContext.tableAlias().get(1).getText(); - sqlDelete.setJoinAlias(sourceTableAlias); - sqlDelete.setAlias(deleteFromTableContext.tableAlias().get(0).getText()); - Pair, String> hintsWithSql = parseSelectHints(sql, (SqlBaseParser.QueryPrimaryDefaultContext) queryPrimaryContext); - sqlDelete.setHints(hintsWithSql.getLeft()); - sqlDelete.setJoinSql(hintsWithSql.getRight()); - } - } else if (!deleteFromTableContext.tableAlias().isEmpty()) { - sqlDelete.setAlias(deleteFromTableContext.tableAlias().get(0).getText()); - } - - if (deleteFromTableContext.whereClause() == null) { - throw new ParseException("delete SQL must have WHERE clause: " + sql); - } - CriteriaClause criteriaClause = new CriteriaClause(); - SqlBaseParser.BooleanExpressionContext expressionContext = deleteFromTableContext.whereClause().booleanExpression(); - CriteriaParser.parseWhereClause(expressionContext, criteriaClause, sqlDelete.getAlias()); - sqlDelete.setWhereClause(criteriaClause); - - List selectItems = new ArrayList<>(); - visitCriteriaClause(criteriaClause, selectItems); - - StringBuilder selectSql = new StringBuilder(); - selectSql.append("SELECT "); - selectSql.append(String.join(", ", selectItems)); - if (StringUtils.isNotBlank(sqlDelete.getJoinTable())) { - selectSql.append(" FROM ").append(sqlDelete.getJoinTable()).append(" ").append(sqlDelete.getJoinAlias()); - } else if (StringUtils.isNotBlank(sqlDelete.getJoinSql())) { - selectSql.append(" FROM (").append(sqlDelete.getJoinSql()).append(") ").append(sqlDelete.getJoinAlias()); - } - sqlDelete.setSourceSql(selectSql.toString()); - } catch (ParseException e) { - throw e; - } catch (Exception e) { - throw new ParseException(e.getMessage(), e); - } - return sqlDelete; + return DeleteSqlParser.parseDelete(sql); } @Override public SQLMerge parseMerge(String sql) { - SQLMerge sqlMerge = new SQLMerge(); - try (InputStream in = new ByteArrayInputStream(sql.getBytes(StandardCharsets.UTF_8))) { - SqlBaseParser parser = parseSql(in); - SqlBaseParser.DmlStatementNoWithContext dmlStatementNoWithContext = parser.dmlStatementNoWith(); - - if (!(dmlStatementNoWithContext instanceof SqlBaseParser.MergeIntoTableContext)) { - throw new ParseException(String.format("unsupported syntax: %s", sql)); - } - getFullText(dmlStatementNoWithContext); - - SqlBaseParser.MergeIntoTableContext mergeIntoTableContext = (SqlBaseParser.MergeIntoTableContext) dmlStatementNoWithContext; - SqlBaseParser.MultipartIdentifierContext mappingTable = mergeIntoTableContext.multipartIdentifier(0); - sqlMerge.setTable(mappingTable.getText()); - - if (mergeIntoTableContext.source != null || mergeIntoTableContext.sourceQuery != null) { - if (mergeIntoTableContext.tableAlias().size() != 2) { - throw new ParseException(String.format("Alias not defined for source table or target table: %s", sql)); - } - if (mergeIntoTableContext.source != null) { - sqlMerge.setUsingTable(mergeIntoTableContext.source.getText()); - String sourceTableAlias = mergeIntoTableContext.tableAlias().get(1).getText(); - sqlMerge.setUsingAlias(sourceTableAlias); - sqlMerge.setAlias(mergeIntoTableContext.tableAlias().get(0).getText()); - } else if (mergeIntoTableContext.sourceQuery != null) { - if (!(mergeIntoTableContext.sourceQuery.queryTerm() instanceof SqlBaseParser.QueryTermDefaultContext)) { - throw new ParseException(String.format("unsupported syntax: %s", sql)); - } - SqlBaseParser.QueryPrimaryContext queryPrimaryContext = ((SqlBaseParser.QueryTermDefaultContext) mergeIntoTableContext.sourceQuery.queryTerm()).queryPrimary(); - if (!(queryPrimaryContext instanceof SqlBaseParser.QueryPrimaryDefaultContext)) { - throw new ParseException(String.format("unsupported syntax: %s", sql)); - } - String sourceTableAlias = mergeIntoTableContext.tableAlias().get(1).getText(); - sqlMerge.setUsingAlias(sourceTableAlias); - sqlMerge.setAlias(mergeIntoTableContext.tableAlias().get(0).getText()); - Pair, String> hintsWithSql = parseSelectHints(sql, (SqlBaseParser.QueryPrimaryDefaultContext) queryPrimaryContext); - sqlMerge.setHints(hintsWithSql.getLeft()); - sqlMerge.setUsingSql(hintsWithSql.getRight()); - } - } else if (!mergeIntoTableContext.tableAlias().isEmpty()) { - sqlMerge.setAlias(mergeIntoTableContext.tableAlias().get(0).getText()); - } - - CriteriaClause onClause = new CriteriaClause(); - SqlBaseParser.BooleanExpressionContext onExpressionContext = mergeIntoTableContext.mergeCondition; - CriteriaParser.parseWhereClause(onExpressionContext, onClause, sqlMerge.getAlias()); - List onSelectItems = new ArrayList<>(); - visitCriteriaClause(onClause, onSelectItems); - sqlMerge.setOnSelectItems(onSelectItems); - - String usingSQL; - if (sqlMerge.getUsingTable() != null) { - usingSQL = "SELECT * FROM " + sqlMerge.getUsingTable(); - } else { - usingSQL = sqlMerge.getUsingSql(); - } - - StringBuilder conditionSql = new StringBuilder(); - visitOnWhereClause(onClause, conditionSql); - - if (mergeIntoTableContext.matchedClause().size() > 1 || mergeIntoTableContext.notMatchedClause().size() > 1) { - throw new ParseException(String.format("unsupported syntax: %s", sql)); - } - if (!mergeIntoTableContext.notMatchedClause().isEmpty()) { - SQLInsert sqlInsert = new SQLInsert(); - sqlInsert.setTable(mappingTable.getText()); - - SqlBaseParser.NotMatchedActionContext notMatchedAction = mergeIntoTableContext.notMatchedClause().get(0).notMatchedAction(); - - for (SqlBaseParser.MultipartIdentifierContext multipartIdentifierContext : notMatchedAction.multipartIdentifierList().multipartIdentifier()) { - sqlInsert.getColumns().add(multipartIdentifierContext.errorCapturingIdentifier.identifier().getText()); - } - - List> valuesArray = new ArrayList<>(); - List values = new ArrayList<>(); - for (SqlBaseParser.ExpressionContext expressionContext : notMatchedAction.expression()) { - values.add(getFullText(expressionContext)); - } - valuesArray.add(values); - sqlInsert.setValuesArray(valuesArray); - sqlMerge.setSqlInsert(sqlInsert); - - StringBuilder sourceSql = new StringBuilder(); - sourceSql.append("SELECT "); - sourceSql.append(String.join(", ", sqlInsert.getValuesArray().get(0))); - sourceSql.append(" FROM ("); - sourceSql.append(usingSQL).append(") ").append(sqlMerge.getUsingAlias()).append(" "); - sourceSql.append("WHERE NOT EXISTS (SELECT 1 FROM "); - sourceSql.append(sqlMerge.getTable()); - sourceSql.append(" WHERE ").append(conditionSql).append(" )"); - sqlInsert.setSourceSql(sourceSql.toString()); - } - - if (!mergeIntoTableContext.matchedClause().isEmpty()) { - SQLUpdate sqlUpdate = new SQLUpdate(); - sqlUpdate.setTable(mappingTable.getText()); - - sqlUpdate.setAlias(sqlMerge.getAlias()); - sqlUpdate.setJoinAlias(sqlMerge.getUsingAlias()); - if (sqlMerge.getUsingTable() != null) { - sqlUpdate.setJoinTable(sqlMerge.getUsingTable()); - } else if (sqlMerge.getUsingSql() != null) { - sqlUpdate.setJoinSql(sqlMerge.getUsingSql()); - } - - SqlBaseParser.MatchedActionContext matchedActionContext = mergeIntoTableContext.matchedClause().get(0).matchedAction(); - - for (SqlBaseParser.AssignmentContext assignmentContext : matchedActionContext.assignmentList().assignment()) { - sqlUpdate.getUpdateColumns().add(assignmentContext.key.errorCapturingIdentifier.identifier().getText()); - sqlUpdate.getUpdateValues().add(getFullText(assignmentContext.value)); - } - - if (matchedActionContext.booleanExpression() != null) { - CriteriaClause whereClause = new CriteriaClause(); - parseWhereClause(matchedActionContext.booleanExpression(), whereClause, sqlUpdate.getAlias()); - - sqlUpdate.setWhereClause(whereClause); - List selectItems = new ArrayList<>(); - visitCriteriaClause(whereClause, selectItems); - sqlUpdate.setSelectWhereItems(selectItems); - - CriteriaClause allWhereClause = new CriteriaClause(); - allWhereClause.setLeftCriteria(onClause); - allWhereClause.setOperator("AND"); - allWhereClause.setRightCriteria(whereClause); - sqlMerge.setAllWhereClause(allWhereClause); - } - sqlMerge.setSqlUpdate(sqlUpdate); - - StringBuilder sourceSql = new StringBuilder(); - sourceSql.append("SELECT "); - sourceSql.append(String.join(", ", sqlUpdate.getUpdateValues())); - sourceSql.append(", ").append(String.join(", ", sqlMerge.getOnSelectItems())); - if (sqlUpdate.getSelectWhereItems() != null) { - sourceSql.append(", ").append(String.join(", ", sqlUpdate.getSelectWhereItems())); - } - sourceSql.append(" FROM ("); - sourceSql.append(usingSQL).append(") ").append(sqlMerge.getUsingAlias()).append(" "); - sourceSql.append("WHERE EXISTS (SELECT 1 FROM "); - sourceSql.append(sqlMerge.getTable()); - sourceSql.append(" WHERE ").append(conditionSql).append(" )"); - - sqlUpdate.setSourceSql(sourceSql.toString()); - } - } catch (ParseException e) { - throw e; - } catch (Exception e) { - throw new ParseException(e.getMessage(), e); - } - return sqlMerge; + return MergeSqlParser.parseMerge(sql); } @Override public SQLTruncate parseTruncate(String sql) { SQLTruncate sqlTruncate = new SQLTruncate(); try (InputStream in = new ByteArrayInputStream(sql.getBytes(StandardCharsets.UTF_8))) { - SqlBaseParser parser = parseSql(in); + SqlBaseParser parser = CommonParser.parseSql(in); SqlBaseParser.StatementContext statementContext = parser.statement(); if (!(statementContext instanceof SqlBaseParser.TruncateTableContext)) { - throw new ParseException(String.format("unsupported syntax: %s", sql)); + throw new ParseException(String.format(UNSUPPORTED_EXP_MSG_WITH_PARAM, sql)); } SqlBaseParser.TruncateTableContext truncateTableContext = (SqlBaseParser.TruncateTableContext) statementContext; @@ -505,75 +63,4 @@ public SQLTruncate parseTruncate(String sql) { } return sqlTruncate; } - - private SqlBaseParser parseSql(InputStream in) throws IOException { - CharStream s = CharStreams.fromStream(in); - CaseChangingCharStream upper = new CaseChangingCharStream(s, true); - - CustomErrorListener lexerErrorListener = new CustomErrorListener(); - SqlBaseLexer lexer = new SqlBaseLexer(upper); - lexer.removeErrorListeners(); - lexer.addErrorListener(lexerErrorListener); - CommonTokenStream tokens = new CommonTokenStream(lexer); - SqlBaseParser parser = new SqlBaseParser(tokens); - CustomErrorListener parserErrorListener = new CustomErrorListener(); - parser.removeErrorListeners(); - parser.addErrorListener(parserErrorListener); - return parser; - } - - private Pair, String> parseSelectHints(String sql, SqlBaseParser.QueryPrimaryDefaultContext queryPrimaryDefaultContext) { - if (!(queryPrimaryDefaultContext.querySpecification() instanceof SqlBaseParser.RegularQuerySpecificationContext)) { - throw new ParseException(String.format("unsupported syntax: %s", sql)); - } - SqlBaseParser.RegularQuerySpecificationContext regularQuerySpecificationContext = (SqlBaseParser.RegularQuerySpecificationContext) queryPrimaryDefaultContext.querySpecification(); - if (regularQuerySpecificationContext.selectClause() == null) { - throw new ParseException(String.format("unsupported syntax: %s", sql)); - } - - List sqlHints = new ArrayList<>(); - SqlBaseParser.SelectClauseContext selectClauseContext = regularQuerySpecificationContext.selectClause(); - if (!selectClauseContext.hints.isEmpty()) { - for (int i = 0; i < selectClauseContext.hints.size(); i++) { - SQLHint sqlHint = new SQLHint(); - SqlBaseParser.HintContext hintContext = selectClauseContext.hint(i); - String hintName = hintContext.hintStatement.hintName.getText(); - sqlHint.setHintName(hintName); - for (SqlBaseParser.PrimaryExpressionContext primaryExpressionContext : hintContext.hintStatement.parameters) { - String parameter = getFullText(primaryExpressionContext); - if (StringUtils.isNotBlank(parameter)) { - sqlHint.getArguments().add(parameter); - } - } - sqlHints.add(sqlHint); - } - } - - // filter out hints and into clause from select clause - StringBuilder selectSql = new StringBuilder(); - for (ParseTree child : regularQuerySpecificationContext.children) { - if (child instanceof SqlBaseParser.SelectClauseContext) { - SqlBaseParser.SelectClauseContext selectClauseContext1 = (SqlBaseParser.SelectClauseContext) child; - for (ParseTree grandChild : selectClauseContext1.children) { - if (grandChild instanceof SqlBaseParser.IntoClauseContext) { - continue; - } else if (grandChild instanceof SqlBaseParser.HintContext) { - if (ExecutionEngineType.engineType == EngineType.SPARK) { - String hint = getFullText(grandChild); - String hintLower = hint.toLowerCase(); - if (hintLower.contains("mapjoin") || hintLower.contains("broadcast")) { - selectSql.append(hint).append(" "); - } - } - continue; - } - selectSql.append(getFullText(grandChild)).append(" "); - } - } else { - selectSql.append(getFullText(child)).append(" "); - } - } - - return Pair.of(sqlHints, selectSql.toString()); - } } diff --git a/ares-parser/src/main/java/com/github/ares/parser/sqlparser/sparksql/UpdateSqlParser.java b/ares-parser/src/main/java/com/github/ares/parser/sqlparser/sparksql/UpdateSqlParser.java new file mode 100644 index 0000000..3d11f63 --- /dev/null +++ b/ares-parser/src/main/java/com/github/ares/parser/sqlparser/sparksql/UpdateSqlParser.java @@ -0,0 +1,119 @@ +package com.github.ares.parser.sqlparser.sparksql; + +import com.github.ares.api.common.CriteriaClause; +import com.github.ares.common.exceptions.ParseException; +import com.github.ares.parser.antlr4.sparksql.SqlBaseParser; +import com.github.ares.parser.sqlparser.model.SQLHint; +import com.github.ares.parser.sqlparser.model.SQLUpdate; +import org.apache.commons.lang3.StringUtils; +import org.apache.commons.lang3.tuple.Pair; + +import java.io.ByteArrayInputStream; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; + +import static com.github.ares.parser.sqlparser.sparksql.CommonParser.SQL_SELECT_PREFIX; +import static com.github.ares.parser.sqlparser.sparksql.CommonParser.UNSUPPORTED_EXP_MSG_WITH_PARAM; +import static com.github.ares.parser.utils.PLParserUtil.getFullText; + +public class UpdateSqlParser { + private UpdateSqlParser() { + } + + /** + * Parse update SQL and return SQLUpdate object. + * + * @param sql update SQL + * @return SQLUpdate object + */ + public static SQLUpdate parseUpdate(String sql) { + SQLUpdate sqlUpdate = new SQLUpdate(); + try (InputStream in = new ByteArrayInputStream(sql.getBytes(StandardCharsets.UTF_8))) { + SqlBaseParser parser = CommonParser.parseSql(in); + SqlBaseParser.DmlStatementNoWithContext dmlStatementNoWithContext = parser.dmlStatementNoWith(); + + if (!(dmlStatementNoWithContext instanceof SqlBaseParser.UpdateTableContext)) { + throw new ParseException(String.format(UNSUPPORTED_EXP_MSG_WITH_PARAM, sql)); + } + + SqlBaseParser.UpdateTableContext updateTableContext = (SqlBaseParser.UpdateTableContext) dmlStatementNoWithContext; + + SqlBaseParser.MultipartIdentifierContext mappingTable = updateTableContext.multipartIdentifier(0); + sqlUpdate.setTable(mappingTable.getText()); + + if (updateTableContext.source != null || updateTableContext.sourceQuery != null) { + parseSourceQuery(updateTableContext, sqlUpdate, sql); + } else if (!updateTableContext.tableAlias().isEmpty()) { + sqlUpdate.setAlias(updateTableContext.tableAlias().get(0).getText()); + } + + List assignmentContexts = updateTableContext.setClause().assignmentList().assignment(); + for (SqlBaseParser.AssignmentContext assignmentContext : assignmentContexts) { + List identifierContexts = assignmentContext.multipartIdentifier().errorCapturingIdentifier(); + if (StringUtils.isBlank(sqlUpdate.getAlias()) && identifierContexts.size() == 2 + && identifierContexts.get(0).getText().equalsIgnoreCase(sqlUpdate.getAlias())) { + throw new ParseException("column owner must be same as table alias in update statement: " + sql); + } + String targetCol = assignmentContext.multipartIdentifier().errorCapturingIdentifier.getText(); + String sourceExpression = getFullText(assignmentContext.expression()); + sqlUpdate.getUpdateColumns().add(targetCol); + sqlUpdate.getUpdateValues().add(sourceExpression); + } + if (updateTableContext.whereClause() == null) { + throw new ParseException("update SQL must have WHERE clause: " + sql); + } + CriteriaClause criteriaClause = new CriteriaClause(); + SqlBaseParser.BooleanExpressionContext expressionContext = updateTableContext.whereClause().booleanExpression(); + CriteriaParser.parseWhereClause(expressionContext, criteriaClause, sqlUpdate.getAlias()); + sqlUpdate.setWhereClause(criteriaClause); + + List selectItems = new ArrayList<>(); + CommonParser.visitCriteriaClause(criteriaClause, selectItems); + + StringBuilder selectSql = new StringBuilder(); + selectSql.append(SQL_SELECT_PREFIX); + selectSql.append(String.join(", ", sqlUpdate.getUpdateValues())); + selectSql.append(", ").append(String.join(", ", selectItems)); + if (StringUtils.isNotBlank(sqlUpdate.getJoinTable())) { + selectSql.append(" FROM ").append(sqlUpdate.getJoinTable()).append(" ").append(sqlUpdate.getJoinAlias()); + } else if (StringUtils.isNotBlank(sqlUpdate.getJoinSql())) { + selectSql.append(" FROM (").append(sqlUpdate.getJoinSql()).append(") ").append(sqlUpdate.getJoinAlias()); + } + sqlUpdate.setSourceSql(selectSql.toString()); + } catch (ParseException e) { + throw e; + } catch (Exception e) { + throw new ParseException(e.getMessage(), e); + } + return sqlUpdate; + } + + private static void parseSourceQuery(SqlBaseParser.UpdateTableContext updateTableContext, SQLUpdate sqlUpdate, String sql) { + if (updateTableContext.tableAlias().size() != 2) { + throw new ParseException(String.format("Alias not defined for source table or target table: %s", sql)); + } + if (updateTableContext.source != null) { + sqlUpdate.setJoinTable(updateTableContext.source.getText()); + String sourceTableAlias = updateTableContext.tableAlias().get(1).getText(); + sqlUpdate.setJoinAlias(sourceTableAlias); + sqlUpdate.setAlias(updateTableContext.tableAlias().get(0).getText()); + } else if (updateTableContext.sourceQuery != null) { + if (!(updateTableContext.sourceQuery.queryTerm() instanceof SqlBaseParser.QueryTermDefaultContext)) { + throw new ParseException(String.format(UNSUPPORTED_EXP_MSG_WITH_PARAM, sql)); + } + SqlBaseParser.QueryPrimaryContext queryPrimaryContext = ((SqlBaseParser.QueryTermDefaultContext) updateTableContext.sourceQuery.queryTerm()).queryPrimary(); + if (!(queryPrimaryContext instanceof SqlBaseParser.QueryPrimaryDefaultContext)) { + throw new ParseException(String.format(UNSUPPORTED_EXP_MSG_WITH_PARAM, sql)); + } + String sourceTableAlias = updateTableContext.tableAlias().get(1).getText(); + sqlUpdate.setJoinAlias(sourceTableAlias); + sqlUpdate.setAlias(updateTableContext.tableAlias().get(0).getText()); + Pair, String> hintsWithSql = HintParser.parseSelectHints(sql, + (SqlBaseParser.QueryPrimaryDefaultContext) queryPrimaryContext); + sqlUpdate.setHints(hintsWithSql.getLeft()); + sqlUpdate.setJoinSql(hintsWithSql.getRight()); + } + } +} diff --git a/ares-parser/src/main/java/com/github/ares/parser/utils/PLParserUtil.java b/ares-parser/src/main/java/com/github/ares/parser/utils/PLParserUtil.java index abb03e5..63fc9a2 100644 --- a/ares-parser/src/main/java/com/github/ares/parser/utils/PLParserUtil.java +++ b/ares-parser/src/main/java/com/github/ares/parser/utils/PLParserUtil.java @@ -71,9 +71,6 @@ public static String getFullSQLWithParams(Object o, Map params, break; } } - // if (terminalNode.startsWith(":")) { - // throw new ParseException("param is not defined: " + terminalNode); - // } } if (resStr != null && resStr.length() > 2) { diff --git a/ares-parser/src/main/java/com/github/ares/parser/visitor/PlBaseVisitor.java b/ares-parser/src/main/java/com/github/ares/parser/visitor/PlBaseVisitor.java index 4801a1c..f1ceddb 100644 --- a/ares-parser/src/main/java/com/github/ares/parser/visitor/PlBaseVisitor.java +++ b/ares-parser/src/main/java/com/github/ares/parser/visitor/PlBaseVisitor.java @@ -11,6 +11,7 @@ import com.github.ares.parser.utils.PLParserUtil; import java.util.ArrayList; +import java.util.Collections; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -22,164 +23,152 @@ public void init(PlVisitorManager visitorManager) { this.visitorManager = visitorManager; } + /** + * visit PL script and return a list of logical operations + * + * @param sqlScriptContext PL SQL script context object + * @return a list of logical operations + */ public List visitBase(PlSqlParser.Sql_scriptContext sqlScriptContext) { - List setConfigs = visit4SetConfig(sqlScriptContext); - List unitStatementContexts = sqlScriptContext.unit_statement(); if (unitStatementContexts == null) { - return null; + return Collections.emptyList(); } List result = new ArrayList<>(); for (PlSqlParser.Unit_statementContext unitStatementContext : unitStatementContexts) { - PlSqlParser.Anonymous_bodyContext anonymousBodyContext = unitStatementContext.anonymous_body(); - if (anonymousBodyContext != null) { - Map declaredParams = new LinkedHashMap<>(); - LogicalOperation operation = visitorManager.getDeclareParamsVisitor() - .visitDeclareParams(unitStatementContext.anonymous_body().seq_of_declare_specs(), declaredParams); - List body = visitorManager.getBodyVisitor().visitBodyStatements(unitStatementContext.anonymous_body().body().seq_of_statements(), new LinkedHashMap<>(), - new LinkedHashMap<>(), declaredParams, result, null); - LogicalAnonymousBody anonymousBody = new LogicalAnonymousBody(); - anonymousBody.setDeclareParams((LogicalDeclareParams) operation); - anonymousBody.setAnonymousBody(body); - - if (unitStatementContext.anonymous_body().body().exception_handler() != null && - !unitStatementContext.anonymous_body().body().exception_handler().isEmpty()) { - LogicalExceptionHandler exHandler = visitorManager.getExceptionHandlerVisitor() - .visitExceptionHandler(unitStatementContext.anonymous_body().body().exception_handler().get(0), - new LinkedHashMap<>(), new LinkedHashMap<>(), declaredParams, false); - if (exHandler != null) { - anonymousBody.setExHandler(exHandler); - } - } - - result.add(anonymousBody); + if (visitPlContext(unitStatementContext, sqlScriptContext, result) + || visitSqlContext(unitStatementContext, result)) { continue; } + throw new ParseException(String.format("Unsupported syntax: %s", PLParserUtil.getFullText(unitStatementContext))); + } - PlSqlParser.Create_tableContext createTableContext = unitStatementContext.create_table(); - if (createTableContext != null) { - PlSqlParser.Create_withContext createWithContext = createTableContext.create_with(); - if (createWithContext != null) { - List operations = visitorManager.getCreateTableWithVisitor() - .visitCreateTableWith(createTableContext, createWithContext, setConfigs); - if (operations != null && !operations.isEmpty()) { - result.addAll(operations); - } - continue; - } else { - throw new UnsupportedOperationException("Unsupported create internal table yet"); - } - } + return result; + } - PlSqlParser.Create_procedure_bodyContext createProcedureBody = unitStatementContext.create_procedure_body(); - if (createProcedureBody != null) { - LogicalOperation operation = visitorManager.getCreateProcedureVisitor().visitCreateProcedure(createProcedureBody, result); - if (operation != null) { - result.add(operation); - } - continue; + private void visitAnonymousBody(PlSqlParser.Anonymous_bodyContext anonymousBodyContext, List result) { + Map declaredParams = new LinkedHashMap<>(); + LogicalOperation operation = visitorManager.getDeclareParamsVisitor() + .visitDeclareParams(anonymousBodyContext.seq_of_declare_specs(), declaredParams); + List body = visitorManager.getBodyVisitor().visitBodyStatements(anonymousBodyContext.body().seq_of_statements(), new LinkedHashMap<>(), + new LinkedHashMap<>(), declaredParams, result, null); + LogicalAnonymousBody anonymousBody = new LogicalAnonymousBody(); + anonymousBody.setDeclareParams((LogicalDeclareParams) operation); + anonymousBody.setAnonymousBody(body); + + if (anonymousBodyContext.body().exception_handler() != null && + !anonymousBodyContext.body().exception_handler().isEmpty()) { + LogicalExceptionHandler exHandler = visitorManager.getExceptionHandlerVisitor() + .visitExceptionHandler(anonymousBodyContext.body().exception_handler().get(0), + new LinkedHashMap<>(), new LinkedHashMap<>(), declaredParams, false); + if (exHandler != null) { + anonymousBody.setExHandler(exHandler); } + } - PlSqlParser.Create_function_bodyContext createFunctionBodyContext = unitStatementContext.create_function_body(); - if (createFunctionBodyContext != null) { - LogicalOperation operation = visitorManager.getCreateFunctionVisitor().visitCreateFunction(createFunctionBodyContext, result); - if (operation != null) { - result.add(operation); - } - continue; - } + result.add(anonymousBody); + } - PlSqlParser.Call_statementContext callStatementContext = unitStatementContext.call_statement(); - if (callStatementContext != null) { - LogicalOperation operation = visitorManager.getCallStatementVisitor(). - visitCallStatement(callStatementContext, new LinkedHashMap<>(), result, result, null); - if (operation != null) { - result.add(operation); - } - continue; + private void visitCreateAsSQL(PlSqlParser.Create_tableContext createTableContext, + PlSqlParser.Sql_scriptContext sqlScriptContext, List result) { + List setConfigs = visit4SetConfig(sqlScriptContext); + PlSqlParser.Create_withContext createWithContext = createTableContext.create_with(); + if (createWithContext != null) { + List operations = visitorManager.getCreateTableWithVisitor() + .visitCreateTableWith(createTableContext, createWithContext, setConfigs); + if (operations != null && !operations.isEmpty()) { + result.addAll(operations); } + } else { + throw new UnsupportedOperationException("Unsupported create internal table yet"); + } + } - if (unitStatementContext.select_block() != null) { - String selectSQL = PLParserUtil.cleanSQL(PLParserUtil.getFullText(unitStatementContext.select_block())); - LogicalOperation operation = visitorManager.getSelectSQLVisitor() - .visitSelectSQL(selectSQL, selectSQL, new LinkedHashMap<>()); - if (operation != null) { - result.add(operation); - } - continue; - } + private void visitCreateProcedure(PlSqlParser.Create_procedure_bodyContext createProcedureBody, List result) { + LogicalOperation operation = visitorManager.getCreateProcedureVisitor().visitCreateProcedure(createProcedureBody, result); + if (operation != null) { + result.add(operation); + } + } - if (unitStatementContext.insert_block() != null) { - String insertSQL = PLParserUtil.cleanSQL(PLParserUtil.getFullText(unitStatementContext.insert_block())); - LogicalOperation operation = visitorManager.getInsertSQLVisitor().visitInsertSQL(insertSQL, insertSQL); - if (operation != null) { - result.add(operation); - } - continue; - } + private void visitCreateFunction(PlSqlParser.Create_function_bodyContext createFunctionBodyContext, List result) { + LogicalOperation operation = visitorManager.getCreateFunctionVisitor().visitCreateFunction(createFunctionBodyContext, result); + if (operation != null) { + result.add(operation); + } + } - if (unitStatementContext.update_block() != null) { - String updateSQL = PLParserUtil.cleanSQL(PLParserUtil.getFullText(unitStatementContext.update_block())); - LogicalOperation operation = visitorManager.getUpdateSQLVisitor().visitUpdateSQL(updateSQL, updateSQL); - if (operation != null) { - result.add(operation); - } - continue; - } + private void visitCallStatement(PlSqlParser.Call_statementContext callStatementContext, List result) { + LogicalOperation operation = visitorManager.getCallStatementVisitor(). + visitCallStatement(callStatementContext, new LinkedHashMap<>(), result, null); + if (operation != null) { + result.add(operation); + } + } - if (unitStatementContext.delete_block() != null) { - String deleteSQL = PLParserUtil.cleanSQL(PLParserUtil.getFullText(unitStatementContext.delete_block())); - LogicalOperation operation = visitorManager.getDeleteSQLVisitor().visitDeleteSQL(deleteSQL, deleteSQL); - if (operation != null) { - result.add(operation); - } - continue; - } + private void visitSelectSQL(PlSqlParser.Select_blockContext selectBlockContext, List result) { + String selectSQL = PLParserUtil.cleanSQL(PLParserUtil.getFullText(selectBlockContext)); + LogicalOperation operation = visitorManager.getSelectSQLVisitor() + .visitSelectSQL(selectSQL, selectSQL, new LinkedHashMap<>()); + if (operation != null) { + result.add(operation); + } + } - if (unitStatementContext.merge_block() != null) { - String mergeSQL = PLParserUtil.cleanSQL(PLParserUtil.getFullText(unitStatementContext.merge_block())); - LogicalOperation operation = visitorManager.getMergeSQLVisitor().visitMergeSQL(mergeSQL, mergeSQL); - if (operation != null) { - result.add(operation); - } - continue; - } + private void visitInsertSQL(PlSqlParser.Insert_blockContext insertBlockContext, List result) { + String insertSQL = PLParserUtil.cleanSQL(PLParserUtil.getFullText(insertBlockContext)); + LogicalOperation operation = visitorManager.getInsertSQLVisitor().visitInsertSQL(insertSQL, insertSQL); + if (operation != null) { + result.add(operation); + } + } - PlSqlParser.Create_table_asContext createTableAsContext = unitStatementContext.create_table_as(); - if (createTableAsContext != null) { - String createSQL = PLParserUtil.cleanSQL(PLParserUtil.getFullText(createTableAsContext)); - String innerTableName = createTableAsContext.table_name().getText(); - LogicalOperation operation = visitorManager.getCreateAsSQLVisitor() - .visitCreateInnerTable(createSQL, createSQL, innerTableName); - if (operation != null) { - result.add(operation); - } - continue; - } + private void visitUpdateSQL(PlSqlParser.Update_blockContext updateBlockContext, List result) { + String updateSQL = PLParserUtil.cleanSQL(PLParserUtil.getFullText(updateBlockContext)); + LogicalOperation operation = visitorManager.getUpdateSQLVisitor().visitUpdateSQL(updateSQL, updateSQL); + if (operation != null) { + result.add(operation); + } + } - if (unitStatementContext.truncate_table_block() != null) { - String truncateSQL = PLParserUtil.cleanSQL(PLParserUtil.getFullText(unitStatementContext.truncate_table_block())); - LogicalOperation operation = visitorManager.getTruncateSQLVisitor().visitTruncateSQL(truncateSQL); - if (operation != null) { - result.add(operation); - } - continue; - } + private void visitDeleteSQL(PlSqlParser.Delete_blockContext deleteBlockContext, List result) { + String deleteSQL = PLParserUtil.cleanSQL(PLParserUtil.getFullText(deleteBlockContext)); + LogicalOperation operation = visitorManager.getDeleteSQLVisitor().visitDeleteSQL(deleteSQL, deleteSQL); + if (operation != null) { + result.add(operation); + } + } - if (unitStatementContext.set_bleck() != null) { - continue; - } + private void visitMergeSQL(PlSqlParser.Merge_blockContext mergeBlockContext, List result) { + String mergeSQL = PLParserUtil.cleanSQL(PLParserUtil.getFullText(mergeBlockContext)); + LogicalOperation operation = visitorManager.getMergeSQLVisitor().visitMergeSQL(mergeSQL, mergeSQL); + if (operation != null) { + result.add(operation); + } + } - throw new ParseException(String.format("Unsupported syntax: %s", PLParserUtil.getFullText(unitStatementContext))); + private void visitCreateTableAs(PlSqlParser.Create_table_asContext createTableAsContext, List result) { + String createSQL = PLParserUtil.cleanSQL(PLParserUtil.getFullText(createTableAsContext)); + String innerTableName = createTableAsContext.table_name().getText(); + LogicalOperation operation = visitorManager.getCreateAsSQLVisitor() + .visitCreateInnerTable(createSQL, createSQL, innerTableName); + if (operation != null) { + result.add(operation); } + } - return result; + private void visitTruncateTable(PlSqlParser.Truncate_table_blockContext truncateTableBlockContext, List result) { + String truncateSQL = PLParserUtil.cleanSQL(PLParserUtil.getFullText(truncateTableBlockContext)); + LogicalOperation operation = visitorManager.getTruncateSQLVisitor().visitTruncateSQL(truncateSQL); + if (operation != null) { + result.add(operation); + } } public List visit4SetConfig(PlSqlParser.Sql_scriptContext sqlScriptContext) { List unitStatementContexts = sqlScriptContext.unit_statement(); if (unitStatementContexts == null) { - return null; + return Collections.emptyList(); } List setConfigs = new ArrayList<>(); for (PlSqlParser.Unit_statementContext unitStatementContext : unitStatementContexts) { @@ -202,4 +191,55 @@ public List visit4SetConfig(PlSqlParser.Sql_scriptContext sqlS } return setConfigs; } + + private boolean visitPlContext(PlSqlParser.Unit_statementContext unitStatementContext, + PlSqlParser.Sql_scriptContext sqlScriptContext, List result) { + boolean resultFlag = false; + if (unitStatementContext.anonymous_body() != null) { + visitAnonymousBody(unitStatementContext.anonymous_body(), result); + resultFlag = true; + } else if (unitStatementContext.create_table() != null) { + visitCreateAsSQL(unitStatementContext.create_table(), sqlScriptContext, result); + resultFlag = true; + } else if (unitStatementContext.create_procedure_body() != null) { + visitCreateProcedure(unitStatementContext.create_procedure_body(), result); + resultFlag = true; + } else if (unitStatementContext.create_function_body() != null) { + visitCreateFunction(unitStatementContext.create_function_body(), result); + resultFlag = true; + } else if (unitStatementContext.call_statement() != null) { + visitCallStatement(unitStatementContext.call_statement(), result); + resultFlag = true; + } else if (unitStatementContext.set_bleck() != null) { + resultFlag = true; + } + return resultFlag; + } + + private boolean visitSqlContext(PlSqlParser.Unit_statementContext unitStatementContext, List result) { + boolean resultFlag = false; + if (unitStatementContext.select_block() != null) { + visitSelectSQL(unitStatementContext.select_block(), result); + resultFlag = true; + } else if (unitStatementContext.insert_block() != null) { + visitInsertSQL(unitStatementContext.insert_block(), result); + resultFlag = true; + } else if (unitStatementContext.update_block() != null) { + visitUpdateSQL(unitStatementContext.update_block(), result); + resultFlag = true; + } else if (unitStatementContext.delete_block() != null) { + visitDeleteSQL(unitStatementContext.delete_block(), result); + resultFlag = true; + } else if (unitStatementContext.merge_block() != null) { + visitMergeSQL(unitStatementContext.merge_block(), result); + resultFlag = true; + } else if (unitStatementContext.create_table_as() != null) { + visitCreateTableAs(unitStatementContext.create_table_as(), result); + resultFlag = true; + } else if (unitStatementContext.truncate_table_block() != null) { + visitTruncateTable(unitStatementContext.truncate_table_block(), result); + resultFlag = true; + } + return resultFlag; + } } diff --git a/ares-parser/src/main/java/com/github/ares/parser/visitor/PlBodyVisitor.java b/ares-parser/src/main/java/com/github/ares/parser/visitor/PlBodyVisitor.java index 829737b..2114189 100644 --- a/ares-parser/src/main/java/com/github/ares/parser/visitor/PlBodyVisitor.java +++ b/ares-parser/src/main/java/com/github/ares/parser/visitor/PlBodyVisitor.java @@ -9,6 +9,7 @@ import com.github.ares.parser.utils.PLParserUtil; import java.util.ArrayList; +import java.util.Collections; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -22,76 +23,34 @@ public void init(PlVisitorManager visitorManager) { this.visitorManager = visitorManager; } - public List visitBodyStatements(List statementContextList, Map inParams, - Map outParams, Map declaredParams, - List baseBody, List structs) { + /** + * Visits a list of statements in a PL/SQL block. + * + * @param statementContextList List of statement contexts. + * @param inParams input parameters + * @param outParams output parameters + * @param declaredParams declared parameters + * @param baseBody base body + * @param structs structs + * @return List of logical operations. + */ + public List visitBodyStatements( + List statementContextList, + Map inParams, + Map outParams, + Map declaredParams, + List baseBody, + List structs) { Map allParams = new LinkedHashMap<>(inParams); declaredParams.putAll(outParams); allParams.putAll(declaredParams); if (statementContextList == null) { - return null; + return Collections.emptyList(); } List result = new ArrayList<>(); for (PlSqlParser.StatementContext statementContext : statementContextList) { - PlSqlParser.Exit_statementContext exitStatementContext = statementContext.exit_statement(); - if (exitStatementContext != null) { - result.add(new LogicalExitLoop()); - continue; - } - PlSqlParser.Continue_statementContext continueStatementContext = statementContext.continue_statement(); - if (continueStatementContext != null) { - result.add(new LogicalContinueLoop()); - continue; - } - - PlSqlParser.Call_statementContext callStatementContext = statementContext.call_statement(); - if (callStatementContext != null) { - LogicalOperation operation = visitorManager.getCallStatementVisitor() - .visitCallStatement(callStatementContext, allParams, baseBody, result, structs); - if (operation != null) { - result.add(operation); - } - continue; - } - - PlSqlParser.Sql_statementContext sqlStatementContext = statementContext.sql_statement(); - if (sqlStatementContext != null) { - LogicalOperation operation = sqlStatementVisitor(sqlStatementContext, declaredParams, allParams, structs); - if (operation != null) { - result.add(operation); - } - continue; - } - - PlSqlParser.Assignment_statementContext assignmentStatement = statementContext.assignment_statement(); - if (assignmentStatement != null) { - LogicalOperation operation = visitorManager.getAssignmentVisitor() - .visitAssignment(assignmentStatement, declaredParams, allParams, structs); - if (operation != null) { - result.add(operation); - } - continue; - } - PlSqlParser.If_statementContext ifStatement = statementContext.if_statement(); - if (ifStatement != null) { - visitorManager.getIfStatementVisitor().ifElseVisitor( - this, ifStatement, baseBody, allParams, result, structs - ); - continue; - } - - PlSqlParser.Loop_statementContext loopStatementContext = statementContext.loop_statement(); - if (loopStatementContext != null) { - LogicalOperation operation = visitorManager.getLoopStatementVisitor().loopVisitor( - this, loopStatementContext, baseBody, allParams, structs); - if (operation != null) { - result.add(operation); - } - continue; - } - - PlSqlParser.Raise_statementContext raiseStatementContext = statementContext.raise_statement(); - if (raiseStatementContext != null) { + if (visitLogicalControlContext(statementContext, result) + || visitPlContext(statementContext, result, allParams, baseBody, declaredParams, structs)) { continue; } @@ -101,18 +60,83 @@ public List visitBodyStatements(List visitBodyStatements(PlSqlParser.Seq_of_statementsContext seq_of_statementsContext, Map inParams, - Map outParams, Map declaredParams, - List baseBody, List structs) { + public List visitBodyStatements( + PlSqlParser.Seq_of_statementsContext seq_of_statementsContext, + Map inParams, + Map outParams, + Map declaredParams, + List baseBody, + List structs) { return visitBodyStatements(seq_of_statementsContext.statement(), inParams, outParams, declaredParams, baseBody, structs); } - public LogicalOperation sqlStatementVisitor(PlSqlParser.Sql_statementContext sql_statementContext, - Map declaredParams, Map allParams, List structs) { - String originalSql = PLParserUtil.getFullText(sql_statementContext); + private boolean visitLogicalControlContext(PlSqlParser.StatementContext statementContext, List result) { + boolean resultFlag = false; + if (statementContext.exit_statement() != null) { + result.add(new LogicalExitLoop()); + resultFlag = true; + } else if (statementContext.continue_statement() != null) { + result.add(new LogicalContinueLoop()); + resultFlag = true; + } else if (statementContext.raise_statement() != null) { + resultFlag = true; + } + return resultFlag; + } + + private boolean visitPlContext( + PlSqlParser.StatementContext statementContext, + List result, + Map allParams, + List baseBody, + Map declaredParams, + List structs) { + boolean resultFlag = false; + if (statementContext.call_statement() != null) { + LogicalOperation operation = visitorManager.getCallStatementVisitor() + .visitCallStatement(statementContext.call_statement(), allParams, result, structs); + if (operation != null) { + result.add(operation); + } + resultFlag = true; + } else if (statementContext.sql_statement() != null) { + LogicalOperation operation = sqlStatementVisitor(statementContext.sql_statement(), declaredParams, allParams, structs); + if (operation != null) { + result.add(operation); + } + resultFlag = true; + } else if (statementContext.assignment_statement() != null) { + LogicalOperation operation = visitorManager.getAssignmentVisitor() + .visitAssignment(statementContext.assignment_statement(), declaredParams, allParams, structs); + if (operation != null) { + result.add(operation); + } + resultFlag = true; + } else if (statementContext.if_statement() != null) { + visitorManager.getIfStatementVisitor().ifElseVisitor( + this, statementContext.if_statement(), baseBody, allParams, result, structs + ); + resultFlag = true; + } else if (statementContext.loop_statement() != null) { + LogicalOperation operation = visitorManager.getLoopStatementVisitor().loopVisitor( + this, statementContext.loop_statement(), baseBody, allParams, structs); + if (operation != null) { + result.add(operation); + } + resultFlag = true; + } + return resultFlag; + } + + public LogicalOperation sqlStatementVisitor( + PlSqlParser.Sql_statementContext sqlStatementContext, + Map declaredParams, + Map allParams, + List structs) { + String originalSql = PLParserUtil.getFullText(sqlStatementContext); String sql; try { - sql = PLParserUtil.getFullSQLWithParams(sql_statementContext, allParams, structs); + sql = PLParserUtil.getFullSQLWithParams(sqlStatementContext, allParams, structs); } catch (Exception e) { throw new ParseException("Failed to parse SQL statement: " + originalSql, e); } @@ -133,19 +157,20 @@ public LogicalOperation sqlStatementVisitor(PlSqlParser.Sql_statementContext sql case "MERGE ": return visitorManager.getMergeSQLVisitor().visitMergeSQL(originalSql, sql); case "CREATE": - if (sql_statementContext.data_manipulation_language_statements() != null && - sql_statementContext.data_manipulation_language_statements().create_table_as2() != null) { - String innerTableName = sql_statementContext.data_manipulation_language_statements() + if (sqlStatementContext.data_manipulation_language_statements() != null && + sqlStatementContext.data_manipulation_language_statements().create_table_as2() != null) { + String innerTableName = sqlStatementContext.data_manipulation_language_statements() .create_table_as2().table_name().getText(); return visitorManager.getCreateAsSQLVisitor().visitCreateInnerTable(originalSql, sql, innerTableName); } + throw new UnsupportedOperationException("Unsupported SQL syntax: " + sql); case "TRUNCA": if ("TRUNCATE".equalsIgnoreCase(sql.substring(0, 8))) { return visitorManager.getTruncateSQLVisitor().visitTruncateSQL(sql); } + throw new UnsupportedOperationException("Unsupported SQL syntax: " + sql); default: throw new UnsupportedOperationException("Unsupported SQL syntax: " + sql); } - } } diff --git a/ares-parser/src/main/java/com/github/ares/parser/visitor/PlCallStatementVisitor.java b/ares-parser/src/main/java/com/github/ares/parser/visitor/PlCallStatementVisitor.java index 8b93ed2..89b1384 100644 --- a/ares-parser/src/main/java/com/github/ares/parser/visitor/PlCallStatementVisitor.java +++ b/ares-parser/src/main/java/com/github/ares/parser/visitor/PlCallStatementVisitor.java @@ -1,61 +1,34 @@ package com.github.ares.parser.visitor; -import com.github.ares.com.google.inject.Inject; import com.github.ares.common.engine.PlType; -import com.github.ares.common.exceptions.ParseException; import com.github.ares.parser.antlr4.plsql.PlSqlParser; import com.github.ares.parser.model.Argument; -import com.github.ares.parser.model.BaseSqlOption; import com.github.ares.parser.plan.LogicalCallFunction; import com.github.ares.parser.plan.LogicalCreateProcedure; -import com.github.ares.parser.plan.LogicalCreateSourceTable; -import com.github.ares.parser.plan.LogicalCreateTableAsSQL; import com.github.ares.parser.plan.LogicalExpression; -import com.github.ares.parser.plan.LogicalInsertSelectSQL; -import com.github.ares.parser.plan.LogicalMergeIntoSQL; import com.github.ares.parser.plan.LogicalOperation; -import com.github.ares.parser.plan.LogicalUpdateSelectSQL; import java.util.ArrayList; -import java.util.Arrays; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import java.util.stream.Collectors; - -import static com.github.ares.parser.enums.OperationType.CREATE_SOURCE_TABLE; -import static com.github.ares.parser.enums.OperationType.CREATE_TABLE_AS_SQL; -import static com.github.ares.parser.enums.OperationType.INSERT_SELECT_SQL; -import static com.github.ares.parser.enums.OperationType.MERGE_INTO_SQL; -import static com.github.ares.parser.enums.OperationType.UPDATE_SELECT_SQL; public class PlCallStatementVisitor { - private static final String INNER_FUNC_REPARTITION = "repartition"; - public static final String INNER_FUNC_CACHE = "cache"; - public static final String INNER_FUNC_SHOW = "show"; - private PlVisitorManager visitorManager; public void init(PlVisitorManager visitorManager) { this.visitorManager = visitorManager; } - public LogicalOperation visitCallStatement(PlSqlParser.Call_statementContext call_statementContext, - Map allParams, List baseBody, - List currentBody, List structs) { - String functionName = call_statementContext.routine_name().getText(); -// if (INNER_FUNC_REPARTITION.equalsIgnoreCase(functionName)) { -// return repartition(call_statementContext.function_argument(), currentBody); -// } else if (INNER_FUNC_CACHE.equalsIgnoreCase(functionName)) { -// return cacheFunc(currentBody); -// } else if (INNER_FUNC_SHOW.equalsIgnoreCase(functionName)) { -// return showFunc(call_statementContext.function_argument(), currentBody); -// } else { - return visitCallStatement(functionName, call_statementContext.function_argument(), allParams, baseBody, structs); -// } + public LogicalOperation visitCallStatement(PlSqlParser.Call_statementContext callStatementContext, + Map allParams, + List baseBody, + List structs) { + String functionName = callStatementContext.routine_name().getText(); + return visitCallStatement(functionName, callStatementContext.function_argument(), allParams, baseBody, structs); } - public LogicalOperation visitCallStatement(String functionName, PlSqlParser.Function_argumentContext function_argumentContext, + public LogicalOperation visitCallStatement(String functionName, PlSqlParser.Function_argumentContext functionArgumentContext, Map allParams, List baseOperations, List structs) { List outArgsIdx = null; for (LogicalOperation baseOperation : baseOperations) { @@ -68,23 +41,22 @@ public LogicalOperation visitCallStatement(String functionName, PlSqlParser.Func List outArgsIdxTmp = outArgsIdx; Map outParams = null; - if (function_argumentContext != null) { + if (functionArgumentContext != null) { outParams = new LinkedHashMap<>(); - List argumentContexts = function_argumentContext.argument(); + List argumentContexts = functionArgumentContext.argument(); int i = 0; for (PlSqlParser.ArgumentContext argumentContext : argumentContexts) { i++; - if (outArgsIdxTmp != null && outArgsIdxTmp.contains(i - 1)) { - if (allParams.containsKey(argumentContext.getText())) { - outParams.put(argumentContext.getText(), allParams.get(argumentContext.getText())); - } + if (outArgsIdxTmp != null && outArgsIdxTmp.contains(i - 1) + && allParams.containsKey(argumentContext.getText())) { + outParams.put(argumentContext.getText(), allParams.get(argumentContext.getText())); } } } List args = new ArrayList<>(); - if (function_argumentContext != null) { - List argumentContexts = function_argumentContext.argument(); + if (functionArgumentContext != null) { + List argumentContexts = functionArgumentContext.argument(); int i = 0; for (PlSqlParser.ArgumentContext argumentContext : argumentContexts) { i++; @@ -110,100 +82,4 @@ public LogicalOperation visitCallStatement(String functionName, PlSqlParser.Func } return callFunction; } - -// public LogicalOperation repartition(PlSqlParser.Function_argumentContext function_argumentContext, List baseOperations) { -// if (function_argumentContext == null || (function_argumentContext.argument().size() != 1 && function_argumentContext.argument().size() != 2)) { -// throw new ParseException("The params of function 'repartition' must be one or two"); -// } -// LogicalOperation baseOperation = baseOperations.get(baseOperations.size() - 1); -// if (baseOperation.getOperationType() == INSERT_SELECT_SQL) { -// repartition(function_argumentContext, (LogicalInsertSelectSQL) baseOperation); -// } else if (baseOperation.getOperationType() == UPDATE_SELECT_SQL) { -// repartition(function_argumentContext, (LogicalUpdateSelectSQL) baseOperation); -// } else if (baseOperation.getOperationType() == MERGE_INTO_SQL) { -// repartition(function_argumentContext, (LogicalMergeIntoSQL) baseOperation); -// } else if (baseOperation.getOperationType() == CREATE_TABLE_AS_SQL) { -// repartition(function_argumentContext, (LogicalCreateTableAsSQL) baseOperation); -// } else { -// throw new UnsupportedOperationException("The REPARTITION function only supports `insert...select...`, `update...select...`, " + -// "`merge into...`, `create table as...` SQL statements."); -// } -// return null; -// } - -// private static void repartition(PlSqlParser.Function_argumentContext function_argumentContext, BaseSqlOption baseSqlOption) { -// int partitionNums = Integer.parseInt(function_argumentContext.argument().get(0).getText()); -// baseSqlOption.setRepartitionNums(partitionNums); -// if (function_argumentContext.argument().size() == 2) { -// String columns = function_argumentContext.argument().get(1).getText(); -// if (!columns.startsWith("'") || !columns.endsWith("'")) { -// throw new ParseException("The params of function 'repartition' columns must be VARCHAR"); -// } -// columns = columns.substring(0, columns.length() - 1).substring(1); -// String[] cols = columns.split(","); -// List colList = Arrays.stream(cols).map(String::trim).collect(Collectors.toList()); -// baseSqlOption.setRepartitionColumns(colList); -// } -// } - -// public LogicalOperation cacheFunc(List baseOperations) { -// if (baseOperations == null || baseOperations.isEmpty()) { -// return null; -// } -// LogicalOperation baseOperation = null; -// for (int i = baseOperations.size() - 1; i >= 0; i--) { -// LogicalOperation baseOperation1 = baseOperations.get(i); -// if (baseOperation1.getOperationType() == CREATE_TABLE_AS_SQL || -// baseOperation1.getOperationType() == CREATE_SOURCE_TABLE) { -// baseOperation = baseOperation1; -// break; -// } -// } -// if (baseOperation != null) { -// if (baseOperation.getOperationType() == CREATE_TABLE_AS_SQL) { -// ((LogicalCreateTableAsSQL) baseOperation).setWithCache(true); -// } else if (baseOperation.getOperationType() == CREATE_SOURCE_TABLE) { -// ((LogicalCreateSourceTable) baseOperation).setWithCache(true); -// } -// } -// return null; -// } - -// public LogicalOperation showFunc(PlSqlParser.Function_argumentContext function_argumentContext, List baseOperations) { -// if (!function_argumentContext.argument().isEmpty() && function_argumentContext.argument().size() != 1) { -// throw new ParseException("The params of function 'show' must be 0 or 1"); -// } -// int showCounts = 100; -// if (function_argumentContext.argument().size() == 1) { -// try { -// showCounts = Integer.parseInt(function_argumentContext.argument().get(0).getText()); -// } catch (NumberFormatException e) { -// throw new ParseException("The params of function 'show' must be INT"); -// } -// } -// if (showCounts > 100) { -// showCounts = 100; -// } -// if (baseOperations == null || baseOperations.isEmpty()) { -// return null; -// } -// LogicalOperation baseOperation = null; -// for (int i = baseOperations.size() - 1; i >= 0; i--) { -// LogicalOperation baseOperation1 = baseOperations.get(i); -// if (baseOperation1.getOperationType() == CREATE_TABLE_AS_SQL || -// baseOperation1.getOperationType() == CREATE_SOURCE_TABLE) { -// baseOperation = baseOperation1; -// break; -// } -// } -// if (baseOperation != null) { -// if (baseOperation.getOperationType() == CREATE_TABLE_AS_SQL) { -// ((LogicalCreateTableAsSQL) baseOperation).setWithShow(showCounts); -// } else if (baseOperation.getOperationType() == CREATE_SOURCE_TABLE) { -// ((LogicalCreateSourceTable) baseOperation).setWithShow(showCounts); -// } -// } -// return null; -// } - } diff --git a/ares-parser/src/main/java/com/github/ares/parser/visitor/PlCreateTableWithVisitor.java b/ares-parser/src/main/java/com/github/ares/parser/visitor/PlCreateTableWithVisitor.java index de9d85d..44c44a2 100644 --- a/ares-parser/src/main/java/com/github/ares/parser/visitor/PlCreateTableWithVisitor.java +++ b/ares-parser/src/main/java/com/github/ares/parser/visitor/PlCreateTableWithVisitor.java @@ -94,8 +94,8 @@ public List visitCreateTableWith(PlSqlParser.Create_tableConte private Map visitCreateWithOptions(PlSqlParser.Create_withContext create_withContext, List setConfigs) { Map withOptions = new LinkedHashMap<>(); - PlSqlParser.Create_optionsContext create_optionsContext = create_withContext.create_options(); - visitCreateWithOptions(create_optionsContext, withOptions); + PlSqlParser.Create_optionsContext createOptionsContext = create_withContext.create_options(); + visitCreateWithOptions(createOptionsContext, withOptions); String datasource = withOptions.get(DATA_SOURCE.key()); if (StringUtils.isEmpty(withOptions.get(CONNECTOR.key())) && !StringUtils.isEmpty(datasource)) { Properties properties = new Properties(); diff --git a/ares-parser/src/main/java/com/github/ares/parser/visitor/PlFunctionBodyVisitor.java b/ares-parser/src/main/java/com/github/ares/parser/visitor/PlFunctionBodyVisitor.java index 40af41f..27ba5c1 100644 --- a/ares-parser/src/main/java/com/github/ares/parser/visitor/PlFunctionBodyVisitor.java +++ b/ares-parser/src/main/java/com/github/ares/parser/visitor/PlFunctionBodyVisitor.java @@ -1,20 +1,32 @@ package com.github.ares.parser.visitor; -import com.github.ares.common.engine.InternalFieldType; import com.github.ares.common.engine.PlType; +import com.github.ares.parser.antlr4.plsql.PlSqlParser; import com.github.ares.parser.plan.LogicalContinueLoop; import com.github.ares.parser.plan.LogicalExitLoop; import com.github.ares.parser.plan.LogicalOperation; -import com.github.ares.parser.antlr4.plsql.PlSqlParser; import com.github.ares.parser.utils.PLParserUtil; import java.util.ArrayList; +import java.util.Collections; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; public class PlFunctionBodyVisitor extends PlBodyVisitor { + /** + * Visits the body statements of a function. + * + * @param statementContextList the list of statement contexts + * @param inParams input parameters + * @param outParams output parameters + * @param declaredParams declared parameters + * @param baseBody the base body of the function + * @param structs the list of structs + * @return the list of logical operations + */ + @Override public List visitBodyStatements(List statementContextList, Map inParams, Map outParams, Map declaredParams, List baseBody, List structs) { @@ -22,77 +34,74 @@ public List visitBodyStatements(List result = new ArrayList<>(); for (PlSqlParser.StatementContext statementContext : statementContextList) { - PlSqlParser.Exit_statementContext exit_statementContext = statementContext.exit_statement(); - if (exit_statementContext != null) { - result.add(new LogicalExitLoop()); - continue; - } - PlSqlParser.Continue_statementContext continueStatementContext = statementContext.continue_statement(); - if (continueStatementContext != null) { - result.add(new LogicalContinueLoop()); + if (visitLogicalControlContext(statementContext, result, inParams, declaredParams) + || visitPlContext(statementContext, result, allParams, declaredParams, baseBody, structs)) { continue; } - PlSqlParser.Call_statementContext callStatementContext = statementContext.call_statement(); - if (callStatementContext != null) { - LogicalOperation operation = visitorManager.getCallStatementVisitor() - .visitCallStatement(callStatementContext, allParams, baseBody, result, structs); - if (operation != null) { - result.add(operation); - } - continue; - } + throw new UnsupportedOperationException(String.format("Unsupported syntax: '%s' in function body", + PLParserUtil.getFullText(statementContext))); + } + return result; + } - PlSqlParser.Assignment_statementContext assignmentStatement = statementContext.assignment_statement(); - if (assignmentStatement != null) { - LogicalOperation operation = visitorManager.getAssignmentVisitor() - .visitAssignment(assignmentStatement, declaredParams, allParams, structs); - if (operation != null) { - result.add(operation); - } - continue; + private boolean visitLogicalControlContext(PlSqlParser.StatementContext statementContext, List result, + Map inParams, Map declaredParams) { + boolean resultFlag = false; + if (statementContext.exit_statement() != null) { + result.add(new LogicalExitLoop()); + resultFlag = true; + } else if (statementContext.continue_statement() != null) { + result.add(new LogicalContinueLoop()); + resultFlag = true; + } else if (statementContext.return_statement() != null) { + LogicalOperation operation = visitorManager.getReturnStatementVisitor() + .visitReturnStatement(statementContext.return_statement(), inParams, declaredParams); + if (operation != null) { + result.add(operation); } + resultFlag = true; + } else if (statementContext.raise_statement() != null) { + resultFlag = true; + } - PlSqlParser.If_statementContext ifStatement = statementContext.if_statement(); - if (ifStatement != null) { - visitorManager.getIfStatementVisitor().ifElseVisitor( - this, ifStatement, baseBody, allParams, result, structs - ); - continue; - } + return resultFlag; + } - PlSqlParser.Loop_statementContext loopStatementContext = statementContext.loop_statement(); - if (loopStatementContext != null) { - LogicalOperation operation = visitorManager.getLoopStatementVisitor().loopVisitor( - this, loopStatementContext, baseBody, allParams, structs, false); - if (operation != null) { - result.add(operation); - } - continue; + private boolean visitPlContext(PlSqlParser.StatementContext statementContext, List result, Map allParams, + Map declaredParams, List baseBody, List structs) { + boolean resultFlag = false; + if (statementContext.call_statement() != null) { + LogicalOperation operation = visitorManager.getCallStatementVisitor() + .visitCallStatement(statementContext.call_statement(), allParams, result, structs); + if (operation != null) { + result.add(operation); } - - PlSqlParser.Return_statementContext return_statementContext = statementContext.return_statement(); - if (return_statementContext != null) { - LogicalOperation operation = visitorManager.getReturnStatementVisitor() - .visitReturnStatement(return_statementContext, inParams, declaredParams); - if (operation != null) { - result.add(operation); - } - continue; + resultFlag = true; + } else if (statementContext.assignment_statement() != null) { + LogicalOperation operation = visitorManager.getAssignmentVisitor() + .visitAssignment(statementContext.assignment_statement(), declaredParams, allParams, structs); + if (operation != null) { + result.add(operation); } - - PlSqlParser.Raise_statementContext raiseStatementContext = statementContext.raise_statement(); - if (raiseStatementContext != null) { - continue; + resultFlag = true; + } else if (statementContext.if_statement() != null) { + visitorManager.getIfStatementVisitor().ifElseVisitor( + this, statementContext.if_statement(), baseBody, allParams, result, structs + ); + resultFlag = true; + } else if (statementContext.loop_statement() != null) { + LogicalOperation operation = visitorManager.getLoopStatementVisitor().loopVisitor( + this, statementContext.loop_statement(), baseBody, allParams, structs, false); + if (operation != null) { + result.add(operation); } - - throw new UnsupportedOperationException(String.format("Unsupported syntax: '%s' in function body", - PLParserUtil.getFullText(statementContext))); + resultFlag = true; } - return result; + return resultFlag; } } diff --git a/ares-parser/src/test/java/com/github/ares/parser/test/PlParserTest.java b/ares-parser/src/test/java/com/github/ares/parser/test/PlParserTest.java index bb45426..fd5201e 100644 --- a/ares-parser/src/test/java/com/github/ares/parser/test/PlParserTest.java +++ b/ares-parser/src/test/java/com/github/ares/parser/test/PlParserTest.java @@ -1,29 +1,113 @@ package com.github.ares.parser.test; +import com.github.ares.api.common.EngineType; +import com.github.ares.api.common.EngineTypeVersion; +import com.github.ares.api.common.ExecutionEngineType; import com.github.ares.com.google.inject.Guice; import com.github.ares.com.google.inject.Injector; import com.github.ares.common.utils.InjectorFactory; import com.github.ares.parser.PlParser; import com.github.ares.parser.config.ParserServiceModule; +import com.github.ares.parser.datasource.PropertiesDataSourcePatcher; +import com.github.ares.parser.datasource.SourceConfigPatcherFactory; +import com.github.ares.parser.plan.LogicalCreateSinkTable; +import com.github.ares.parser.plan.LogicalCreateSourceTable; import com.github.ares.parser.plan.LogicalProject; +import com.github.ares.parser.utils.Constants; import org.junit.Assert; +import org.junit.Before; import org.junit.Ignore; import org.junit.Test; -import java.io.IOException; -import java.io.InputStream; +import java.util.Properties; @Ignore public class PlParserTest { - @Test - public void test00() throws IOException { + + private PlParser plTransformation; + + @Before + public void init() { + ExecutionEngineType.init(EngineType.SPARK, EngineTypeVersion.SPARK3); Injector injector = Guice.createInjector(new ParserServiceModule()); InjectorFactory.init(injector); - PlParser plTransformation = injector.getInstance(PlParser.class); + plTransformation = injector.getInstance(PlParser.class); plTransformation.init(); - InputStream in = this.getClass().getClassLoader().getResourceAsStream("mysql.sql"); - LogicalProject baseBody = plTransformation.parseToBaseBody(in); - Assert.assertFalse(baseBody.getLogicalOperations().isEmpty()); - in.close(); + SourceConfigPatcherFactory.register(Constants.DEFAULT_DATASOURCE_PATCHER, + new PropertiesDataSourcePatcher(new Properties())); + } + + @Test + public void parseCreateTable() { + Injector injector = Guice.createInjector(new ParserServiceModule()); + InjectorFactory.init(injector); + String pl = "CREATE TABLE test1\n" + + "WITH (\n" + + " 'connector'='jdbc',\n" + + " 'url'='jdbc:mysql://127.0.0.1:3306/mytest?useSSL=false',\n" + + " 'driver'='com.mysql.cj.jdbc.Driver',\n" + + " 'user'='root',\n" + + " 'password'='123456',\n" + + " -- 'query'='select * from t_user',\n" + + " 'table_name'='t_user',\n" + + " 'type' = 'source,sink'\n" + + ");"; + LogicalProject logicalProject = plTransformation.parseToBaseBody(pl); + Assert.assertEquals(2, logicalProject.getLogicalOperations().size()); + LogicalCreateSourceTable logicalCreateSourceTable = (LogicalCreateSourceTable) logicalProject.getLogicalOperations().get(0); + Assert.assertEquals("jdbc", logicalCreateSourceTable.getConnector()); + Assert.assertEquals("test1", logicalCreateSourceTable.getTableName()); + Assert.assertEquals("jdbc:mysql://127.0.0.1:3306/mytest?useSSL=false", logicalCreateSourceTable.getOptions().get("url")); + Assert.assertEquals("com.mysql.cj.jdbc.Driver", logicalCreateSourceTable.getOptions().get("driver")); + Assert.assertEquals("root", logicalCreateSourceTable.getOptions().get("user")); + Assert.assertEquals("123456", logicalCreateSourceTable.getOptions().get("password")); + Assert.assertEquals("t_user", logicalCreateSourceTable.getOptions().get("table_name")); + + LogicalCreateSinkTable logicalCreateSinkTable = (LogicalCreateSinkTable) logicalProject.getLogicalOperations().get(1); + Assert.assertEquals("jdbc", logicalCreateSinkTable.getConnector()); + Assert.assertEquals("test1", logicalCreateSinkTable.getTableName()); + Assert.assertEquals("jdbc:mysql://127.0.0.1:3306/mytest?useSSL=false", logicalCreateSinkTable.getOptions().get("url")); + Assert.assertEquals("com.mysql.cj.jdbc.Driver", logicalCreateSinkTable.getOptions().get("driver")); + Assert.assertEquals("root", logicalCreateSinkTable.getOptions().get("user")); + Assert.assertEquals("123456", logicalCreateSinkTable.getOptions().get("password")); + Assert.assertEquals("t_user", logicalCreateSinkTable.getOptions().get("table_name")); + } + + @Test + public void parseCreateTableWithDs() { + Injector injector = Guice.createInjector(new ParserServiceModule()); + InjectorFactory.init(injector); + String pl = "SET datasource.mytest.connector=jdbc;\n" + + "SET datasource.mytest.url=jdbc:mysql://127.0.0.1:3306/mytest?useSSL=false;\n" + + "SET datasource.mytest.driver=com.mysql.cj.jdbc.Driver;\n" + + "SET datasource.mytest.user=root;\n" + + "SET datasource.mytest.password=123456;\n" + + "\n" + + "CREATE TABLE test1\n" + + "WITH (\n" + + " 'datasource' = 'mytest',\n" + + " -- 'query'='select * from t_user',\n" + + " 'table_name'='t_user',\n" + + " 'type' = 'source,sink'\n" + + ");"; + LogicalProject logicalProject = plTransformation.parseToBaseBody(pl); + Assert.assertEquals(2, logicalProject.getLogicalOperations().size()); + LogicalCreateSourceTable logicalCreateSourceTable = (LogicalCreateSourceTable) logicalProject.getLogicalOperations().get(0); + Assert.assertEquals("jdbc", logicalCreateSourceTable.getConnector()); + Assert.assertEquals("test1", logicalCreateSourceTable.getTableName()); + Assert.assertEquals("jdbc:mysql://127.0.0.1:3306/mytest?useSSL=false", logicalCreateSourceTable.getOptions().get("url")); + Assert.assertEquals("com.mysql.cj.jdbc.Driver", logicalCreateSourceTable.getOptions().get("driver")); + Assert.assertEquals("root", logicalCreateSourceTable.getOptions().get("user")); + Assert.assertEquals("123456", logicalCreateSourceTable.getOptions().get("password")); + Assert.assertEquals("t_user", logicalCreateSourceTable.getOptions().get("table_name")); + + LogicalCreateSinkTable logicalCreateSinkTable = (LogicalCreateSinkTable) logicalProject.getLogicalOperations().get(1); + Assert.assertEquals("jdbc", logicalCreateSinkTable.getConnector()); + Assert.assertEquals("test1", logicalCreateSinkTable.getTableName()); + Assert.assertEquals("jdbc:mysql://127.0.0.1:3306/mytest?useSSL=false", logicalCreateSinkTable.getOptions().get("url")); + Assert.assertEquals("com.mysql.cj.jdbc.Driver", logicalCreateSinkTable.getOptions().get("driver")); + Assert.assertEquals("root", logicalCreateSinkTable.getOptions().get("user")); + Assert.assertEquals("123456", logicalCreateSinkTable.getOptions().get("password")); + Assert.assertEquals("t_user", logicalCreateSinkTable.getOptions().get("table_name")); } } diff --git a/ares-starter/ares-spark-starter-common/src/main/java/com/github/ares/spark/starter/SparkStarter.java b/ares-starter/ares-spark-starter-common/src/main/java/com/github/ares/spark/starter/SparkStarter.java index 00e0349..90f57df 100644 --- a/ares-starter/ares-spark-starter-common/src/main/java/com/github/ares/spark/starter/SparkStarter.java +++ b/ares-starter/ares-spark-starter-common/src/main/java/com/github/ares/spark/starter/SparkStarter.java @@ -124,13 +124,13 @@ public List buildCommands() throws Exception { /** * parse spark configurations from Ares config file */ - private void setSparkConf() throws FileNotFoundException { + private void setSparkConf() { commandArgs.getVariables().stream() .filter(Objects::nonNull) .map(variable -> variable.split("=", 2)) .filter(pair -> pair.length == 2) .forEach(pair -> System.setProperty(pair[0], pair[1])); - this.sparkConf = getSparkConf(/* TODO commandArgs.getConfigFile()*/null); + this.sparkConf = new LinkedHashMap<>(); String driverJavaOpts = this.sparkConf.getOrDefault("spark.driver.extraJavaOptions", ""); String executorJavaOpts = this.sparkConf.getOrDefault("spark.executor.extraJavaOptions", ""); @@ -146,29 +146,6 @@ private void setSparkConf() throws FileNotFoundException { } } - /** - * Get spark configurations from Ares job config file. - */ - static Map getSparkConf(String configFile) throws FileNotFoundException { -// File file = new File(configFile); -// if (!file.exists()) { -// throw new FileNotFoundException("config file '" + file + "' does not exists!"); -// } -// Config appConfig = -// ConfigFactory.parseFile(file) -// .resolve(ConfigResolveOptions.defaults().setAllowUnresolved(true)) -// .resolveWith( -// ConfigFactory.systemProperties(), -// ConfigResolveOptions.defaults().setAllowUnresolved(true)); -// -// return appConfig.getConfig("env").entrySet().stream() -// .collect( -// Collectors.toMap( -// Map.Entry::getKey, e -> e.getValue().unwrapped().toString())); - // TODO - return new LinkedHashMap<>(); - } - /** * append spark configurations to StringBuilder */ diff --git a/ares-starter/ares-spark2-starter/src/main/java/com/github/ares/spark/starter/SparkStarter.java b/ares-starter/ares-spark2-starter/src/main/java/com/github/ares/spark/starter/SparkStarter.java index 2ed4547..02ab59b 100644 --- a/ares-starter/ares-spark2-starter/src/main/java/com/github/ares/spark/starter/SparkStarter.java +++ b/ares-starter/ares-spark2-starter/src/main/java/com/github/ares/spark/starter/SparkStarter.java @@ -152,7 +152,7 @@ public List buildCommands() throws IOException { /** * parse spark configurations from Ares config file */ - private void setSparkConf() throws FileNotFoundException { + private void setSparkConf() { commandArgs.getVariables().stream() .filter(Objects::nonNull) .map(variable -> variable.split("=", 2))