Skip to content

Commit

Permalink
Move common code to CompressionUtils
Browse files Browse the repository at this point in the history
  • Loading branch information
yash-puligundla authored and cmnbroad committed Feb 27, 2024
1 parent be396bb commit 6ffdb71
Show file tree
Hide file tree
Showing 8 changed files with 246 additions and 233 deletions.
179 changes: 179 additions & 0 deletions src/main/java/htsjdk/samtools/cram/compression/CompressionUtils.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
package htsjdk.samtools.cram.compression;

import htsjdk.samtools.cram.CRAMException;
import htsjdk.samtools.cram.compression.rans.Constants;

import java.nio.ByteBuffer;
import java.nio.ByteOrder;

public class CompressionUtils {
public static void writeUint7(final int i, final ByteBuffer cp) {
int s = 0;
int X = i;
do {
s += 7;
X >>= 7;
} while (X > 0);
do {
s -= 7;
//writeByte
final int s_ = (s > 0) ? 1 : 0;
cp.put((byte) (((i >> s) & 0x7f) + (s_ << 7)));
} while (s > 0);
}

public static int readUint7(final ByteBuffer cp) {
int i = 0;
int c;
do {
//read byte
c = cp.get();
i = (i << 7) | (c & 0x7f);
} while ((c & 0x80) != 0);
return i;
}

public static ByteBuffer encodePack(
final ByteBuffer inBuffer,
final ByteBuffer outBuffer,
final int[] frequencyTable,
final int[] packMappingTable,
final int numSymbols){
final int inSize = inBuffer.remaining();
final ByteBuffer encodedBuffer;
if (numSymbols <= 1) {
encodedBuffer = CompressionUtils.allocateByteBuffer(0);
} else if (numSymbols <= 2) {

// 1 bit per value
final int encodedBufferSize = (int) Math.ceil((double) inSize/8);
encodedBuffer = CompressionUtils.allocateByteBuffer(encodedBufferSize);
int j = -1;
for (int i = 0; i < inSize; i ++) {
if (i % 8 == 0) {
encodedBuffer.put(++j, (byte) 0);
}
encodedBuffer.put(j, (byte) (encodedBuffer.get(j) + (packMappingTable[inBuffer.get(i) & 0xFF] << (i % 8))));
}
} else if (numSymbols <= 4) {

// 2 bits per value
final int encodedBufferSize = (int) Math.ceil((double) inSize/4);
encodedBuffer = CompressionUtils.allocateByteBuffer(encodedBufferSize);
int j = -1;
for (int i = 0; i < inSize; i ++) {
if (i % 4 == 0) {
encodedBuffer.put(++j, (byte) 0);
}
encodedBuffer.put(j, (byte) (encodedBuffer.get(j) + (packMappingTable[inBuffer.get(i) & 0xFF] << ((i % 4) * 2))));
}
} else {

// 4 bits per value
final int encodedBufferSize = (int) Math.ceil((double)inSize/2);
encodedBuffer = CompressionUtils.allocateByteBuffer(encodedBufferSize);
int j = -1;
for (int i = 0; i < inSize; i ++) {
if (i % 2 == 0) {
encodedBuffer.put(++j, (byte) 0);
}
encodedBuffer.put(j, (byte) (encodedBuffer.get(j) + (packMappingTable[inBuffer.get(i) & 0xFF] << ((i % 2) * 4))));
}
}

// write numSymbols
outBuffer.put((byte) numSymbols);

// write mapping table "packMappingTable" that converts mapped value to original symbol
for(int i = 0; i < Constants.NUMBER_OF_SYMBOLS; i ++) {
if (frequencyTable[i] > 0) {
outBuffer.put((byte) i);
}
}

// write the length of data
CompressionUtils.writeUint7(encodedBuffer.limit(), outBuffer);
return encodedBuffer; // Here position = 0 since we have always accessed the data buffer using index
}

public static ByteBuffer decodePack(
final ByteBuffer inBuffer,
final byte[] packMappingTable,
final int numSymbols,
final int uncompressedPackOutputLength) {
final ByteBuffer outBufferPack = CompressionUtils.allocateByteBuffer(uncompressedPackOutputLength);
int j = 0;
if (numSymbols <= 1) {
for (int i=0; i < uncompressedPackOutputLength; i++){
outBufferPack.put(i, packMappingTable[0]);
}
}

// 1 bit per value
else if (numSymbols <= 2) {
int v = 0;
for (int i=0; i < uncompressedPackOutputLength; i++){
if (i % 8 == 0){
v = inBuffer.get(j++);
}
outBufferPack.put(i, packMappingTable[v & 1]);
v >>=1;
}
}

// 2 bits per value
else if (numSymbols <= 4){
int v = 0;
for(int i=0; i < uncompressedPackOutputLength; i++){
if (i % 4 == 0){
v = inBuffer.get(j++);
}
outBufferPack.put(i, packMappingTable[v & 3]);
v >>=2;
}
}

// 4 bits per value
else if (numSymbols <= 16){
int v = 0;
for(int i=0; i < uncompressedPackOutputLength; i++){
if (i % 2 == 0){
v = inBuffer.get(j++);
}
outBufferPack.put(i, packMappingTable[v & 15]);
v >>=4;
}
}
return outBufferPack;
}



public static ByteBuffer allocateOutputBuffer(final int inSize) {
// This calculation is identical to the one in samtools rANS_static.c
// Presumably the frequency table (always big enough for order 1) = 257*257,
// then * 3 for each entry (byte->symbol, 2 bytes -> scaled frequency),
// + 9 for the header (order byte, and 2 int lengths for compressed/uncompressed lengths).
final int compressedSize = (int) (inSize + 257 * 257 * 3 + 9);
final ByteBuffer outputBuffer = ByteBuffer.allocate(compressedSize).order(ByteOrder.LITTLE_ENDIAN);
if (outputBuffer.remaining() < compressedSize) {
throw new CRAMException("Failed to allocate sufficient buffer size for RANS coder.");
}
return outputBuffer;
}

// returns a new LITTLE_ENDIAN ByteBuffer of size = bufferSize
public static ByteBuffer allocateByteBuffer(final int bufferSize){
return ByteBuffer.allocate(bufferSize).order(ByteOrder.LITTLE_ENDIAN);
}

// returns a LITTLE_ENDIAN ByteBuffer that is created by wrapping a byte[]
public static ByteBuffer wrap(final byte[] inputBytes){
return ByteBuffer.wrap(inputBytes).order(ByteOrder.LITTLE_ENDIAN);
}

// returns a LITTLE_ENDIAN ByteBuffer that is created by inputBuffer.slice()
public static ByteBuffer slice(final ByteBuffer inputBuffer){
return inputBuffer.slice().order(ByteOrder.LITTLE_ENDIAN);
}
}
56 changes: 0 additions & 56 deletions src/main/java/htsjdk/samtools/cram/compression/rans/Utils.java
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package htsjdk.samtools.cram.compression.rans;

