From a3479c5811c0719cdde63436c3b61f81d38d8040 Mon Sep 17 00:00:00 2001 From: Mirko Zizzari Date: Fri, 12 Jul 2024 11:31:18 +0200 Subject: [PATCH] issue #957: fixes knn wrapper query generation --- .../searcher/parser/impl/KnnQueryParser.java | 36 +++++---- .../parser/impl/KnnQueryParserTest.java | 81 +++++++++++++++++++ 2 files changed, 103 insertions(+), 14 deletions(-) create mode 100644 core/app/datasource/src/test/java/io/openk9/datasource/searcher/parser/impl/KnnQueryParserTest.java diff --git a/core/app/datasource/src/main/java/io/openk9/datasource/searcher/parser/impl/KnnQueryParser.java b/core/app/datasource/src/main/java/io/openk9/datasource/searcher/parser/impl/KnnQueryParser.java index 5809b9de8..92c352443 100644 --- a/core/app/datasource/src/main/java/io/openk9/datasource/searcher/parser/impl/KnnQueryParser.java +++ b/core/app/datasource/src/main/java/io/openk9/datasource/searcher/parser/impl/KnnQueryParser.java @@ -25,6 +25,7 @@ import jakarta.json.Json; import org.opensearch.client.json.jackson.JacksonJsonpMapper; import org.opensearch.client.opensearch._types.query_dsl.KnnQuery; +import org.opensearch.client.opensearch._types.query_dsl.Query; import org.opensearch.index.query.QueryBuilders; import java.io.ByteArrayOutputStream; @@ -55,7 +56,7 @@ public Uni apply(ParserContext parserContext) { var tenantId = currentTenant.getTenant(); - var knnQueryUnis = new ArrayList>(); + var knnQueryUnis = new ArrayList>(); for (ParserSearchToken parserSearchToken : tokenTypeGroup) { @@ -74,6 +75,7 @@ public Uni apply(ParserContext parserContext) { .field("vector") .vector(toVector(embeddedText)) .build() + .toQuery() ); knnQueryUnis.add(knnQuery); @@ -86,26 +88,32 @@ public Uni apply(ParserContext parserContext) { .andCollectFailures() .invoke(knnQueries -> { - for (KnnQuery knnQuery : knnQueries) { + for (Query knnQuery : knnQueries) { - try (var os = new ByteArrayOutputStream()) { + addsKnnQuery(parserContext, knnQuery); + } - var generator = Json.createGenerator(os); + }) + .replaceWithVoid(); + } - knnQuery.serialize(generator, new JacksonJsonpMapper()); + protected static void addsKnnQuery(ParserContext parserContext, Query knnQuery) { + try (var os = new ByteArrayOutputStream()) { - var wrapperQueryBuilder = QueryBuilders.wrapperQuery(os.toByteArray()); + var generator = Json.createGenerator(os); - parserContext.getMutableQuery().must(wrapperQueryBuilder); + knnQuery.serialize(generator, new JacksonJsonpMapper()); - } - catch (IOException e) { - throw new RuntimeException(e); - } - } + generator.close(); - }) - .replaceWithVoid(); + var wrapperQueryBuilder = QueryBuilders.wrapperQuery(os.toByteArray()); + + parserContext.getMutableQuery().must(wrapperQueryBuilder); + + } + catch (IOException e) { + throw new RuntimeException(e); + } } private static float[] toVector(EmbeddingService.EmbeddedText embeddedText) { diff --git a/core/app/datasource/src/test/java/io/openk9/datasource/searcher/parser/impl/KnnQueryParserTest.java b/core/app/datasource/src/test/java/io/openk9/datasource/searcher/parser/impl/KnnQueryParserTest.java new file mode 100644 index 000000000..89fe8aa56 --- /dev/null +++ b/core/app/datasource/src/test/java/io/openk9/datasource/searcher/parser/impl/KnnQueryParserTest.java @@ -0,0 +1,81 @@ +/* + * Copyright (c) 2020-present SMC Treviso s.r.l. All rights reserved. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package io.openk9.datasource.searcher.parser.impl; + +import com.jayway.jsonpath.JsonPath; +import io.openk9.datasource.searcher.parser.ParserContext; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.opensearch.client.opensearch._types.query_dsl.KnnQuery; +import org.opensearch.index.query.BoolQueryBuilder; + +import java.util.Base64; +import java.util.List; + +class KnnQueryParserTest { + + private static final int VECTOR_SIZE = 1800; + private static final int K_NEIGHBORS = 2; + + @Test + void addsKnnQuery() { + + var parserContext = new ParserContext(); + + var boolQueryBuilder = new BoolQueryBuilder(); + + parserContext.setMutableQuery(boolQueryBuilder); + + var knnQueryBuilder = new KnnQuery.Builder() + .field("vector") + .k(K_NEIGHBORS) + .vector(randomVector(VECTOR_SIZE)) + .build() + .toQuery(); + + KnnQueryParser.addsKnnQuery(parserContext, knnQueryBuilder); + + var query = parserContext.getMutableQuery().toString(); + + String wrappedQuery = JsonPath.parse(query).read("$.bool.must[0].wrapper.query"); + + var wrappedKnnQuery = new String(Base64.getDecoder().decode(wrappedQuery)); + + var documentContext = JsonPath.parse(wrappedKnnQuery); + + int k = documentContext.read("$.knn.vector.k"); + List vector = documentContext.read("$.knn.vector.vector"); + + Assertions.assertEquals(K_NEIGHBORS, k); + Assertions.assertEquals(VECTOR_SIZE, vector.size()); + + } + + private float[] randomVector(int size) { + var vector = new float[size]; + + for (int i = 0; i < size; i++) { + + vector[i] = (float) Math.random(); + + } + + return vector; + } + +} \ No newline at end of file