Skip to content

Commit

Permalink
[BugFix] Fixing KNNMethodContext resolution when documents are ingest…
Browse files Browse the repository at this point in the history
…ed and a flaky test (opensearch-project#2072)

Signed-off-by: Navneet Verma <[email protected]>
  • Loading branch information
navneet1v authored Sep 9, 2024
1 parent 0492fb3 commit 8f6b177
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,8 @@ public ParametrizedFieldMapper.Builder getMergeBuilder() {
.vectorDataType(vectorDataType)
.versionCreated(indexCreatedVersion)
.dimension(fieldType().getKnnMappingConfig().getDimension())
.compressionLevel(fieldType().getKnnMappingConfig().getCompressionLevel())
.mode(fieldType().getKnnMappingConfig().getMode())
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -880,6 +880,65 @@ public void testTypeParser_parse_fromLegacy() throws IOException {
assertNull(builder.knnMethodContext.get());
}

public void testKNNVectorFieldMapperMerge_whenModeAndCompressionIsPresent_thenSuccess() throws IOException {
String fieldName = "test-field-name";
String indexName = "test-index-name";

Settings settings = Settings.builder().put(settings(CURRENT).build()).put(KNN_INDEX, true).build();
ModelDao modelDao = mock(ModelDao.class);
KNNVectorFieldMapper.TypeParser typeParser = new KNNVectorFieldMapper.TypeParser(() -> modelDao);

int dimension = 133;
XContentBuilder xContentBuilder = XContentFactory.jsonBuilder()
.startObject()
.field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE)
.field(DIMENSION_FIELD_NAME, dimension)
.field(MODE_PARAMETER, Mode.ON_DISK.getName())
.field(COMPRESSION_LEVEL_PARAMETER, CompressionLevel.x32.getName())
.endObject();

KNNVectorFieldMapper.Builder builder = (KNNVectorFieldMapper.Builder) typeParser.parse(
fieldName,
xContentBuilderToMap(xContentBuilder),
buildParserContext(indexName, settings)
);
Mapper.BuilderContext builderContext = new Mapper.BuilderContext(settings, new ContentPath());
KNNVectorFieldMapper knnVectorFieldMapper1 = builder.build(builderContext);

// merge with itself - should be successful
KNNVectorFieldMapper knnVectorFieldMapperMerge1 = (KNNVectorFieldMapper) knnVectorFieldMapper1.merge(knnVectorFieldMapper1);
assertEquals(
knnVectorFieldMapper1.fieldType().getKnnMappingConfig().getKnnMethodContext().get(),
knnVectorFieldMapperMerge1.fieldType().getKnnMappingConfig().getKnnMethodContext().get()
);

assertEquals(
knnVectorFieldMapper1.fieldType().getKnnMappingConfig().getCompressionLevel(),
knnVectorFieldMapperMerge1.fieldType().getKnnMappingConfig().getCompressionLevel()
);
assertEquals(
knnVectorFieldMapper1.fieldType().getKnnMappingConfig().getMode(),
knnVectorFieldMapperMerge1.fieldType().getKnnMappingConfig().getMode()
);

// merge with another mapper of the same field with same context
KNNVectorFieldMapper knnVectorFieldMapper2 = builder.build(builderContext);
KNNVectorFieldMapper knnVectorFieldMapperMerge2 = (KNNVectorFieldMapper) knnVectorFieldMapper1.merge(knnVectorFieldMapper2);
assertEquals(
knnVectorFieldMapper1.fieldType().getKnnMappingConfig().getKnnMethodContext().get(),
knnVectorFieldMapperMerge2.fieldType().getKnnMappingConfig().getKnnMethodContext().get()
);

assertEquals(
knnVectorFieldMapper1.fieldType().getKnnMappingConfig().getCompressionLevel(),
knnVectorFieldMapperMerge2.fieldType().getKnnMappingConfig().getCompressionLevel()
);
assertEquals(
knnVectorFieldMapper1.fieldType().getKnnMappingConfig().getMode(),
knnVectorFieldMapperMerge2.fieldType().getKnnMappingConfig().getMode()
);
}

public void testKNNVectorFieldMapper_merge_fromKnnMethodContext() throws IOException {
String fieldName = "test-field-name";
String indexName = "test-index-name";
Expand Down
58 changes: 53 additions & 5 deletions src/test/java/org/opensearch/knn/integ/ModeAndCompressionIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,18 @@

import lombok.SneakyThrows;
import org.apache.hc.core5.http.io.entity.EntityUtils;
import org.junit.Assert;
import org.junit.Ignore;
import org.opensearch.client.Request;
import org.opensearch.client.Response;
import org.opensearch.client.ResponseException;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.knn.KNNRestTestCase;
import org.opensearch.knn.KNNResult;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.mapper.CompressionLevel;
import org.opensearch.knn.index.mapper.Mode;
import org.opensearch.knn.index.mapper.ModeBasedResolver;
Expand Down Expand Up @@ -50,7 +52,7 @@ public class ModeAndCompressionIT extends KNNRestTestCase {

private static final int DIMENSION = 16;
private static final int NUM_DOCS = 20;
private static final int K = 2;
private static final int K = NUM_DOCS;
private final static float[] TEST_VECTOR = new float[] {
1.0f,
2.0f,
Expand Down Expand Up @@ -210,7 +212,7 @@ public void testDeletedDocsWithSegmentMerge_whenValid_ThenSucceed() {
.endObject();
String mapping = builder.toString();
validateIndexWithDeletedDocs(indexName, mapping);
validateSearch(indexName, METHOD_PARAMETER_EF_SEARCH, KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_EF_SEARCH);
validateGreenIndex(indexName);
}

@SneakyThrows
Expand Down Expand Up @@ -352,6 +354,19 @@ private void validateIndexWithDeletedDocs(String indexName, String mapping) {
refreshIndex(indexName);
}

@SneakyThrows
private void validateGreenIndex(String indexName) {
Request request = new Request("GET", "/_cat/indices/" + indexName + "?format=csv");
Response response = client().performRequest(request);
assertOK(response);
assertEquals(
"The status of index " + indexName + " is not green",
"green",
new String(response.getEntity().getContent().readAllBytes()).split("\n")[0].split(" ")[0]
);

}

@SneakyThrows
private void setupTrainingIndex() {
createBasicKnnIndex(TRAINING_INDEX_NAME, TRAINING_FIELD_NAME, DIMENSION);
Expand Down Expand Up @@ -388,9 +403,41 @@ private void validateSearch(String indexName, String methodParameterName, int me
);
assertOK(response);
String responseBody = EntityUtils.toString(response.getEntity());
List<KNNResult> knnResults = parseSearchResponse(responseBody, FIELD_NAME);
List<Float> knnResults = parseSearchResponseScore(responseBody, FIELD_NAME);
assertEquals(K, knnResults.size());

// Do exact search and gather right scores for the documents
Response exactSearchResponse = searchKNNIndex(
indexName,
XContentFactory.jsonBuilder()
.startObject()
.startObject("query")
.startObject("script_score")
.startObject("query")
.field("match_all")
.startObject()
.endObject()
.endObject()
.startObject("script")
.field("source", "knn_score")
.field("lang", "knn")
.startObject("params")
.field("field", FIELD_NAME)
.field("query_value", TEST_VECTOR)
.field("space_type", SpaceType.L2.getValue())
.endObject()
.endObject()
.endObject()
.endObject()
.endObject(),
K
);
assertOK(exactSearchResponse);
String exactSearchResponseBody = EntityUtils.toString(exactSearchResponse.getEntity());
List<Float> exactSearchKnnResults = parseSearchResponseScore(exactSearchResponseBody, FIELD_NAME);
assertEquals(NUM_DOCS, exactSearchKnnResults.size());
Assert.assertEquals(exactSearchKnnResults, knnResults);

// Search with rescore
response = searchKNNIndex(
indexName,
Expand All @@ -415,7 +462,8 @@ private void validateSearch(String indexName, String methodParameterName, int me
);
assertOK(response);
responseBody = EntityUtils.toString(response.getEntity());
knnResults = parseSearchResponse(responseBody, FIELD_NAME);
knnResults = parseSearchResponseScore(responseBody, FIELD_NAME);
assertEquals(K, knnResults.size());
Assert.assertEquals(exactSearchKnnResults, knnResults);
}
}

0 comments on commit 8f6b177

Please sign in to comment.