import htsjdk.samtools.cram.CRAMException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;

final public class Utils {

Expand Down Expand Up @@ -54,32 +52,6 @@ public static long RANSDecodeRenormalizeNx16(final long r, final ByteBuffer byte
return ret;
}

public static void writeUint7(final int i, final ByteBuffer cp) {
int s = 0;
int X = i;
do {
s += 7;
X >>= 7;
} while (X > 0);
do {
s -= 7;
//writeByte
final int s_ = (s > 0) ? 1 : 0;
cp.put((byte) (((i >> s) & 0x7f) + (s_ << 7)));
} while (s > 0);
}

public static int readUint7(final ByteBuffer cp) {
int i = 0;
int c;
do {
//read byte
c = cp.get();
i = (i << 7) | (c & 0x7f);
} while ((c & 0x80) != 0);
return i;
}

public static void normaliseFrequenciesOrder0(final int[] F, final int bits) {
// Returns an array of normalised Frequencies,
// such that the frequencies add up to 1<<bits.
Expand Down Expand Up @@ -194,32 +166,4 @@ public static void normaliseFrequenciesOrder1Shift(final int[][] F, final int sh
}
}

public static ByteBuffer allocateOutputBuffer(final int inSize) {
// This calculation is identical to the one in samtools rANS_static.c
// Presumably the frequency table (always big enough for order 1) = 257*257,
// then * 3 for each entry (byte->symbol, 2 bytes -> scaled frequency),
// + 9 for the header (order byte, and 2 int lengths for compressed/uncompressed lengths).
final int compressedSize = (int) (inSize + 257 * 257 * 3 + 9);
final ByteBuffer outputBuffer = ByteBuffer.allocate(compressedSize).order(ByteOrder.LITTLE_ENDIAN);
if (outputBuffer.remaining() < compressedSize) {
throw new CRAMException("Failed to allocate sufficient buffer size for RANS coder.");
}
return outputBuffer;
}

// returns a new LITTLE_ENDIAN ByteBuffer of size = bufferSize
public static ByteBuffer allocateByteBuffer(final int bufferSize){
return ByteBuffer.allocate(bufferSize).order(ByteOrder.LITTLE_ENDIAN);
}

// returns a LITTLE_ENDIAN ByteBuffer that is created by wrapping a byte[]
public static ByteBuffer wrap(final byte[] inputBytes){
return ByteBuffer.wrap(inputBytes).order(ByteOrder.LITTLE_ENDIAN);
}

// returns a LITTLE_ENDIAN ByteBuffer that is created by inputBuffer.slice()
public static ByteBuffer slice(final ByteBuffer inputBuffer){
return inputBuffer.slice().order(ByteOrder.LITTLE_ENDIAN);
}

}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package htsjdk.samtools.cram.compression.rans.rans4x8;

