diff --git a/src/main/java/com/ning/compress/lzf/ChunkEncoder.java b/src/main/java/com/ning/compress/lzf/ChunkEncoder.java index 6c5df0a..df3f51c 100644 --- a/src/main/java/com/ning/compress/lzf/ChunkEncoder.java +++ b/src/main/java/com/ning/compress/lzf/ChunkEncoder.java @@ -93,8 +93,11 @@ protected ChunkEncoder(int totalLength) */ protected ChunkEncoder(int totalLength, BufferRecycler bufferRecycler) { + if (totalLength <= 0) { + throw new IllegalArgumentException("Invalid total length: " + totalLength); + } // Need room for at most a single full chunk - int largestChunkLen = Math.min(totalLength, LZFChunk.MAX_CHUNK_LEN); + int largestChunkLen = Math.min(totalLength, LZFChunk.MAX_CHUNK_LEN); int suggestedHashLen = calcHashLen(largestChunkLen); _recycler = bufferRecycler; _hashTable = bufferRecycler.allocEncodingHash(suggestedHashLen); diff --git a/src/main/java/com/ning/compress/lzf/impl/UnsafeChunkDecoder.java b/src/main/java/com/ning/compress/lzf/impl/UnsafeChunkDecoder.java index 1853d37..c4f04e1 100644 --- a/src/main/java/com/ning/compress/lzf/impl/UnsafeChunkDecoder.java +++ b/src/main/java/com/ning/compress/lzf/impl/UnsafeChunkDecoder.java @@ -71,6 +71,10 @@ public final int decodeChunk(final InputStream is, final byte[] inputBuffer, fin public final void decodeChunk(byte[] in, int inPos, byte[] out, int outPos, int outEnd) throws LZFException { + // Sanity checks; otherwise if any of the arguments are invalid `Unsafe` might corrupt memory + checkArrayIndices(in, inPos, in.length); + checkArrayIndices(out, outPos, outEnd); + // We need to take care of end condition, leave last 32 bytes out final int outputEnd8 = outEnd - 8; final int outputEnd32 = outEnd - 32; @@ -175,7 +179,17 @@ public int skipOrDecodeChunk(final InputStream is, final byte[] inputBuffer, // Internal methods /////////////////////////////////////////////////////////////////////// */ - + + /** + * @param start start index, inclusive + * @param end end index, exclusive + */ + private final void checkArrayIndices(byte[] array, int start, int end) { + if (start < 0 || end < start || end > array.length) { + throw new ArrayIndexOutOfBoundsException(); + } + } + private final int copyOverlappingShort(final byte[] out, int outPos, final int offset, int len) { out[outPos] = out[outPos++ + offset]; diff --git a/src/main/java/com/ning/compress/lzf/impl/UnsafeChunkEncoder.java b/src/main/java/com/ning/compress/lzf/impl/UnsafeChunkEncoder.java index 6dab20e..a64d497 100644 --- a/src/main/java/com/ning/compress/lzf/impl/UnsafeChunkEncoder.java +++ b/src/main/java/com/ning/compress/lzf/impl/UnsafeChunkEncoder.java @@ -21,7 +21,7 @@ public abstract class UnsafeChunkEncoder { // // Our Nitro Booster, mr. Unsafe! - protected static final Unsafe unsafe; + static final Unsafe unsafe; static { try { Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe"); @@ -33,23 +33,26 @@ public abstract class UnsafeChunkEncoder } } - protected static final long BYTE_ARRAY_OFFSET = unsafe.arrayBaseOffset(byte[].class); + // All members here (fields, constructors, methods) are at most package-private; users are + // not supposed to subclass this class - protected static final long BYTE_ARRAY_OFFSET_PLUS2 = BYTE_ARRAY_OFFSET + 2; - - public UnsafeChunkEncoder(int totalLength) { + static final long BYTE_ARRAY_OFFSET = unsafe.arrayBaseOffset(byte[].class); + + static final long BYTE_ARRAY_OFFSET_PLUS2 = BYTE_ARRAY_OFFSET + 2; + + UnsafeChunkEncoder(int totalLength) { super(totalLength); } - public UnsafeChunkEncoder(int totalLength, boolean bogus) { + UnsafeChunkEncoder(int totalLength, boolean bogus) { super(totalLength, bogus); } - public UnsafeChunkEncoder(int totalLength, BufferRecycler bufferRecycler) { + UnsafeChunkEncoder(int totalLength, BufferRecycler bufferRecycler) { super(totalLength, bufferRecycler); } - public UnsafeChunkEncoder(int totalLength, BufferRecycler bufferRecycler, boolean bogus) { + UnsafeChunkEncoder(int totalLength, BufferRecycler bufferRecycler, boolean bogus) { super(totalLength, bufferRecycler, bogus); } @@ -59,7 +62,17 @@ public UnsafeChunkEncoder(int totalLength, BufferRecycler bufferRecycler, boolea /////////////////////////////////////////////////////////////////////// */ - protected final static int _copyPartialLiterals(byte[] in, int inPos, byte[] out, int outPos, + /** + * @param start start index, inclusive + * @param end end index, exclusive + */ + final static void _checkArrayIndices(byte[] array, int start, int end) { + if (start < 0 || end < start || end > array.length) { + throw new ArrayIndexOutOfBoundsException(); + } + } + + final static int _copyPartialLiterals(byte[] in, int inPos, byte[] out, int outPos, int literals) { out[outPos++] = (byte) (literals-1); @@ -94,7 +107,7 @@ protected final static int _copyPartialLiterals(byte[] in, int inPos, byte[] out return outPos+literals; } - protected final static int _copyLongLiterals(byte[] in, int inPos, byte[] out, int outPos, + final static int _copyLongLiterals(byte[] in, int inPos, byte[] out, int outPos, int literals) { inPos -= literals; @@ -129,7 +142,7 @@ protected final static int _copyLongLiterals(byte[] in, int inPos, byte[] out, i return outPos; } - protected final static int _copyFullLiterals(byte[] in, int inPos, byte[] out, int outPos) + final static int _copyFullLiterals(byte[] in, int inPos, byte[] out, int outPos) { // literals == 32 out[outPos++] = (byte) 31; @@ -151,7 +164,7 @@ protected final static int _copyFullLiterals(byte[] in, int inPos, byte[] out, i return (outPos + 32); } - protected final static int _handleTail(byte[] in, int inPos, int inEnd, byte[] out, int outPos, + final static int _handleTail(byte[] in, int inPos, int inEnd, byte[] out, int outPos, int literals) { while (inPos < inEnd) { @@ -172,7 +185,7 @@ protected final static int _handleTail(byte[] in, int inPos, int inEnd, byte[] o return outPos; } - protected final static int _findTailMatchLength(final byte[] in, int ptr1, int ptr2, final int maxPtr1) + final static int _findTailMatchLength(final byte[] in, int ptr1, int ptr2, final int maxPtr1) { final int start1 = ptr1; while (ptr1 < maxPtr1 && in[ptr1] == in[ptr2]) { diff --git a/src/main/java/com/ning/compress/lzf/impl/UnsafeChunkEncoderBE.java b/src/main/java/com/ning/compress/lzf/impl/UnsafeChunkEncoderBE.java index 19114a0..1c6751a 100644 --- a/src/main/java/com/ning/compress/lzf/impl/UnsafeChunkEncoderBE.java +++ b/src/main/java/com/ning/compress/lzf/impl/UnsafeChunkEncoderBE.java @@ -29,6 +29,11 @@ public UnsafeChunkEncoderBE(int totalLength, BufferRecycler bufferRecycler, bool @Override protected int tryCompress(byte[] in, int inPos, int inEnd, byte[] out, int outPos) { + // Sanity checks; otherwise if any of the arguments are invalid `Unsafe` might corrupt memory + _checkArrayIndices(in, inPos, inEnd); + _checkArrayIndices(out, outPos, out.length); + // TODO: Validate that `out.length - outPos` is large enough? + final int[] hashTable = _hashTable; int literals = 0; inEnd -= TAIL_LENGTH; @@ -83,7 +88,7 @@ protected int tryCompress(byte[] in, int inPos, int inEnd, byte[] out, int outPo ++inPos; } // try offlining the tail - return _handleTail(in, inPos, inEnd+4, out, outPos, literals); + return _handleTail(in, inPos, inEnd+TAIL_LENGTH, out, outPos, literals); } private final static int _getInt(final byte[] in, final int inPos) { diff --git a/src/main/java/com/ning/compress/lzf/impl/UnsafeChunkEncoderLE.java b/src/main/java/com/ning/compress/lzf/impl/UnsafeChunkEncoderLE.java index 867428f..f87d98b 100644 --- a/src/main/java/com/ning/compress/lzf/impl/UnsafeChunkEncoderLE.java +++ b/src/main/java/com/ning/compress/lzf/impl/UnsafeChunkEncoderLE.java @@ -7,7 +7,7 @@ * Implementation to use on Little Endian architectures. */ @SuppressWarnings("restriction") -public class UnsafeChunkEncoderLE +public final class UnsafeChunkEncoderLE extends UnsafeChunkEncoder { public UnsafeChunkEncoderLE(int totalLength) { @@ -29,6 +29,11 @@ public UnsafeChunkEncoderLE(int totalLength, BufferRecycler bufferRecycler, bool @Override protected int tryCompress(byte[] in, int inPos, int inEnd, byte[] out, int outPos) { + // Sanity checks; otherwise if any of the arguments are invalid `Unsafe` might corrupt memory + _checkArrayIndices(in, inPos, inEnd); + _checkArrayIndices(out, outPos, out.length); + // TODO: Validate that `out.length - outPos` is large enough? + final int[] hashTable = _hashTable; int literals = 0; inEnd -= TAIL_LENGTH; @@ -85,7 +90,7 @@ protected int tryCompress(byte[] in, int inPos, int inEnd, byte[] out, int outPo ++inPos; } // try offlining the tail - return _handleTail(in, inPos, inEnd+4, out, outPos, literals); + return _handleTail(in, inPos, inEnd+TAIL_LENGTH, out, outPos, literals); } private final static int _getInt(final byte[] in, final int inPos) { diff --git a/src/test/java/com/ning/compress/lzf/LZFEncoderTest.java b/src/test/java/com/ning/compress/lzf/LZFEncoderTest.java index 1a5cef2..b514ff6 100644 --- a/src/test/java/com/ning/compress/lzf/LZFEncoderTest.java +++ b/src/test/java/com/ning/compress/lzf/LZFEncoderTest.java @@ -7,6 +7,9 @@ import org.testng.annotations.Test; import com.ning.compress.BaseForTests; +import com.ning.compress.lzf.impl.UnsafeChunkEncoder; +import com.ning.compress.lzf.impl.UnsafeChunkEncoderBE; +import com.ning.compress.lzf.impl.UnsafeChunkEncoderLE; import com.ning.compress.lzf.util.ChunkEncoderFactory; public class LZFEncoderTest extends BaseForTests @@ -169,4 +172,26 @@ private void _testConditionalCompression(ChunkEncoder enc, final byte[] input) t chunk = enc.encodeChunkIfCompresses(input, 0, input.length, 0.60); Assert.assertNull(chunk); } + + @Test + public void testUnsafeValidation() { + _testUnsafeValidation(new UnsafeChunkEncoderBE(10)); + _testUnsafeValidation(new UnsafeChunkEncoderLE(10)); + + } + + private void _testUnsafeValidation(UnsafeChunkEncoder encoder) { + byte[] array = new byte[10]; + int goodStart = 2; + int goodEnd = 5; + + Assert.assertThrows(NullPointerException.class, () -> encoder.tryCompress(null, goodStart, goodEnd, array, goodStart)); + Assert.assertThrows(NullPointerException.class, () -> encoder.tryCompress(array, goodStart, goodEnd, null, goodStart)); + Assert.assertThrows(ArrayIndexOutOfBoundsException.class, () -> encoder.tryCompress(array, -1, goodEnd, array, goodStart)); + Assert.assertThrows(ArrayIndexOutOfBoundsException.class, () -> encoder.tryCompress(array, 12, goodEnd, array, goodStart)); + Assert.assertThrows(ArrayIndexOutOfBoundsException.class, () -> encoder.tryCompress(array, goodStart, 1, array, goodStart)); + Assert.assertThrows(ArrayIndexOutOfBoundsException.class, () -> encoder.tryCompress(array, goodStart, 12, array, goodStart)); + Assert.assertThrows(ArrayIndexOutOfBoundsException.class, () -> encoder.tryCompress(array, goodStart, goodEnd, array, -1)); + Assert.assertThrows(ArrayIndexOutOfBoundsException.class, () -> encoder.tryCompress(array, goodStart, goodEnd, array, 12)); + } } diff --git a/src/test/java/com/ning/compress/lzf/TestLZFDecoder.java b/src/test/java/com/ning/compress/lzf/TestLZFDecoder.java index b2718c4..ff27673 100644 --- a/src/test/java/com/ning/compress/lzf/TestLZFDecoder.java +++ b/src/test/java/com/ning/compress/lzf/TestLZFDecoder.java @@ -6,6 +6,7 @@ import org.testng.annotations.Test; import com.ning.compress.BaseForTests; +import com.ning.compress.lzf.impl.UnsafeChunkDecoder; import com.ning.compress.lzf.util.ChunkDecoderFactory; public class TestLZFDecoder extends BaseForTests @@ -28,6 +29,22 @@ public void testChunks() throws IOException { _testChunks(ChunkDecoderFactory.optimalInstance()); } + @Test + public void testUnsafeValidation() { + UnsafeChunkDecoder decoder = new UnsafeChunkDecoder(); + + byte[] array = new byte[10]; + int goodStart = 2; + int goodEnd = 5; + Assert.assertThrows(NullPointerException.class, () -> decoder.decodeChunk(null, goodStart, array, goodStart, goodEnd)); + Assert.assertThrows(NullPointerException.class, () -> decoder.decodeChunk(array, goodStart, null, goodStart, goodEnd)); + Assert.assertThrows(ArrayIndexOutOfBoundsException.class, () -> decoder.decodeChunk(array, -1, array, goodStart, goodEnd)); + Assert.assertThrows(ArrayIndexOutOfBoundsException.class, () -> decoder.decodeChunk(array, 12, array, goodStart, goodEnd)); + Assert.assertThrows(ArrayIndexOutOfBoundsException.class, () -> decoder.decodeChunk(array, goodStart, array, -1, goodEnd)); + Assert.assertThrows(ArrayIndexOutOfBoundsException.class, () -> decoder.decodeChunk(array, goodStart, array, goodStart, 1)); + Assert.assertThrows(ArrayIndexOutOfBoundsException.class, () -> decoder.decodeChunk(array, goodStart, array, goodStart, 12)); + } + /* /////////////////////////////////////////////////////////////////////// // Second-level test methods