From b80ca40ed19ef96ce3774fcf1006042a640e20e7 Mon Sep 17 00:00:00 2001 From: yrizhkov Date: Wed, 20 Dec 2023 09:38:23 +0200 Subject: [PATCH] FMWK-298 Support byte array parameter in PreparedStatement --- .../jdbc/AerospikePreparedStatement.java | 28 +++------- .../aerospike/jdbc/AerospikeStatement.java | 13 +++-- .../aerospike/jdbc/model/AerospikeQuery.java | 4 +- .../jdbc/model/AerospikeSqlVisitor.java | 26 +++++++-- .../jdbc/predicate/QueryPredicateBase.java | 25 +++++++-- .../jdbc/util/PreparedStatement.java | 4 +- .../java/com/aerospike/jdbc/JdbcBaseTest.java | 8 +-- .../aerospike/jdbc/PreparedQueriesTest.java | 55 +++++++++++++++---- 8 files changed, 110 insertions(+), 53 deletions(-) diff --git a/src/main/java/com/aerospike/jdbc/AerospikePreparedStatement.java b/src/main/java/com/aerospike/jdbc/AerospikePreparedStatement.java index 0044937..72bd347 100644 --- a/src/main/java/com/aerospike/jdbc/AerospikePreparedStatement.java +++ b/src/main/java/com/aerospike/jdbc/AerospikePreparedStatement.java @@ -27,7 +27,6 @@ import static com.aerospike.jdbc.util.PreparedStatement.parseParameters; import static java.lang.String.format; -import static java.util.Objects.isNull; public class AerospikePreparedStatement extends AerospikeStatement implements PreparedStatement { @@ -50,9 +49,8 @@ private Object[] buildSqlParameters(String sql) { @Override public ResultSet executeQuery() throws SQLException { - String preparedQueryString = prepareQueryString(); - logger.info(() -> "executeQuery: " + preparedQueryString); - AerospikeQuery query = parseQuery(preparedQueryString); + logger.info(() -> format("executeQuery: %s, params: %s", sqlStatement, Arrays.toString(sqlParameters))); + AerospikeQuery query = parseQuery(sqlStatement, Arrays.asList(sqlParameters)); runQuery(query); return resultSet; } @@ -114,7 +112,7 @@ public void setBigDecimal(int parameterIndex, BigDecimal x) throws SQLException @Override public void setString(int parameterIndex, String x) throws SQLException { - setObject(parameterIndex, format("\"%s\"", x)); + setObject(parameterIndex, x); } @Override @@ -180,25 +178,15 @@ public void setObject(int parameterIndex, Object x) throws SQLException { @Override public boolean execute() throws SQLException { - String preparedQueryString = prepareQueryString(); - logger.info(() -> "execute: " + preparedQueryString); - AerospikeQuery query = parseQuery(preparedQueryString); + logger.info(() -> format("execute: %s, params: %s", sqlStatement, Arrays.toString(sqlParameters))); + AerospikeQuery query = parseQuery(sqlStatement, Arrays.asList(sqlParameters)); runQuery(query); return query.getQueryType() == QueryType.SELECT; } - private String prepareQueryString() { - String preparedQueryString = sqlStatement; - for (Object value : sqlParameters) { - String replacement = isNull(value) ? "?" : value.toString(); - preparedQueryString = preparedQueryString.replaceFirst("\\?", replacement); - } - return preparedQueryString; - } - @Override public void addBatch() throws SQLException { - addBatch(prepareQueryString()); + throw new SQLFeatureNotSupportedException(BATCH_NOT_SUPPORTED_MESSAGE); } @Override @@ -228,7 +216,7 @@ public void setArray(int parameterIndex, Array x) throws SQLException { @Override public ResultSetMetaData getMetaData() throws SQLException { - AerospikeQuery query = parseQuery(prepareQueryString()); + AerospikeQuery query = parseQuery(sqlStatement, Arrays.asList(sqlParameters)); List columns = ((AerospikeDatabaseMetadata) connection.getMetaData()) .getSchemaBuilder() .getSchema(query.getCatalogTable()); @@ -262,7 +250,7 @@ public void setURL(int parameterIndex, URL url) throws SQLException { @Override public ParameterMetaData getParameterMetaData() throws SQLException { - AerospikeQuery query = parseQuery(prepareQueryString()); + AerospikeQuery query = parseQuery(sqlStatement, Arrays.asList(sqlParameters)); List columns = ((AerospikeDatabaseMetadata) connection.getMetaData()) .getSchemaBuilder() .getSchema(query.getCatalogTable()); diff --git a/src/main/java/com/aerospike/jdbc/AerospikeStatement.java b/src/main/java/com/aerospike/jdbc/AerospikeStatement.java index d0bcaa0..c8ef91e 100644 --- a/src/main/java/com/aerospike/jdbc/AerospikeStatement.java +++ b/src/main/java/com/aerospike/jdbc/AerospikeStatement.java @@ -14,6 +14,7 @@ import java.sql.SQLFeatureNotSupportedException; import java.sql.SQLWarning; import java.sql.Statement; +import java.util.Collection; import java.util.logging.Logger; import static java.lang.String.format; @@ -24,9 +25,9 @@ public class AerospikeStatement implements Statement, SimpleWrapper { - private static final Logger logger = Logger.getLogger(AerospikeStatement.class.getName()); + protected static final String BATCH_NOT_SUPPORTED_MESSAGE = "Batch update is not supported"; - private static final String BATCH_NOT_SUPPORTED_MESSAGE = "Batch update is not supported"; + private static final Logger logger = Logger.getLogger(AerospikeStatement.class.getName()); private static final String AUTO_GENERATED_KEYS_NOT_SUPPORTED_MESSAGE = "Auto-generated keys are not supported"; protected final IAerospikeClient client; @@ -48,7 +49,7 @@ public AerospikeStatement(IAerospikeClient client, AerospikeConnection connectio @Override public ResultSet executeQuery(String sql) throws SQLException { logger.info(() -> "executeQuery: " + sql); - AerospikeQuery query = parseQuery(sql); + AerospikeQuery query = parseQuery(sql, null); runQuery(query); return resultSet; } @@ -59,11 +60,11 @@ protected void runQuery(AerospikeQuery query) { updateCount = result.getRight(); } - protected AerospikeQuery parseQuery(String sql) throws SQLException { + protected AerospikeQuery parseQuery(String sql, Collection sqlParameters) throws SQLException { sql = sql.replace("\n", " "); AerospikeQuery query; try { - query = AerospikeQuery.parse(sql); + query = AerospikeQuery.parse(sql, sqlParameters); } catch (Exception e) { query = AuxStatementParser.parse(sql); } @@ -142,7 +143,7 @@ public void setCursorName(String name) throws SQLException { @Override public boolean execute(String sql) throws SQLException { logger.info(() -> "execute: " + sql); - AerospikeQuery query = parseQuery(sql); + AerospikeQuery query = parseQuery(sql, null); runQuery(query); return query.getQueryType() == QueryType.SELECT; } diff --git a/src/main/java/com/aerospike/jdbc/model/AerospikeQuery.java b/src/main/java/com/aerospike/jdbc/model/AerospikeQuery.java index 05da48b..bb6e5fc 100644 --- a/src/main/java/com/aerospike/jdbc/model/AerospikeQuery.java +++ b/src/main/java/com/aerospike/jdbc/model/AerospikeQuery.java @@ -45,10 +45,10 @@ public AerospikeQuery() { this.queryType = QueryType.UNKNOWN; } - public static AerospikeQuery parse(String sql) throws SqlParseException { + public static AerospikeQuery parse(String sql, Collection sqlParameters) throws SqlParseException { SqlParser parser = SqlParser.create(sql, sqlParserConfig); SqlNode parsed = parser.parseQuery(); - return parsed.accept(new AerospikeSqlVisitor()); + return parsed.accept(new AerospikeSqlVisitor(sqlParameters)); } public String getCatalog() { diff --git a/src/main/java/com/aerospike/jdbc/model/AerospikeSqlVisitor.java b/src/main/java/com/aerospike/jdbc/model/AerospikeSqlVisitor.java index 9e1b511..93732db 100644 --- a/src/main/java/com/aerospike/jdbc/model/AerospikeSqlVisitor.java +++ b/src/main/java/com/aerospike/jdbc/model/AerospikeSqlVisitor.java @@ -10,18 +10,30 @@ import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.sql.util.SqlVisitor; +import javax.annotation.Nullable; import java.math.BigDecimal; +import java.util.Collection; +import java.util.Iterator; import java.util.stream.Collectors; import static com.aerospike.jdbc.util.Constants.UNSUPPORTED_QUERY_TYPE_MESSAGE; +import static com.google.common.base.Preconditions.checkState; import static java.util.Objects.requireNonNull; public class AerospikeSqlVisitor implements SqlVisitor { + private static final String QUERY_PLACEHOLDER = "?"; + private final AerospikeQuery query; + private final Iterator sqlParametersIterator; public AerospikeSqlVisitor() { + this(null); + } + + public AerospikeSqlVisitor(@Nullable Collection sqlParameters) { query = new AerospikeQuery(); + sqlParametersIterator = sqlParameters != null ? sqlParameters.iterator() : null; } @Override @@ -45,13 +57,13 @@ public AerospikeQuery visit(SqlCall sqlCall) { SqlUpdate sql = (SqlUpdate) sqlCall; query.setQueryType(QueryType.UPDATE); query.setTable(requireNonNull(sql.getTargetTable()).toString()); + query.setValues(sql.getSourceExpressionList().stream() + .map(this::parseValue).collect(Collectors.toList())); if (sql.getCondition() != null) { query.setPredicate(parseWhere((SqlBasicCall) sql.getCondition())); } query.setColumns(sql.getTargetColumnList().stream() .map(SqlNode::toString).collect(Collectors.toList())); - query.setValues(sql.getSourceExpressionList().stream() - .map(this::parseValue).collect(Collectors.toList())); } else if (sqlCall instanceof SqlInsert) { SqlInsert sql = (SqlInsert) sqlCall; query.setQueryType(QueryType.INSERT); @@ -97,8 +109,10 @@ public AerospikeQuery visit(SqlCall sqlCall) { } else { throw new UnsupportedOperationException(UNSUPPORTED_QUERY_TYPE_MESSAGE); } + } catch (UnsupportedOperationException e) { + throw e; } catch (Exception e) { - throw new UnsupportedOperationException(UNSUPPORTED_QUERY_TYPE_MESSAGE); + throw new UnsupportedOperationException(UNSUPPORTED_QUERY_TYPE_MESSAGE, e); } return query; } @@ -142,7 +156,7 @@ private QueryPredicate parseWhere(SqlBasicCall where) { } } else if (where.getOperator() instanceof SqlLikeOperator) { String binName = where.getOperandList().get(0).toString(); - String expression = unwrapString(where.getOperandList().get(1).toString()); + String expression = parseValue(where.getOperandList().get(1)).toString(); return new QueryPredicateLike(binName, expression); } else if (where.getOperator() instanceof SqlBetweenOperator) { return new QueryPredicateRange( @@ -168,6 +182,10 @@ private Object parseValue(SqlNode sqlNode) { } } else if (sqlNode instanceof SqlIdentifier) { return unwrapString(sqlNode.toString()); + } else if (sqlNode instanceof SqlDynamicParam + && unwrapString(sqlNode.toString()).equals(QUERY_PLACEHOLDER)) { + checkState(sqlParametersIterator != null, "SQL parameters is null"); + return sqlParametersIterator.next(); } throw new UnsupportedOperationException(UNSUPPORTED_QUERY_TYPE_MESSAGE); } diff --git a/src/main/java/com/aerospike/jdbc/predicate/QueryPredicateBase.java b/src/main/java/com/aerospike/jdbc/predicate/QueryPredicateBase.java index 25ac5e2..dfa325e 100644 --- a/src/main/java/com/aerospike/jdbc/predicate/QueryPredicateBase.java +++ b/src/main/java/com/aerospike/jdbc/predicate/QueryPredicateBase.java @@ -20,28 +20,45 @@ protected QueryPredicateBase( } protected static Exp.Type getValueType(Object value) { - if (value instanceof String) { + if (value == null) { + return Exp.Type.NIL; + } else if (value instanceof String) { return Exp.Type.STRING; - } else if (value instanceof Long) { + } else if (value instanceof Long || value instanceof Integer + || value instanceof Short || value instanceof Byte) { return Exp.Type.INT; - } else if (value instanceof Double) { + } else if (value instanceof Double || value instanceof Float) { return Exp.Type.FLOAT; } else if (value instanceof Boolean) { return Exp.Type.BOOL; + } else if (value instanceof byte[]) { + return Exp.Type.BLOB; } else { return Exp.Type.STRING; } } protected Exp getValueExp(Object value) { - if (value instanceof String) { + if (value == null) { + return Exp.nil(); + } else if (value instanceof String) { return Exp.val((String) value); } else if (value instanceof Long) { return Exp.val((long) value); + } else if (value instanceof Integer) { + return Exp.val((int) value); + } else if (value instanceof Short) { + return Exp.val((short) value); + } else if (value instanceof Byte) { + return Exp.val((byte) value); } else if (value instanceof Double) { return Exp.val((double) value); + } else if (value instanceof Float) { + return Exp.val((float) value); } else if (value instanceof Boolean) { return Exp.val((boolean) value); + } else if (value instanceof byte[]) { + return Exp.val((byte[]) value); } else { return Exp.val(value.toString()); } diff --git a/src/main/java/com/aerospike/jdbc/util/PreparedStatement.java b/src/main/java/com/aerospike/jdbc/util/PreparedStatement.java index 4b8d807..fa7c59a 100644 --- a/src/main/java/com/aerospike/jdbc/util/PreparedStatement.java +++ b/src/main/java/com/aerospike/jdbc/util/PreparedStatement.java @@ -47,7 +47,7 @@ public static Iterable splitQueries(String sql) { } } - if (currentQuery.length() > 0 && currentQuery.toString().trim().length() > 0) { + if (currentQuery.length() > 0 && !currentQuery.toString().trim().isEmpty()) { appendNotEmpty(queries, currentQuery.toString()); } @@ -55,7 +55,7 @@ public static Iterable splitQueries(String sql) { } private static void appendNotEmpty(Collection queries, String query) { - if (query.trim().length() > 0) { + if (!query.trim().isEmpty()) { queries.add(query); } } diff --git a/src/test/java/com/aerospike/jdbc/JdbcBaseTest.java b/src/test/java/com/aerospike/jdbc/JdbcBaseTest.java index 6fa3412..2379a2a 100644 --- a/src/test/java/com/aerospike/jdbc/JdbcBaseTest.java +++ b/src/test/java/com/aerospike/jdbc/JdbcBaseTest.java @@ -1,7 +1,7 @@ package com.aerospike.jdbc; -import org.testng.annotations.AfterSuite; -import org.testng.annotations.BeforeSuite; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; import java.sql.Connection; import java.sql.DriverManager; @@ -25,7 +25,7 @@ public abstract class JdbcBaseTest { protected static Connection connection; - @BeforeSuite + @BeforeClass public static void connectionInit() throws Exception { logger.info("connectionInit"); Class.forName("com.aerospike.jdbc.AerospikeDriver").newInstance(); @@ -34,7 +34,7 @@ public static void connectionInit() throws Exception { connection.setNetworkTimeout(Executors.newSingleThreadExecutor(), 5000); } - @AfterSuite + @AfterClass public static void connectionClose() throws SQLException { logger.info("connectionClose"); connection.close(); diff --git a/src/test/java/com/aerospike/jdbc/PreparedQueriesTest.java b/src/test/java/com/aerospike/jdbc/PreparedQueriesTest.java index fa3738e..77ce9dc 100644 --- a/src/test/java/com/aerospike/jdbc/PreparedQueriesTest.java +++ b/src/test/java/com/aerospike/jdbc/PreparedQueriesTest.java @@ -18,18 +18,23 @@ public class PreparedQueriesTest extends JdbcBaseTest { + private final byte[] blobValue = new byte[]{1, 2, 3, 4}; + @BeforeMethod public void setUp() throws SQLException { Objects.requireNonNull(connection, "connection is null"); PreparedStatement statement = null; int count; String query = format( - "insert into %s (%s, bin1, int1, str1, bool1) values (\"key1\", 11100, 1, \"bar\", true)", + "insert into %s (%s, bin1, int1, str1, bool1, val1, val2) values " + + "(\"key1\", 11100, 1, \"bar\", true, ?, ?)", tableName, PRIMARY_KEY_COLUMN_NAME ); try { statement = connection.prepareStatement(query); + statement.setBytes(1, blobValue); + statement.setNull(2, 0); count = statement.executeUpdate(); } finally { closeQuietly(statement); @@ -64,6 +69,7 @@ public void testSelectQuery() throws SQLException { while (resultSet.next()) { assertAllByColumnLabel(resultSet); assertAllByColumnIndex(resultSet); + assertBlobValue(resultSet); total++; } @@ -78,14 +84,17 @@ public void testSelectQuery() throws SQLException { public void testSelectByPrimaryKeyQuery() throws SQLException { PreparedStatement statement = null; ResultSet resultSet = null; - String query = format("select *, bin1 from %s where %s='key1'", tableName, PRIMARY_KEY_COLUMN_NAME); + String query = format("select *, bin1 from %s where %s=?", tableName, PRIMARY_KEY_COLUMN_NAME); int total = 0; try { statement = connection.prepareStatement(query); + statement.setString(1, "key1"); + resultSet = statement.executeQuery(); while (resultSet.next()) { assertAllByColumnLabel(resultSet); assertAllByColumnIndex(resultSet); + assertBlobValue(resultSet); total++; } @@ -100,9 +109,13 @@ public void testSelectByPrimaryKeyQuery() throws SQLException { public void testInsertQuery() throws SQLException { PreparedStatement statement = null; int count; - String query = format("insert into %s (bin1, int1) values (11101, 3), (11102, 4)", tableName); + String query = format("insert into %s (bin1, int1) values (?, ?), (?, ?)", tableName); try { statement = connection.prepareStatement(query); + statement.setInt(1, 11101); + statement.setInt(2, 3); + statement.setInt(3, 11102); + statement.setInt(4, 4); count = statement.executeUpdate(); } finally { closeQuietly(statement); @@ -114,19 +127,23 @@ public void testInsertQuery() throws SQLException { public void testUpdateQuery() throws SQLException { PreparedStatement statement = null; int count; - String query = format("update %s set int1=100 where bin1>10000", tableName); + String query = format("update %s set int1=? where bin1>?", tableName); try { statement = connection.prepareStatement(query); - count = statement.executeUpdate(query); + statement.setInt(1, 100); + statement.setInt(2, 10000); + count = statement.executeUpdate(); } finally { closeQuietly(statement); } assertEquals(count, 1); - query = format("update %s set int1=100 where bin1>20000", tableName); + query = format("update %s set int1=? where bin1>?", tableName); try { statement = connection.prepareStatement(query); - count = statement.executeUpdate(query); + statement.setInt(1, 100); + statement.setInt(2, 20000); + count = statement.executeUpdate(); } finally { closeQuietly(statement); } @@ -154,9 +171,11 @@ public void testSelectCountQuery() throws SQLException { public void testSelectEqualsQuery() throws SQLException { PreparedStatement statement = null; ResultSet resultSet = null; - String query = format("select %s from %s where int1 = 1", PRIMARY_KEY_COLUMN_NAME, tableName); + String query = format("select %s from %s where int1 = ?", PRIMARY_KEY_COLUMN_NAME, tableName); try { statement = connection.prepareStatement(query); + statement.setInt(1, 1); + resultSet = statement.executeQuery(); assertTrue(resultSet.next()); @@ -172,9 +191,11 @@ public void testSelectEqualsQuery() throws SQLException { public void testSelectNotEqualsQuery() throws SQLException { PreparedStatement statement = null; ResultSet resultSet = null; - String query = format("select %s, int1 from %s where int1 <> 2", PRIMARY_KEY_COLUMN_NAME, tableName); + String query = format("select %s, int1 from %s where int1 <> ?", PRIMARY_KEY_COLUMN_NAME, tableName); try { statement = connection.prepareStatement(query); + statement.setInt(1, 2); + resultSet = statement.executeQuery(); assertTrue(resultSet.next()); @@ -243,9 +264,12 @@ public void testSelectAndQuery() throws SQLException { public void testSelectInQuery() throws SQLException { PreparedStatement statement = null; ResultSet resultSet = null; - String query = format("select * from %s where int1 in (1, 2) and str1 is not null", tableName); + String query = format("select * from %s where int1 in (?, ?) and str1 is not null", tableName); try { statement = connection.prepareStatement(query); + statement.setInt(1, 1); + statement.setInt(2, 2); + resultSet = statement.executeQuery(); assertTrue(resultSet.next()); @@ -261,9 +285,13 @@ public void testSelectInQuery() throws SQLException { public void testSelectBetweenQuery() throws SQLException { PreparedStatement statement = null; ResultSet resultSet = null; - String query = format("select * from %s where int1 between 1 and 3", tableName); + String query = format("select * from %s where int1 between ? and ? and val1=?", tableName); try { statement = connection.prepareStatement(query); + statement.setInt(1, 1); + statement.setInt(2, 3); + statement.setBytes(3, blobValue); + resultSet = statement.executeQuery(); assertTrue(resultSet.next()); @@ -274,4 +302,9 @@ public void testSelectBetweenQuery() throws SQLException { closeQuietly(resultSet); } } + + private void assertBlobValue(ResultSet resultSet) throws SQLException { + assertEquals(resultSet.getBytes("val1"), blobValue); + assertEquals(resultSet.getBytes(6), blobValue); + } }