import htsjdk.samtools.cram.CRAMException;
import htsjdk.samtools.cram.compression.CompressionUtils;
import htsjdk.samtools.cram.compression.rans.ArithmeticDecoder;
import htsjdk.samtools.cram.compression.rans.Constants;
import htsjdk.samtools.cram.compression.rans.RANSDecode;
Expand All @@ -15,7 +16,7 @@
public class RANS4x8Decode extends RANSDecode {

private static final int RAW_BYTE_LENGTH = 4;
private static final ByteBuffer EMPTY_BUFFER = Utils.allocateByteBuffer(0);
private static final ByteBuffer EMPTY_BUFFER = CompressionUtils.allocateByteBuffer(0);

// This method assumes that inBuffer is already rewound.
// It uncompresses the data in the inBuffer, leaving it consumed.
Expand All @@ -39,7 +40,7 @@ public ByteBuffer uncompress(final ByteBuffer inBuffer) {

// uncompressed bytes length
final int outSize = inBuffer.getInt();
final ByteBuffer outBuffer = Utils.allocateByteBuffer(outSize);
final ByteBuffer outBuffer = CompressionUtils.allocateByteBuffer(outSize);
initializeRANSDecoder();
switch (order) {
case ZERO:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package htsjdk.samtools.cram.compression.rans.rans4x8;

import htsjdk.samtools.cram.CRAMException;
import htsjdk.samtools.cram.compression.CompressionUtils;
import htsjdk.samtools.cram.compression.rans.Constants;
import htsjdk.samtools.cram.compression.rans.RANSEncode;
import htsjdk.samtools.cram.compression.rans.RANSEncodingSymbol;
Expand All @@ -15,7 +16,7 @@ public class RANS4x8Encode extends RANSEncode<RANS4x8Params> {
// streams smaller than this value don't have sufficient symbol context for ORDER-1 encoding,
// so always use ORDER-0
private static final int MINIMUM_ORDER_1_SIZE = 4;
private static final ByteBuffer EMPTY_BUFFER = Utils.allocateByteBuffer(0);
private static final ByteBuffer EMPTY_BUFFER = CompressionUtils.allocateByteBuffer(0);

// This method assumes that inBuffer is already rewound.
// It compresses the data in the inBuffer, leaving it consumed.
Expand Down Expand Up @@ -44,7 +45,7 @@ public ByteBuffer compress(final ByteBuffer inBuffer, final RANS4x8Params params

private ByteBuffer compressOrder0Way4(final ByteBuffer inBuffer) {
final int inputSize = inBuffer.remaining();
final ByteBuffer outBuffer = Utils.allocateOutputBuffer(inputSize);
final ByteBuffer outBuffer = CompressionUtils.allocateOutputBuffer(inputSize);

// move the output buffer ahead to the start of the frequency table (we'll come back and
// write the output stream prefix at the end of this method)
Expand All @@ -55,7 +56,7 @@ private ByteBuffer compressOrder0Way4(final ByteBuffer inBuffer) {

// using the normalised frequencies, set the RANSEncodingSymbols
buildSymsOrder0(normalizedFreq);
final ByteBuffer cp = Utils.slice(outBuffer);
final ByteBuffer cp = CompressionUtils.slice(outBuffer);

// write Frequency table
final int frequencyTableSize = writeFrequenciesOrder0(cp, normalizedFreq);
Expand All @@ -65,7 +66,7 @@ private ByteBuffer compressOrder0Way4(final ByteBuffer inBuffer) {
final RANSEncodingSymbol[] syms = getEncodingSymbols()[0];
final int in_size = inBuffer.remaining();
long rans0, rans1, rans2, rans3;
final ByteBuffer ptr = Utils.slice(cp);
final ByteBuffer ptr = CompressionUtils.slice(cp);
rans0 = Constants.RANS_4x8_LOWER_BOUND;
rans1 = Constants.RANS_4x8_LOWER_BOUND;
rans2 = Constants.RANS_4x8_LOWER_BOUND;
Expand Down Expand Up @@ -112,7 +113,7 @@ private ByteBuffer compressOrder0Way4(final ByteBuffer inBuffer) {

private ByteBuffer compressOrder1Way4(final ByteBuffer inBuffer) {
final int inSize = inBuffer.remaining();
final ByteBuffer outBuffer = Utils.allocateOutputBuffer(inSize);
final ByteBuffer outBuffer = CompressionUtils.allocateOutputBuffer(inSize);

// move to start of frequency
outBuffer.position(Constants.RANS_4x8_PREFIX_BYTE_LENGTH);
Expand All @@ -123,7 +124,7 @@ private ByteBuffer compressOrder1Way4(final ByteBuffer inBuffer) {
// using the normalised frequencies, set the RANSEncodingSymbols
buildSymsOrder1(normalizedFreq);

final ByteBuffer cp = Utils.slice(outBuffer);
final ByteBuffer cp = CompressionUtils.slice(outBuffer);
final int frequencyTableSize = writeFrequenciesOrder1(cp, normalizedFreq);
inBuffer.rewind();
final int in_size = inBuffer.remaining();
Expand Down Expand Up @@ -156,7 +157,7 @@ private ByteBuffer compressOrder1Way4(final ByteBuffer inBuffer) {
byte l3 = inBuffer.get(in_size - 1);

// Slicing is needed for buffer reversing later
final ByteBuffer ptr = Utils.slice(cp);
final ByteBuffer ptr = CompressionUtils.slice(cp);
final RANSEncodingSymbol[][] syms = getEncodingSymbols();
for (i3 = in_size - 2; i3 > 4 * isz4 - 2 && i3 >= 0; i3--) {
final byte c3 = inBuffer.get(i3);
Expand Down
Loading

0 comments on commit 6ffdb71

Please sign in to comment.