Skip to content

Commit

Permalink
issue #957: fixes knn wrapper query generation
Browse files Browse the repository at this point in the history
  • Loading branch information
mrk-vi committed Jul 12, 2024
1 parent 2b28822 commit a3479c5
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import jakarta.json.Json;
import org.opensearch.client.json.jackson.JacksonJsonpMapper;
import org.opensearch.client.opensearch._types.query_dsl.KnnQuery;
import org.opensearch.client.opensearch._types.query_dsl.Query;
import org.opensearch.index.query.QueryBuilders;

import java.io.ByteArrayOutputStream;
Expand Down Expand Up @@ -55,7 +56,7 @@ public Uni<Void> apply(ParserContext parserContext) {

var tenantId = currentTenant.getTenant();

var knnQueryUnis = new ArrayList<Uni<KnnQuery>>();
var knnQueryUnis = new ArrayList<Uni<Query>>();

for (ParserSearchToken parserSearchToken : tokenTypeGroup) {

Expand All @@ -74,6 +75,7 @@ public Uni<Void> apply(ParserContext parserContext) {
.field("vector")
.vector(toVector(embeddedText))
.build()
.toQuery()
);

knnQueryUnis.add(knnQuery);
Expand All @@ -86,26 +88,32 @@ public Uni<Void> apply(ParserContext parserContext) {
.andCollectFailures()
.invoke(knnQueries -> {

for (KnnQuery knnQuery : knnQueries) {
for (Query knnQuery : knnQueries) {

try (var os = new ByteArrayOutputStream()) {
addsKnnQuery(parserContext, knnQuery);
}

var generator = Json.createGenerator(os);
})
.replaceWithVoid();
}

knnQuery.serialize(generator, new JacksonJsonpMapper());
protected static void addsKnnQuery(ParserContext parserContext, Query knnQuery) {
try (var os = new ByteArrayOutputStream()) {

var wrapperQueryBuilder = QueryBuilders.wrapperQuery(os.toByteArray());
var generator = Json.createGenerator(os);

parserContext.getMutableQuery().must(wrapperQueryBuilder);
knnQuery.serialize(generator, new JacksonJsonpMapper());

}
catch (IOException e) {
throw new RuntimeException(e);
}
}
generator.close();

})
.replaceWithVoid();
var wrapperQueryBuilder = QueryBuilders.wrapperQuery(os.toByteArray());

parserContext.getMutableQuery().must(wrapperQueryBuilder);

}
catch (IOException e) {
throw new RuntimeException(e);
}
}

private static float[] toVector(EmbeddingService.EmbeddedText embeddedText) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Copyright (c) 2020-present SMC Treviso s.r.l. All rights reserved.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

package io.openk9.datasource.searcher.parser.impl;

import com.jayway.jsonpath.JsonPath;
import io.openk9.datasource.searcher.parser.ParserContext;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.opensearch.client.opensearch._types.query_dsl.KnnQuery;
import org.opensearch.index.query.BoolQueryBuilder;

import java.util.Base64;
import java.util.List;

class KnnQueryParserTest {

private static final int VECTOR_SIZE = 1800;
private static final int K_NEIGHBORS = 2;

@Test
void addsKnnQuery() {

var parserContext = new ParserContext();

var boolQueryBuilder = new BoolQueryBuilder();

parserContext.setMutableQuery(boolQueryBuilder);

var knnQueryBuilder = new KnnQuery.Builder()
.field("vector")
.k(K_NEIGHBORS)
.vector(randomVector(VECTOR_SIZE))
.build()
.toQuery();

KnnQueryParser.addsKnnQuery(parserContext, knnQueryBuilder);

var query = parserContext.getMutableQuery().toString();

String wrappedQuery = JsonPath.parse(query).read("$.bool.must[0].wrapper.query");

var wrappedKnnQuery = new String(Base64.getDecoder().decode(wrappedQuery));

var documentContext = JsonPath.parse(wrappedKnnQuery);

int k = documentContext.read("$.knn.vector.k");
List<Float> vector = documentContext.read("$.knn.vector.vector");

Assertions.assertEquals(K_NEIGHBORS, k);
Assertions.assertEquals(VECTOR_SIZE, vector.size());

}

private float[] randomVector(int size) {
var vector = new float[size];

for (int i = 0; i < size; i++) {

vector[i] = (float) Math.random();

}

return vector;
}

}

0 comments on commit a3479c5

Please sign in to comment.