From a90b294923b9613985444c48432e4b44fe0e0c1e Mon Sep 17 00:00:00 2001 From: Florian Bernard Date: Fri, 3 Jan 2025 13:56:32 +0100 Subject: [PATCH] DuckDB: Improve SQL management to avoid injection (#44) --- .../duckdb/DuckDBEmbeddingStore.java | 67 +++++++++---------- 1 file changed, 33 insertions(+), 34 deletions(-) diff --git a/embedding-stores/langchain4j-community-duckdb/src/main/java/dev/langchain4j/community/store/embedding/duckdb/DuckDBEmbeddingStore.java b/embedding-stores/langchain4j-community-duckdb/src/main/java/dev/langchain4j/community/store/embedding/duckdb/DuckDBEmbeddingStore.java index 88ba209..a735939 100644 --- a/embedding-stores/langchain4j-community-duckdb/src/main/java/dev/langchain4j/community/store/embedding/duckdb/DuckDBEmbeddingStore.java +++ b/embedding-stores/langchain4j-community-duckdb/src/main/java/dev/langchain4j/community/store/embedding/duckdb/DuckDBEmbeddingStore.java @@ -48,22 +48,21 @@ public class DuckDBEmbeddingStore implements EmbeddingStore { private static final String SEARCH_QUERY_TEMPLATE = """ - select *, (list_cosine_similarity(embedding,%s)+1.0)/2.0 as score + select id, embedding, text, metadata, (list_cosine_similarity(embedding,%s)+1.0)/2.0 as score from %s where score >= %s %s order by score DESC - limit(%d) + limit %d """; private static final String INSERT_QUERY_TEMPLATE = """ - insert into %s (id, embedding, text, metadata) values %s + insert into %s (id, embedding, text, metadata) values (?,?,?,?) """; - private static final String DELETE_BY_IDS_QUERY_TEMPLATE = - """ - delete from %s where id in (%s) - """; + private static final String DELETE_BY_IDS_QUERY_TEMPLATE = """ + delete from %s where id in ? + """; private static final String DELETE_QUERY_TEMPLATE = """ delete from %s where %s @@ -171,12 +170,12 @@ public List addAll(List embeddings, List embedde @Override public void removeAll(Collection ids) { ensureNotEmpty(ids, "ids"); - var idsParam = ids.stream().map(id -> "'" + id + "'").collect(Collectors.joining(",")); - String sql = format(DELETE_BY_IDS_QUERY_TEMPLATE, tableName, idsParam); + String sql = format(DELETE_BY_IDS_QUERY_TEMPLATE, tableName); try (var connection = duckDBConnection.duplicate(); - var statement = connection.createStatement()) { - log.debug(sql); - statement.execute(sql); + var statement = connection.prepareStatement(sql)) { + var idsParam = connection.createArrayOf("UUID", ids.toArray()); + statement.setObject(1, idsParam); + statement.execute(); } catch (SQLException e) { throw new DuckDBSQLException("Unable to remove embeddings by ids", e); } @@ -188,9 +187,9 @@ public void removeAll(Filter filter) { var whereClause = jsonFilterMapper.map(filter); String sql = format(DELETE_QUERY_TEMPLATE, tableName, whereClause); try (var connection = duckDBConnection.duplicate(); - var statement = connection.createStatement()) { + var statement = connection.prepareStatement(sql)) { log.debug(sql); - statement.execute(sql); + statement.execute(); } catch (SQLException e) { throw new DuckDBSQLException("Unable to remove embeddings with filter", e); } @@ -209,15 +208,17 @@ public void removeAll() { @Override public EmbeddingSearchResult search(EmbeddingSearchRequest request) { + var param = embeddingToParam(request.queryEmbedding()); + var filterClause = request.filter() != null ? "and " + jsonFilterMapper.map(request.filter()) : ""; + var query = + format(SEARCH_QUERY_TEMPLATE, param, tableName, request.minScore(), filterClause, request.maxResults()); + try (var connection = duckDBConnection.duplicate(); - var statement = connection.createStatement()) { + var statement = connection.prepareStatement(query)) { var matches = new ArrayList>(); - var param = embeddingToParam(request.queryEmbedding()); - var filterClause = request.filter() != null ? "and " + jsonFilterMapper.map(request.filter()) : ""; - var query = format( - SEARCH_QUERY_TEMPLATE, param, tableName, request.minScore(), filterClause, request.maxResults()); + log.debug(query); - var resultSet = statement.executeQuery(query); + var resultSet = statement.executeQuery(); while (resultSet.next()) { var id = resultSet.getString("id"); @@ -261,27 +262,25 @@ public void addAll(List ids, List embeddings, List(ids.size()); + var statement = connection.prepareStatement(format(INSERT_QUERY_TEMPLATE, tableName))) { for (int i = 0; i < ids.size(); i++) { - var text = "NULL"; + String textParam = null; if (embedded != null && embedded.get(i) != null) { - text = "'" + embedded.get(i).text() + "'"; + textParam = embedded.get(i).text(); } var metadata = embedded != null && embedded.get(i) != null ? embedded.get(i).metadata().toMap() : null; - values.add(format( - "('%s',%s,%s,'%s')", - ids.get(i), - embeddingToParam(embeddings.get(i)), - text, - jsonMetadataSerializer.writeValueAsString(metadata))); - } - var sql = format(INSERT_QUERY_TEMPLATE, tableName, String.join(",", values)); - log.debug(sql); - statement.execute(sql); + statement.setString(1, ids.get(i)); + var embeddingsParam = connection.createArrayOf( + "float", embeddings.get(i).vectorAsList().toArray()); + statement.setObject(2, embeddingsParam); + statement.setString(3, textParam); + statement.setString(4, jsonMetadataSerializer.writeValueAsString(metadata)); + statement.addBatch(); + } + statement.executeBatch(); } catch (SQLException | JsonProcessingException e) { throw new DuckDBSQLException("Unable to add embeddings in DuckDB", e); }