Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DuckDB: Improve SQL management to avoid injection (#44) #45

Merged
merged 1 commit into from
Jan 4, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,22 +48,21 @@ public class DuckDBEmbeddingStore implements EmbeddingStore<TextSegment> {

private static final String SEARCH_QUERY_TEMPLATE =
"""
select *, (list_cosine_similarity(embedding,%s)+1.0)/2.0 as score
select id, embedding, text, metadata, (list_cosine_similarity(embedding,%s)+1.0)/2.0 as score
from %s
where score >= %s %s
order by score DESC
limit(%d)
limit %d
""";

private static final String INSERT_QUERY_TEMPLATE =
"""
insert into %s (id, embedding, text, metadata) values %s
insert into %s (id, embedding, text, metadata) values (?,?,?,?)
""";

private static final String DELETE_BY_IDS_QUERY_TEMPLATE =
"""
delete from %s where id in (%s)
""";
private static final String DELETE_BY_IDS_QUERY_TEMPLATE = """
delete from %s where id in ?
""";

private static final String DELETE_QUERY_TEMPLATE = """
delete from %s where %s
Expand Down Expand Up @@ -171,12 +170,12 @@ public List<String> addAll(List<Embedding> embeddings, List<TextSegment> embedde
@Override
public void removeAll(Collection<String> ids) {
ensureNotEmpty(ids, "ids");
var idsParam = ids.stream().map(id -> "'" + id + "'").collect(Collectors.joining(","));
String sql = format(DELETE_BY_IDS_QUERY_TEMPLATE, tableName, idsParam);
String sql = format(DELETE_BY_IDS_QUERY_TEMPLATE, tableName);
try (var connection = duckDBConnection.duplicate();
var statement = connection.createStatement()) {
log.debug(sql);
statement.execute(sql);
var statement = connection.prepareStatement(sql)) {
var idsParam = connection.createArrayOf("UUID", ids.toArray());
statement.setObject(1, idsParam);
statement.execute();
} catch (SQLException e) {
throw new DuckDBSQLException("Unable to remove embeddings by ids", e);
}
Expand All @@ -188,9 +187,9 @@ public void removeAll(Filter filter) {
var whereClause = jsonFilterMapper.map(filter);
String sql = format(DELETE_QUERY_TEMPLATE, tableName, whereClause);
try (var connection = duckDBConnection.duplicate();
var statement = connection.createStatement()) {
var statement = connection.prepareStatement(sql)) {
log.debug(sql);
statement.execute(sql);
statement.execute();
} catch (SQLException e) {
throw new DuckDBSQLException("Unable to remove embeddings with filter", e);
}
Expand All @@ -209,15 +208,17 @@ public void removeAll() {

@Override
public EmbeddingSearchResult<TextSegment> search(EmbeddingSearchRequest request) {
var param = embeddingToParam(request.queryEmbedding());
var filterClause = request.filter() != null ? "and " + jsonFilterMapper.map(request.filter()) : "";
var query =
format(SEARCH_QUERY_TEMPLATE, param, tableName, request.minScore(), filterClause, request.maxResults());

try (var connection = duckDBConnection.duplicate();
var statement = connection.createStatement()) {
var statement = connection.prepareStatement(query)) {
var matches = new ArrayList<EmbeddingMatch<TextSegment>>();
var param = embeddingToParam(request.queryEmbedding());
var filterClause = request.filter() != null ? "and " + jsonFilterMapper.map(request.filter()) : "";
var query = format(
SEARCH_QUERY_TEMPLATE, param, tableName, request.minScore(), filterClause, request.maxResults());

log.debug(query);
var resultSet = statement.executeQuery(query);
var resultSet = statement.executeQuery();
while (resultSet.next()) {

var id = resultSet.getString("id");
Expand Down Expand Up @@ -261,27 +262,25 @@ public void addAll(List<String> ids, List<Embedding> embeddings, List<TextSegmen
"embeddings size is not equal to embedded size");

try (var connection = duckDBConnection.duplicate();
var statement = connection.createStatement()) {
var values = new ArrayList<String>(ids.size());
var statement = connection.prepareStatement(format(INSERT_QUERY_TEMPLATE, tableName))) {
for (int i = 0; i < ids.size(); i++) {
var text = "NULL";
String textParam = null;
if (embedded != null && embedded.get(i) != null) {
text = "'" + embedded.get(i).text() + "'";
textParam = embedded.get(i).text();
}
var metadata = embedded != null && embedded.get(i) != null
? embedded.get(i).metadata().toMap()
: null;
values.add(format(
"('%s',%s,%s,'%s')",
ids.get(i),
embeddingToParam(embeddings.get(i)),
text,
jsonMetadataSerializer.writeValueAsString(metadata)));
}

var sql = format(INSERT_QUERY_TEMPLATE, tableName, String.join(",", values));
log.debug(sql);
statement.execute(sql);
statement.setString(1, ids.get(i));
var embeddingsParam = connection.createArrayOf(
"float", embeddings.get(i).vectorAsList().toArray());
statement.setObject(2, embeddingsParam);
statement.setString(3, textParam);
statement.setString(4, jsonMetadataSerializer.writeValueAsString(metadata));
statement.addBatch();
}
statement.executeBatch();
} catch (SQLException | JsonProcessingException e) {
throw new DuckDBSQLException("Unable to add embeddings in DuckDB", e);
}
Expand Down
Loading