From 8f6b177986ab89676fcca2b1fdedad580896b8c3 Mon Sep 17 00:00:00 2001 From: Navneet Verma Date: Sun, 8 Sep 2024 19:05:34 -0700 Subject: [PATCH] [BugFix] Fixing KNNMethodContext resolution when documents are ingested and a flaky test (#2072) Signed-off-by: Navneet Verma --- .../index/mapper/KNNVectorFieldMapper.java | 2 + .../mapper/KNNVectorFieldMapperTests.java | 59 +++++++++++++++++++ .../knn/integ/ModeAndCompressionIT.java | 58 ++++++++++++++++-- 3 files changed, 114 insertions(+), 5 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java index 265876310..f149fa1d2 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java @@ -782,6 +782,8 @@ public ParametrizedFieldMapper.Builder getMergeBuilder() { .vectorDataType(vectorDataType) .versionCreated(indexCreatedVersion) .dimension(fieldType().getKnnMappingConfig().getDimension()) + .compressionLevel(fieldType().getKnnMappingConfig().getCompressionLevel()) + .mode(fieldType().getKnnMappingConfig().getMode()) .build(); } diff --git a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java index 496641339..84cbf05dc 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java @@ -880,6 +880,65 @@ public void testTypeParser_parse_fromLegacy() throws IOException { assertNull(builder.knnMethodContext.get()); } + public void testKNNVectorFieldMapperMerge_whenModeAndCompressionIsPresent_thenSuccess() throws IOException { + String fieldName = "test-field-name"; + String indexName = "test-index-name"; + + Settings settings = Settings.builder().put(settings(CURRENT).build()).put(KNN_INDEX, true).build(); + ModelDao modelDao = mock(ModelDao.class); + KNNVectorFieldMapper.TypeParser typeParser = new KNNVectorFieldMapper.TypeParser(() -> modelDao); + + int dimension = 133; + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE) + .field(DIMENSION_FIELD_NAME, dimension) + .field(MODE_PARAMETER, Mode.ON_DISK.getName()) + .field(COMPRESSION_LEVEL_PARAMETER, CompressionLevel.x32.getName()) + .endObject(); + + KNNVectorFieldMapper.Builder builder = (KNNVectorFieldMapper.Builder) typeParser.parse( + fieldName, + xContentBuilderToMap(xContentBuilder), + buildParserContext(indexName, settings) + ); + Mapper.BuilderContext builderContext = new Mapper.BuilderContext(settings, new ContentPath()); + KNNVectorFieldMapper knnVectorFieldMapper1 = builder.build(builderContext); + + // merge with itself - should be successful + KNNVectorFieldMapper knnVectorFieldMapperMerge1 = (KNNVectorFieldMapper) knnVectorFieldMapper1.merge(knnVectorFieldMapper1); + assertEquals( + knnVectorFieldMapper1.fieldType().getKnnMappingConfig().getKnnMethodContext().get(), + knnVectorFieldMapperMerge1.fieldType().getKnnMappingConfig().getKnnMethodContext().get() + ); + + assertEquals( + knnVectorFieldMapper1.fieldType().getKnnMappingConfig().getCompressionLevel(), + knnVectorFieldMapperMerge1.fieldType().getKnnMappingConfig().getCompressionLevel() + ); + assertEquals( + knnVectorFieldMapper1.fieldType().getKnnMappingConfig().getMode(), + knnVectorFieldMapperMerge1.fieldType().getKnnMappingConfig().getMode() + ); + + // merge with another mapper of the same field with same context + KNNVectorFieldMapper knnVectorFieldMapper2 = builder.build(builderContext); + KNNVectorFieldMapper knnVectorFieldMapperMerge2 = (KNNVectorFieldMapper) knnVectorFieldMapper1.merge(knnVectorFieldMapper2); + assertEquals( + knnVectorFieldMapper1.fieldType().getKnnMappingConfig().getKnnMethodContext().get(), + knnVectorFieldMapperMerge2.fieldType().getKnnMappingConfig().getKnnMethodContext().get() + ); + + assertEquals( + knnVectorFieldMapper1.fieldType().getKnnMappingConfig().getCompressionLevel(), + knnVectorFieldMapperMerge2.fieldType().getKnnMappingConfig().getCompressionLevel() + ); + assertEquals( + knnVectorFieldMapper1.fieldType().getKnnMappingConfig().getMode(), + knnVectorFieldMapperMerge2.fieldType().getKnnMappingConfig().getMode() + ); + } + public void testKNNVectorFieldMapper_merge_fromKnnMethodContext() throws IOException { String fieldName = "test-field-name"; String indexName = "test-index-name"; diff --git a/src/test/java/org/opensearch/knn/integ/ModeAndCompressionIT.java b/src/test/java/org/opensearch/knn/integ/ModeAndCompressionIT.java index a2bf46aa4..8b4c856c4 100644 --- a/src/test/java/org/opensearch/knn/integ/ModeAndCompressionIT.java +++ b/src/test/java/org/opensearch/knn/integ/ModeAndCompressionIT.java @@ -7,16 +7,18 @@ import lombok.SneakyThrows; import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.junit.Assert; import org.junit.Ignore; +import org.opensearch.client.Request; import org.opensearch.client.Response; import org.opensearch.client.ResponseException; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.knn.KNNRestTestCase; -import org.opensearch.knn.KNNResult; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.mapper.CompressionLevel; import org.opensearch.knn.index.mapper.Mode; import org.opensearch.knn.index.mapper.ModeBasedResolver; @@ -50,7 +52,7 @@ public class ModeAndCompressionIT extends KNNRestTestCase { private static final int DIMENSION = 16; private static final int NUM_DOCS = 20; - private static final int K = 2; + private static final int K = NUM_DOCS; private final static float[] TEST_VECTOR = new float[] { 1.0f, 2.0f, @@ -210,7 +212,7 @@ public void testDeletedDocsWithSegmentMerge_whenValid_ThenSucceed() { .endObject(); String mapping = builder.toString(); validateIndexWithDeletedDocs(indexName, mapping); - validateSearch(indexName, METHOD_PARAMETER_EF_SEARCH, KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_EF_SEARCH); + validateGreenIndex(indexName); } @SneakyThrows @@ -352,6 +354,19 @@ private void validateIndexWithDeletedDocs(String indexName, String mapping) { refreshIndex(indexName); } + @SneakyThrows + private void validateGreenIndex(String indexName) { + Request request = new Request("GET", "/_cat/indices/" + indexName + "?format=csv"); + Response response = client().performRequest(request); + assertOK(response); + assertEquals( + "The status of index " + indexName + " is not green", + "green", + new String(response.getEntity().getContent().readAllBytes()).split("\n")[0].split(" ")[0] + ); + + } + @SneakyThrows private void setupTrainingIndex() { createBasicKnnIndex(TRAINING_INDEX_NAME, TRAINING_FIELD_NAME, DIMENSION); @@ -388,9 +403,41 @@ private void validateSearch(String indexName, String methodParameterName, int me ); assertOK(response); String responseBody = EntityUtils.toString(response.getEntity()); - List knnResults = parseSearchResponse(responseBody, FIELD_NAME); + List knnResults = parseSearchResponseScore(responseBody, FIELD_NAME); assertEquals(K, knnResults.size()); + // Do exact search and gather right scores for the documents + Response exactSearchResponse = searchKNNIndex( + indexName, + XContentFactory.jsonBuilder() + .startObject() + .startObject("query") + .startObject("script_score") + .startObject("query") + .field("match_all") + .startObject() + .endObject() + .endObject() + .startObject("script") + .field("source", "knn_score") + .field("lang", "knn") + .startObject("params") + .field("field", FIELD_NAME) + .field("query_value", TEST_VECTOR) + .field("space_type", SpaceType.L2.getValue()) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(), + K + ); + assertOK(exactSearchResponse); + String exactSearchResponseBody = EntityUtils.toString(exactSearchResponse.getEntity()); + List exactSearchKnnResults = parseSearchResponseScore(exactSearchResponseBody, FIELD_NAME); + assertEquals(NUM_DOCS, exactSearchKnnResults.size()); + Assert.assertEquals(exactSearchKnnResults, knnResults); + // Search with rescore response = searchKNNIndex( indexName, @@ -415,7 +462,8 @@ private void validateSearch(String indexName, String methodParameterName, int me ); assertOK(response); responseBody = EntityUtils.toString(response.getEntity()); - knnResults = parseSearchResponse(responseBody, FIELD_NAME); + knnResults = parseSearchResponseScore(responseBody, FIELD_NAME); assertEquals(K, knnResults.size()); + Assert.assertEquals(exactSearchKnnResults, knnResults); } }