Skip to content

Commit

Permalink
Validate arguments for Unsafe coders
Browse files Browse the repository at this point in the history
  • Loading branch information
Marcono1234 committed Feb 18, 2024
1 parent b206cc9 commit 9cd0e20
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 18 deletions.
5 changes: 4 additions & 1 deletion src/main/java/com/ning/compress/lzf/ChunkEncoder.java
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,11 @@ protected ChunkEncoder(int totalLength)
*/
protected ChunkEncoder(int totalLength, BufferRecycler bufferRecycler)
{
if (totalLength <= 0) {
throw new IllegalArgumentException("Invalid total length: " + totalLength);
}
// Need room for at most a single full chunk
int largestChunkLen = Math.min(totalLength, LZFChunk.MAX_CHUNK_LEN);
int largestChunkLen = Math.min(totalLength, LZFChunk.MAX_CHUNK_LEN);
int suggestedHashLen = calcHashLen(largestChunkLen);
_recycler = bufferRecycler;
_hashTable = bufferRecycler.allocEncodingHash(suggestedHashLen);
Expand Down
16 changes: 15 additions & 1 deletion src/main/java/com/ning/compress/lzf/impl/UnsafeChunkDecoder.java
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ public final int decodeChunk(final InputStream is, final byte[] inputBuffer, fin
public final void decodeChunk(byte[] in, int inPos, byte[] out, int outPos, int outEnd)
throws LZFException
{
// Sanity checks; otherwise if any of the arguments are invalid `Unsafe` might corrupt memory
checkArrayIndices(in, inPos, in.length);
checkArrayIndices(out, outPos, outEnd);

// We need to take care of end condition, leave last 32 bytes out
final int outputEnd8 = outEnd - 8;
final int outputEnd32 = outEnd - 32;
Expand Down Expand Up @@ -175,7 +179,17 @@ public int skipOrDecodeChunk(final InputStream is, final byte[] inputBuffer,
// Internal methods
///////////////////////////////////////////////////////////////////////
*/


/**
* @param start start index, inclusive
* @param end end index, exclusive
*/
private final void checkArrayIndices(byte[] array, int start, int end) {
if (start < 0 || end < start || end > array.length) {
throw new ArrayIndexOutOfBoundsException();
}
}

private final int copyOverlappingShort(final byte[] out, int outPos, final int offset, int len)
{
out[outPos] = out[outPos++ + offset];
Expand Down
39 changes: 26 additions & 13 deletions src/main/java/com/ning/compress/lzf/impl/UnsafeChunkEncoder.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ public abstract class UnsafeChunkEncoder
{
// // Our Nitro Booster, mr. Unsafe!

protected static final Unsafe unsafe;
static final Unsafe unsafe;
static {
try {
Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe");
Expand All @@ -33,23 +33,26 @@ public abstract class UnsafeChunkEncoder
}
}

protected static final long BYTE_ARRAY_OFFSET = unsafe.arrayBaseOffset(byte[].class);
// All members here (fields, constructors, methods) are at most package-private; users are
// not supposed to subclass this class

protected static final long BYTE_ARRAY_OFFSET_PLUS2 = BYTE_ARRAY_OFFSET + 2;

public UnsafeChunkEncoder(int totalLength) {
static final long BYTE_ARRAY_OFFSET = unsafe.arrayBaseOffset(byte[].class);

static final long BYTE_ARRAY_OFFSET_PLUS2 = BYTE_ARRAY_OFFSET + 2;

UnsafeChunkEncoder(int totalLength) {
super(totalLength);
}

public UnsafeChunkEncoder(int totalLength, boolean bogus) {
UnsafeChunkEncoder(int totalLength, boolean bogus) {
super(totalLength, bogus);
}

public UnsafeChunkEncoder(int totalLength, BufferRecycler bufferRecycler) {
UnsafeChunkEncoder(int totalLength, BufferRecycler bufferRecycler) {
super(totalLength, bufferRecycler);
}

public UnsafeChunkEncoder(int totalLength, BufferRecycler bufferRecycler, boolean bogus) {
UnsafeChunkEncoder(int totalLength, BufferRecycler bufferRecycler, boolean bogus) {
super(totalLength, bufferRecycler, bogus);
}

Expand All @@ -59,7 +62,17 @@ public UnsafeChunkEncoder(int totalLength, BufferRecycler bufferRecycler, boolea
///////////////////////////////////////////////////////////////////////
*/

protected final static int _copyPartialLiterals(byte[] in, int inPos, byte[] out, int outPos,
/**
* @param start start index, inclusive
* @param end end index, exclusive
*/
final static void _checkArrayIndices(byte[] array, int start, int end) {
if (start < 0 || end < start || end > array.length) {
throw new ArrayIndexOutOfBoundsException();
}
}

final static int _copyPartialLiterals(byte[] in, int inPos, byte[] out, int outPos,
int literals)
{
out[outPos++] = (byte) (literals-1);
Expand Down Expand Up @@ -94,7 +107,7 @@ protected final static int _copyPartialLiterals(byte[] in, int inPos, byte[] out
return outPos+literals;
}

protected final static int _copyLongLiterals(byte[] in, int inPos, byte[] out, int outPos,
final static int _copyLongLiterals(byte[] in, int inPos, byte[] out, int outPos,
int literals)
{
inPos -= literals;
Expand Down Expand Up @@ -129,7 +142,7 @@ protected final static int _copyLongLiterals(byte[] in, int inPos, byte[] out, i
return outPos;
}

protected final static int _copyFullLiterals(byte[] in, int inPos, byte[] out, int outPos)
final static int _copyFullLiterals(byte[] in, int inPos, byte[] out, int outPos)
{
// literals == 32
out[outPos++] = (byte) 31;
Expand All @@ -151,7 +164,7 @@ protected final static int _copyFullLiterals(byte[] in, int inPos, byte[] out, i
return (outPos + 32);
}

protected final static int _handleTail(byte[] in, int inPos, int inEnd, byte[] out, int outPos,
final static int _handleTail(byte[] in, int inPos, int inEnd, byte[] out, int outPos,
int literals)
{
while (inPos < inEnd) {
Expand All @@ -172,7 +185,7 @@ protected final static int _handleTail(byte[] in, int inPos, int inEnd, byte[] o
return outPos;
}

protected final static int _findTailMatchLength(final byte[] in, int ptr1, int ptr2, final int maxPtr1)
final static int _findTailMatchLength(final byte[] in, int ptr1, int ptr2, final int maxPtr1)
{
final int start1 = ptr1;
while (ptr1 < maxPtr1 && in[ptr1] == in[ptr2]) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ public UnsafeChunkEncoderBE(int totalLength, BufferRecycler bufferRecycler, bool
@Override
protected int tryCompress(byte[] in, int inPos, int inEnd, byte[] out, int outPos)
{
// Sanity checks; otherwise if any of the arguments are invalid `Unsafe` might corrupt memory
_checkArrayIndices(in, inPos, inEnd);
_checkArrayIndices(out, outPos, out.length);
// TODO: Validate that `out.length - outPos` is large enough?

final int[] hashTable = _hashTable;
int literals = 0;
inEnd -= TAIL_LENGTH;
Expand Down Expand Up @@ -83,7 +88,7 @@ protected int tryCompress(byte[] in, int inPos, int inEnd, byte[] out, int outPo
++inPos;
}
// try offlining the tail
return _handleTail(in, inPos, inEnd+4, out, outPos, literals);
return _handleTail(in, inPos, inEnd+TAIL_LENGTH, out, outPos, literals);
}

private final static int _getInt(final byte[] in, final int inPos) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
* Implementation to use on Little Endian architectures.
*/
@SuppressWarnings("restriction")
public class UnsafeChunkEncoderLE
public final class UnsafeChunkEncoderLE
extends UnsafeChunkEncoder
{
public UnsafeChunkEncoderLE(int totalLength) {
Expand All @@ -29,6 +29,11 @@ public UnsafeChunkEncoderLE(int totalLength, BufferRecycler bufferRecycler, bool
@Override
protected int tryCompress(byte[] in, int inPos, int inEnd, byte[] out, int outPos)
{
// Sanity checks; otherwise if any of the arguments are invalid `Unsafe` might corrupt memory
_checkArrayIndices(in, inPos, inEnd);
_checkArrayIndices(out, outPos, out.length);
// TODO: Validate that `out.length - outPos` is large enough?

final int[] hashTable = _hashTable;
int literals = 0;
inEnd -= TAIL_LENGTH;
Expand Down Expand Up @@ -85,7 +90,7 @@ protected int tryCompress(byte[] in, int inPos, int inEnd, byte[] out, int outPo
++inPos;
}
// try offlining the tail
return _handleTail(in, inPos, inEnd+4, out, outPos, literals);
return _handleTail(in, inPos, inEnd+TAIL_LENGTH, out, outPos, literals);
}

private final static int _getInt(final byte[] in, final int inPos) {
Expand Down
25 changes: 25 additions & 0 deletions src/test/java/com/ning/compress/lzf/LZFEncoderTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
import org.testng.annotations.Test;

import com.ning.compress.BaseForTests;
import com.ning.compress.lzf.impl.UnsafeChunkEncoder;
import com.ning.compress.lzf.impl.UnsafeChunkEncoderBE;
import com.ning.compress.lzf.impl.UnsafeChunkEncoderLE;
import com.ning.compress.lzf.util.ChunkEncoderFactory;

public class LZFEncoderTest extends BaseForTests
Expand Down Expand Up @@ -169,4 +172,26 @@ private void _testConditionalCompression(ChunkEncoder enc, final byte[] input) t
chunk = enc.encodeChunkIfCompresses(input, 0, input.length, 0.60);
Assert.assertNull(chunk);
}

@Test
public void testUnsafeValidation() {
_testUnsafeValidation(new UnsafeChunkEncoderBE(10));
_testUnsafeValidation(new UnsafeChunkEncoderLE(10));

}

private void _testUnsafeValidation(UnsafeChunkEncoder encoder) {
byte[] array = new byte[10];
int goodStart = 2;
int goodEnd = 5;

Assert.assertThrows(NullPointerException.class, () -> encoder.tryCompress(null, goodStart, goodEnd, array, goodStart));
Assert.assertThrows(NullPointerException.class, () -> encoder.tryCompress(array, goodStart, goodEnd, null, goodStart));
Assert.assertThrows(ArrayIndexOutOfBoundsException.class, () -> encoder.tryCompress(array, -1, goodEnd, array, goodStart));
Assert.assertThrows(ArrayIndexOutOfBoundsException.class, () -> encoder.tryCompress(array, 12, goodEnd, array, goodStart));
Assert.assertThrows(ArrayIndexOutOfBoundsException.class, () -> encoder.tryCompress(array, goodStart, 1, array, goodStart));
Assert.assertThrows(ArrayIndexOutOfBoundsException.class, () -> encoder.tryCompress(array, goodStart, 12, array, goodStart));
Assert.assertThrows(ArrayIndexOutOfBoundsException.class, () -> encoder.tryCompress(array, goodStart, goodEnd, array, -1));
Assert.assertThrows(ArrayIndexOutOfBoundsException.class, () -> encoder.tryCompress(array, goodStart, goodEnd, array, 12));
}
}
17 changes: 17 additions & 0 deletions src/test/java/com/ning/compress/lzf/TestLZFDecoder.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import org.testng.annotations.Test;

import com.ning.compress.BaseForTests;
import com.ning.compress.lzf.impl.UnsafeChunkDecoder;
import com.ning.compress.lzf.util.ChunkDecoderFactory;

public class TestLZFDecoder extends BaseForTests
Expand All @@ -28,6 +29,22 @@ public void testChunks() throws IOException {
_testChunks(ChunkDecoderFactory.optimalInstance());
}

@Test
public void testUnsafeValidation() {
UnsafeChunkDecoder decoder = new UnsafeChunkDecoder();

byte[] array = new byte[10];
int goodStart = 2;
int goodEnd = 5;
Assert.assertThrows(NullPointerException.class, () -> decoder.decodeChunk(null, goodStart, array, goodStart, goodEnd));
Assert.assertThrows(NullPointerException.class, () -> decoder.decodeChunk(array, goodStart, null, goodStart, goodEnd));
Assert.assertThrows(ArrayIndexOutOfBoundsException.class, () -> decoder.decodeChunk(array, -1, array, goodStart, goodEnd));
Assert.assertThrows(ArrayIndexOutOfBoundsException.class, () -> decoder.decodeChunk(array, 12, array, goodStart, goodEnd));
Assert.assertThrows(ArrayIndexOutOfBoundsException.class, () -> decoder.decodeChunk(array, goodStart, array, -1, goodEnd));
Assert.assertThrows(ArrayIndexOutOfBoundsException.class, () -> decoder.decodeChunk(array, goodStart, array, goodStart, 1));
Assert.assertThrows(ArrayIndexOutOfBoundsException.class, () -> decoder.decodeChunk(array, goodStart, array, goodStart, 12));
}

/*
///////////////////////////////////////////////////////////////////////
// Second-level test methods
Expand Down

0 comments on commit 9cd0e20

Please sign in to comment.