From 0492fb350fa3e398871b4a3b477a7eea921818b1 Mon Sep 17 00:00:00 2001 From: Tejas Shah Date: Sun, 8 Sep 2024 18:11:51 -0700 Subject: [PATCH] Bounds the transferLimit in OffheapVectorTransfer (#2070) The array list buffer was unnecessarily allocated a large memory irrespective of the vectors to transfer. This change considers total vectors Signed-off-by: Tejas Shah --- .../DefaultIndexBuildStrategy.java | 10 +- .../MemOptimizedNativeIndexBuildStrategy.java | 12 ++- .../transfer/OffHeapBinaryVectorTransfer.java | 4 +- .../transfer/OffHeapByteVectorTransfer.java | 4 +- .../transfer/OffHeapFloatVectorTransfer.java | 4 +- .../codec/transfer/OffHeapVectorTransfer.java | 17 ++- .../OffHeapVectorTransferFactory.java | 15 ++- .../DefaultIndexBuildStrategyTests.java | 11 +- ...ptimizedNativeIndexBuildStrategyTests.java | 12 +-- .../OffHeapVectorTransferFactoryTests.java | 26 +++-- .../transfer/OffHeapVectorTransferTests.java | 101 ++++++++++-------- 11 files changed, 126 insertions(+), 90 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategy.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategy.java index e68121a7d..476c95b8d 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategy.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategy.java @@ -8,7 +8,6 @@ import lombok.AccessLevel; import lombok.NoArgsConstructor; import org.opensearch.knn.common.KNNConstants; -import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; import org.opensearch.knn.index.codec.transfer.OffHeapVectorTransfer; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; @@ -56,8 +55,13 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo) throws IOExcept iterateVectorValuesOnce(knnVectorValues); IndexBuildSetup indexBuildSetup = QuantizationIndexUtils.prepareIndexBuild(knnVectorValues, indexInfo); - int transferLimit = (int) Math.max(1, KNNSettings.getVectorStreamingMemoryLimit().getBytes() / indexBuildSetup.getBytesPerVector()); - try (final OffHeapVectorTransfer vectorTransfer = getVectorTransfer(indexInfo.getVectorDataType(), transferLimit)) { + try ( + final OffHeapVectorTransfer vectorTransfer = getVectorTransfer( + indexInfo.getVectorDataType(), + indexBuildSetup.getBytesPerVector(), + indexInfo.getTotalLiveDocs() + ) + ) { final List transferredDocIds = new ArrayList<>(indexInfo.getTotalLiveDocs()); while (knnVectorValues.docId() != NO_MORE_DOCS) { diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategy.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategy.java index af3f4777f..b7e337081 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategy.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategy.java @@ -7,7 +7,6 @@ import lombok.AccessLevel; import lombok.NoArgsConstructor; -import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; import org.opensearch.knn.index.codec.transfer.OffHeapVectorTransfer; import org.opensearch.knn.index.engine.KNNEngine; @@ -70,10 +69,15 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo) throws IOExcept ) ); - int transferLimit = (int) Math.max(1, KNNSettings.getVectorStreamingMemoryLimit().getBytes() / indexBuildSetup.getBytesPerVector()); - try (final OffHeapVectorTransfer vectorTransfer = getVectorTransfer(indexInfo.getVectorDataType(), transferLimit)) { + try ( + final OffHeapVectorTransfer vectorTransfer = getVectorTransfer( + indexInfo.getVectorDataType(), + indexBuildSetup.getBytesPerVector(), + indexInfo.getTotalLiveDocs() + ) + ) { - final List transferredDocIds = new ArrayList<>(transferLimit); + final List transferredDocIds = new ArrayList<>(vectorTransfer.getTransferLimit()); while (knnVectorValues.docId() != NO_MORE_DOCS) { Object vector = QuantizationIndexUtils.processAndReturnVector(knnVectorValues, indexBuildSetup); diff --git a/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapBinaryVectorTransfer.java b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapBinaryVectorTransfer.java index ffa12a231..964007fc0 100644 --- a/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapBinaryVectorTransfer.java +++ b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapBinaryVectorTransfer.java @@ -17,8 +17,8 @@ */ public final class OffHeapBinaryVectorTransfer extends OffHeapVectorTransfer { - public OffHeapBinaryVectorTransfer(int transferLimit) { - super(transferLimit); + public OffHeapBinaryVectorTransfer(int bytesPerVector, int totalVectorsToTransfer) { + super(bytesPerVector, totalVectorsToTransfer); } @Override diff --git a/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapByteVectorTransfer.java b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapByteVectorTransfer.java index 83ebf2fa3..16e333478 100644 --- a/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapByteVectorTransfer.java +++ b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapByteVectorTransfer.java @@ -17,8 +17,8 @@ */ public final class OffHeapByteVectorTransfer extends OffHeapVectorTransfer { - public OffHeapByteVectorTransfer(int transferLimit) { - super(transferLimit); + public OffHeapByteVectorTransfer(int bytesPerVector, int totalVectorsToTransfer) { + super(bytesPerVector, totalVectorsToTransfer); } @Override diff --git a/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapFloatVectorTransfer.java b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapFloatVectorTransfer.java index 0eb28d791..767f57271 100644 --- a/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapFloatVectorTransfer.java +++ b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapFloatVectorTransfer.java @@ -15,8 +15,8 @@ */ public final class OffHeapFloatVectorTransfer extends OffHeapVectorTransfer { - public OffHeapFloatVectorTransfer(int transferLimit) { - super(transferLimit); + public OffHeapFloatVectorTransfer(int bytesPerVector, int totalVectorsToTransfer) { + super(bytesPerVector, totalVectorsToTransfer); } @Override diff --git a/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransfer.java b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransfer.java index 43c27c8da..8a248e06c 100644 --- a/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransfer.java +++ b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransfer.java @@ -6,6 +6,7 @@ package org.opensearch.knn.index.codec.transfer; import lombok.Getter; +import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; import java.io.Closeable; @@ -27,16 +28,22 @@ public abstract class OffHeapVectorTransfer implements Closeable { @Getter private long vectorAddress; + @Getter protected final int transferLimit; - private final List vectorsToTransfer; + private List vectorsToTransfer; - public OffHeapVectorTransfer(final int transferLimit) { - this.transferLimit = transferLimit; - this.vectorsToTransfer = new ArrayList<>(transferLimit); + public OffHeapVectorTransfer(int bytesPerVector, int totalVectorsToTransfer) { + this.transferLimit = computeTransferLimit(bytesPerVector, totalVectorsToTransfer); + this.vectorsToTransfer = new ArrayList<>(this.transferLimit); this.vectorAddress = 0; } + private int computeTransferLimit(int bytesPerVector, int totalVectorsToTransfer) { + int limit = (int) Math.max(1, KNNSettings.getVectorStreamingMemoryLimit().getBytes() / bytesPerVector); + return Math.min(limit, totalVectorsToTransfer); + } + /** * Transfer vectors to off-heap * @param vector float[] or byte[] @@ -90,7 +97,7 @@ public void close() { */ public void reset() { vectorAddress = 0; - vectorsToTransfer.clear(); + vectorsToTransfer = null; } protected abstract void deallocate(); diff --git a/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransferFactory.java b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransferFactory.java index 446b6ae80..3bc55f7fa 100644 --- a/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransferFactory.java +++ b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransferFactory.java @@ -18,18 +18,23 @@ public final class OffHeapVectorTransferFactory { /** * Gets the right vector transfer object based on vector data type * @param vectorDataType {@link VectorDataType} - * @param transferLimit max number of vectors that can be transferred to off heap in one transfer + * @param bytesPerVector Bytes used per vector + * @param totalVectorsToTransfer total number of vectors that will be transferred off heap * @return Correct implementation of {@link OffHeapVectorTransfer} * @param float[] or byte[] */ - public static OffHeapVectorTransfer getVectorTransfer(final VectorDataType vectorDataType, final int transferLimit) { + public static OffHeapVectorTransfer getVectorTransfer( + final VectorDataType vectorDataType, + int bytesPerVector, + int totalVectorsToTransfer + ) { switch (vectorDataType) { case FLOAT: - return (OffHeapVectorTransfer) new OffHeapFloatVectorTransfer(transferLimit); + return (OffHeapVectorTransfer) new OffHeapFloatVectorTransfer(bytesPerVector, totalVectorsToTransfer); case BINARY: - return (OffHeapVectorTransfer) new OffHeapBinaryVectorTransfer(transferLimit); + return (OffHeapVectorTransfer) new OffHeapBinaryVectorTransfer(bytesPerVector, totalVectorsToTransfer); case BYTE: - return (OffHeapVectorTransfer) new OffHeapByteVectorTransfer(transferLimit); + return (OffHeapVectorTransfer) new OffHeapByteVectorTransfer(bytesPerVector, totalVectorsToTransfer); default: throw new IllegalArgumentException("Unsupported vector data type: " + vectorDataType); } diff --git a/src/test/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategyTests.java b/src/test/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategyTests.java index 96af0db19..9c2e5a4b7 100644 --- a/src/test/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategyTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategyTests.java @@ -57,14 +57,11 @@ public void testBuildAndWrite() { final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, randomVectorValues); try ( - MockedStatic mockedKNNSettings = mockStatic(KNNSettings.class); MockedStatic mockedJNIService = mockStatic(JNIService.class); MockedStatic mockedOffHeapVectorTransferFactory = mockStatic(OffHeapVectorTransferFactory.class) ) { - - mockedKNNSettings.when(KNNSettings::getVectorStreamingMemoryLimit).thenReturn(new ByteSizeValue(16)); OffHeapVectorTransfer offHeapVectorTransfer = mock(OffHeapVectorTransfer.class); - mockedOffHeapVectorTransferFactory.when(() -> OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.FLOAT, 2)) + mockedOffHeapVectorTransferFactory.when(() -> OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.FLOAT, 8, 3)) .thenReturn(offHeapVectorTransfer); when(offHeapVectorTransfer.getVectorAddress()).thenReturn(200L); @@ -131,7 +128,7 @@ public void testBuildAndWrite_withQuantization() { mockedJNIService.when(() -> JNIService.initIndex(3, 2, Map.of("index", "param"), KNNEngine.FAISS)).thenReturn(100L); OffHeapVectorTransfer offHeapVectorTransfer = mock(OffHeapVectorTransfer.class); - mockedOffHeapVectorTransferFactory.when(() -> OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.FLOAT, 2)) + mockedOffHeapVectorTransferFactory.when(() -> OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.FLOAT, 8, 3)) .thenReturn(offHeapVectorTransfer); QuantizationService quantizationService = mock(QuantizationService.class); @@ -237,14 +234,12 @@ public void testBuildAndWriteWithModel() { docs ); try ( - MockedStatic mockedKNNSettings = mockStatic(KNNSettings.class); MockedStatic mockedJNIService = mockStatic(JNIService.class); MockedStatic mockedOffHeapVectorTransferFactory = mockStatic(OffHeapVectorTransferFactory.class) ) { - mockedKNNSettings.when(KNNSettings::getVectorStreamingMemoryLimit).thenReturn(new ByteSizeValue(16)); OffHeapVectorTransfer offHeapVectorTransfer = mock(OffHeapVectorTransfer.class); - mockedOffHeapVectorTransferFactory.when(() -> OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.FLOAT, 2)) + mockedOffHeapVectorTransferFactory.when(() -> OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.FLOAT, 8, 3)) .thenReturn(offHeapVectorTransfer); when(offHeapVectorTransfer.getVectorAddress()).thenReturn(200L); diff --git a/src/test/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategyTests.java b/src/test/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategyTests.java index 22f9b2dfd..62c3b7a71 100644 --- a/src/test/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategyTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategyTests.java @@ -9,8 +9,6 @@ import org.mockito.ArgumentCaptor; import org.mockito.MockedStatic; import org.mockito.Mockito; -import org.opensearch.core.common.unit.ByteSizeValue; -import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; import org.opensearch.knn.index.codec.transfer.OffHeapVectorTransfer; @@ -49,7 +47,6 @@ public void testBuildAndWrite() { final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, randomVectorValues); try ( - MockedStatic mockedKNNSettings = Mockito.mockStatic(KNNSettings.class); MockedStatic mockedJNIService = Mockito.mockStatic(JNIService.class); MockedStatic mockedOffHeapVectorTransferFactory = Mockito.mockStatic( OffHeapVectorTransferFactory.class @@ -57,13 +54,13 @@ public void testBuildAndWrite() { ) { // Limits transfer to 2 vectors - mockedKNNSettings.when(KNNSettings::getVectorStreamingMemoryLimit).thenReturn(new ByteSizeValue(16)); mockedJNIService.when(() -> JNIService.initIndex(3, 2, Map.of("index", "param"), KNNEngine.FAISS)).thenReturn(100L); OffHeapVectorTransfer offHeapVectorTransfer = mock(OffHeapVectorTransfer.class); - mockedOffHeapVectorTransferFactory.when(() -> OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.FLOAT, 2)) + mockedOffHeapVectorTransferFactory.when(() -> OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.FLOAT, 8, 3)) .thenReturn(offHeapVectorTransfer); + when(offHeapVectorTransfer.getTransferLimit()).thenReturn(2); when(offHeapVectorTransfer.transfer(vectorTransferCapture.capture(), eq(false))).thenReturn(false) .thenReturn(true) .thenReturn(false); @@ -145,7 +142,6 @@ public void testBuildAndWrite_withQuantization() { final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, randomVectorValues); try ( - MockedStatic mockedKNNSettings = Mockito.mockStatic(KNNSettings.class); MockedStatic mockedJNIService = Mockito.mockStatic(JNIService.class); MockedStatic mockedOffHeapVectorTransferFactory = Mockito.mockStatic( OffHeapVectorTransferFactory.class @@ -154,11 +150,11 @@ public void testBuildAndWrite_withQuantization() { ) { // Limits transfer to 2 vectors - mockedKNNSettings.when(KNNSettings::getVectorStreamingMemoryLimit).thenReturn(new ByteSizeValue(16)); mockedJNIService.when(() -> JNIService.initIndex(3, 2, Map.of("index", "param"), KNNEngine.FAISS)).thenReturn(100L); OffHeapVectorTransfer offHeapVectorTransfer = mock(OffHeapVectorTransfer.class); - mockedOffHeapVectorTransferFactory.when(() -> OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.FLOAT, 2)) + when(offHeapVectorTransfer.getTransferLimit()).thenReturn(2); + mockedOffHeapVectorTransferFactory.when(() -> OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.FLOAT, 8, 3)) .thenReturn(offHeapVectorTransfer); QuantizationService quantizationService = mock(QuantizationService.class); diff --git a/src/test/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransferFactoryTests.java b/src/test/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransferFactoryTests.java index 39415d811..09984ba46 100644 --- a/src/test/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransferFactoryTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransferFactoryTests.java @@ -5,22 +5,30 @@ package org.opensearch.knn.index.codec.transfer; +import org.mockito.MockedStatic; +import org.opensearch.core.common.unit.ByteSizeValue; +import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.VectorDataType; import org.opensearch.test.OpenSearchTestCase; +import static org.mockito.Mockito.mockStatic; + public class OffHeapVectorTransferFactoryTests extends OpenSearchTestCase { public void testOffHeapVectorTransferFactory() { - var floatVectorTransfer = OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.FLOAT, 10); - assertEquals(OffHeapFloatVectorTransfer.class, floatVectorTransfer.getClass()); - assertNotSame(floatVectorTransfer, OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.FLOAT, 10)); + try (MockedStatic mockedKNNSettings = mockStatic(KNNSettings.class)) { + mockedKNNSettings.when(KNNSettings::getVectorStreamingMemoryLimit).thenReturn(new ByteSizeValue(16)); + var floatVectorTransfer = OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.FLOAT, 10, 10); + assertEquals(OffHeapFloatVectorTransfer.class, floatVectorTransfer.getClass()); + assertNotSame(floatVectorTransfer, OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.FLOAT, 10, 10)); - var byteVectorTransfer = OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.BYTE, 10); - assertEquals(OffHeapByteVectorTransfer.class, byteVectorTransfer.getClass()); - assertNotSame(byteVectorTransfer, OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.BYTE, 10)); + var byteVectorTransfer = OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.BYTE, 10, 10); + assertEquals(OffHeapByteVectorTransfer.class, byteVectorTransfer.getClass()); + assertNotSame(byteVectorTransfer, OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.BYTE, 10, 10)); - var binaryVectorTransfer = OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.BINARY, 10); - assertEquals(OffHeapBinaryVectorTransfer.class, binaryVectorTransfer.getClass()); - assertNotSame(binaryVectorTransfer, OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.BINARY, 10)); + var binaryVectorTransfer = OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.BINARY, 10, 10); + assertEquals(OffHeapBinaryVectorTransfer.class, binaryVectorTransfer.getClass()); + assertNotSame(binaryVectorTransfer, OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.BINARY, 10, 10)); + } } } diff --git a/src/test/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransferTests.java b/src/test/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransferTests.java index f1650db8f..fb2ef274e 100644 --- a/src/test/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransferTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransferTests.java @@ -6,10 +6,15 @@ package org.opensearch.knn.index.codec.transfer; import lombok.SneakyThrows; +import org.mockito.MockedStatic; +import org.opensearch.core.common.unit.ByteSizeValue; import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.KNNSettings; import java.util.List; +import static org.mockito.Mockito.mockStatic; + public class OffHeapVectorTransferTests extends KNNTestCase { @SneakyThrows @@ -22,21 +27,27 @@ public void testFloatTransfer() { new float[] { 0.3f, 0.4f } ); - OffHeapFloatVectorTransfer vectorTransfer = new OffHeapFloatVectorTransfer(2); - long vectorAddress = 0; - assertFalse(vectorTransfer.transfer(vectors.get(0), false)); - assertEquals(0, vectorTransfer.getVectorAddress()); - assertTrue(vectorTransfer.transfer(vectors.get(1), false)); - vectorAddress = vectorTransfer.getVectorAddress(); - assertFalse(vectorTransfer.transfer(vectors.get(2), false)); - assertEquals(vectorAddress, vectorTransfer.getVectorAddress()); - assertTrue(vectorTransfer.transfer(vectors.get(3), false)); - assertEquals(vectorAddress, vectorTransfer.getVectorAddress()); - assertFalse(vectorTransfer.transfer(vectors.get(4), false)); - assertTrue(vectorTransfer.flush(false)); - vectorTransfer.reset(); - assertEquals(0, vectorTransfer.getVectorAddress()); - vectorTransfer.close(); + try (MockedStatic mockedKNNSettings = mockStatic(KNNSettings.class)) { + mockedKNNSettings.when(KNNSettings::getVectorStreamingMemoryLimit).thenReturn(new ByteSizeValue(16)); + + OffHeapFloatVectorTransfer vectorTransfer = new OffHeapFloatVectorTransfer(8, 5); + long vectorAddress = 0; + assertFalse(vectorTransfer.transfer(vectors.get(0), false)); + assertEquals(0, vectorTransfer.getVectorAddress()); + assertTrue(vectorTransfer.transfer(vectors.get(1), false)); + vectorAddress = vectorTransfer.getVectorAddress(); + assertFalse(vectorTransfer.transfer(vectors.get(2), false)); + assertEquals(vectorAddress, vectorTransfer.getVectorAddress()); + assertTrue(vectorTransfer.transfer(vectors.get(3), false)); + assertEquals(vectorAddress, vectorTransfer.getVectorAddress()); + assertFalse(vectorTransfer.transfer(vectors.get(4), false)); + assertTrue(vectorTransfer.flush(false)); + vectorTransfer.reset(); + assertEquals(0, vectorTransfer.getVectorAddress()); + vectorTransfer.close(); + + } + } @SneakyThrows @@ -49,20 +60,23 @@ public void testByteTransfer() { new byte[] { 8, 9 } ); - OffHeapByteVectorTransfer vectorTransfer = new OffHeapByteVectorTransfer(2); - long vectorAddress = 0; - assertFalse(vectorTransfer.transfer(vectors.get(0), false)); - assertEquals(0, vectorTransfer.getVectorAddress()); - assertTrue(vectorTransfer.transfer(vectors.get(1), false)); - vectorAddress = vectorTransfer.getVectorAddress(); - assertFalse(vectorTransfer.transfer(vectors.get(2), false)); - assertEquals(vectorAddress, vectorTransfer.getVectorAddress()); - assertTrue(vectorTransfer.transfer(vectors.get(3), false)); - assertEquals(vectorAddress, vectorTransfer.getVectorAddress()); - assertFalse(vectorTransfer.transfer(vectors.get(4), false)); - assertTrue(vectorTransfer.flush(false)); - vectorTransfer.close(); - assertEquals(0, vectorTransfer.getVectorAddress()); + try (MockedStatic mockedKNNSettings = mockStatic(KNNSettings.class)) { + mockedKNNSettings.when(KNNSettings::getVectorStreamingMemoryLimit).thenReturn(new ByteSizeValue(4)); + OffHeapByteVectorTransfer vectorTransfer = new OffHeapByteVectorTransfer(2, 5); + long vectorAddress = 0; + assertFalse(vectorTransfer.transfer(vectors.get(0), false)); + assertEquals(0, vectorTransfer.getVectorAddress()); + assertTrue(vectorTransfer.transfer(vectors.get(1), false)); + vectorAddress = vectorTransfer.getVectorAddress(); + assertFalse(vectorTransfer.transfer(vectors.get(2), false)); + assertEquals(vectorAddress, vectorTransfer.getVectorAddress()); + assertTrue(vectorTransfer.transfer(vectors.get(3), false)); + assertEquals(vectorAddress, vectorTransfer.getVectorAddress()); + assertFalse(vectorTransfer.transfer(vectors.get(4), false)); + assertTrue(vectorTransfer.flush(false)); + vectorTransfer.close(); + assertEquals(0, vectorTransfer.getVectorAddress()); + } } @SneakyThrows @@ -75,18 +89,21 @@ public void testBinaryTransfer() { new byte[] { 8, 9 } ); - OffHeapBinaryVectorTransfer vectorTransfer = new OffHeapBinaryVectorTransfer(2); - long vectorAddress = 0; - assertFalse(vectorTransfer.transfer(vectors.get(0), false)); - assertEquals(0, vectorTransfer.getVectorAddress()); - assertTrue(vectorTransfer.transfer(vectors.get(1), false)); - vectorAddress = vectorTransfer.getVectorAddress(); - assertFalse(vectorTransfer.transfer(vectors.get(2), false)); - assertEquals(vectorAddress, vectorTransfer.getVectorAddress()); - assertTrue(vectorTransfer.transfer(vectors.get(3), false)); - assertEquals(vectorAddress, vectorTransfer.getVectorAddress()); - assertFalse(vectorTransfer.transfer(vectors.get(4), false)); - assertTrue(vectorTransfer.flush(false)); - vectorTransfer.close(); + try (MockedStatic mockedKNNSettings = mockStatic(KNNSettings.class)) { + mockedKNNSettings.when(KNNSettings::getVectorStreamingMemoryLimit).thenReturn(new ByteSizeValue(4)); + OffHeapBinaryVectorTransfer vectorTransfer = new OffHeapBinaryVectorTransfer(2, 5); + long vectorAddress = 0; + assertFalse(vectorTransfer.transfer(vectors.get(0), false)); + assertEquals(0, vectorTransfer.getVectorAddress()); + assertTrue(vectorTransfer.transfer(vectors.get(1), false)); + vectorAddress = vectorTransfer.getVectorAddress(); + assertFalse(vectorTransfer.transfer(vectors.get(2), false)); + assertEquals(vectorAddress, vectorTransfer.getVectorAddress()); + assertTrue(vectorTransfer.transfer(vectors.get(3), false)); + assertEquals(vectorAddress, vectorTransfer.getVectorAddress()); + assertFalse(vectorTransfer.transfer(vectors.get(4), false)); + assertTrue(vectorTransfer.flush(false)); + vectorTransfer.close(); + } } }