Skip to content

Commit

Permalink
Simplified branchless version
Browse files Browse the repository at this point in the history
  • Loading branch information
franz1981 committed Jan 5, 2025
1 parent 4d9e4cd commit f485dc2
Showing 1 changed file with 17 additions and 78 deletions.
95 changes: 17 additions & 78 deletions 2024/10/14/src/main/java/me/lemire/MyBenchmark.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,20 @@ public class MyBenchmark {
}

private static final byte[] silly_table3;
private static final short[] replacementsAndCompressionTables = new short[256 * 2];
private static final short[] replacementsAndLengthsTables = new short[256 * 2];

static {
for (int i = 0; i < 256; i++) {
replacementsAndCompressionTables[i * 2] = (short) ((i & 0xFF) << 8);
replacementsAndCompressionTables[(i * 2) + 1] = (short) 0xFF00;
replacementsAndLengthsTables[i * 2] = (short) (i & 0xFF);
replacementsAndLengthsTables[(i * 2) + 1] = 1;
}
// we need to override the default values
replacementsAndCompressionTables['\\' * 2] = (short) (0x005c | (('\\' & 0xFF) << 8));
replacementsAndCompressionTables[('\\' * 2) + 1] = (short) 0xFFFF;
replacementsAndCompressionTables['\n' * 2] = (short) (0x005c | (('n' & 0xFF) << 8));
replacementsAndCompressionTables[('\n' * 2) + 1] = (short) 0xFFFF;
replacementsAndCompressionTables['\t' * 2] = (short) (0x005c | (('t' & 0xFF) << 8));
replacementsAndCompressionTables[('\t' * 2) + 1] = (short) 0xFFFF;
replacementsAndLengthsTables['\\' * 2] = (short) (0x005c | (('\\' & 0xFF) << 8));
replacementsAndLengthsTables[('\\' * 2) + 1] = 2;
replacementsAndLengthsTables['\n' * 2] = (short) (0x005c | (('n' & 0xFF) << 8));
replacementsAndLengthsTables[('\n' * 2) + 1] = 2;
replacementsAndLengthsTables['\t' * 2] = (short) (0x005c | (('t' & 0xFF) << 8));
replacementsAndLengthsTables[('\t' * 2) + 1] = 2;

silly_table3 = new byte[256];
silly_table3['\\'] = '\\';
Expand Down Expand Up @@ -253,82 +253,21 @@ private static byte[] latinCharsExcluding(byte[] specialChars) {
return nonSpecialLatinChars;
}

private static final VarHandle LONG_WRITER = MethodHandles.byteArrayViewVarHandle(long[].class, ByteOrder.LITTLE_ENDIAN);
private static final VarHandle SHORT_WRITER = MethodHandles.byteArrayViewVarHandle(short[].class, ByteOrder.LITTLE_ENDIAN);

public static void main(String[] args) {
replaceBackslashRawCompressedTable3(new byte[] { 'a', '\\' , 'b', '\n', 'c', '\t', 'd', 'e', '\\', '\n', 'a'}, new byte[(11 * 2) + 16]);
}

public static int replaceBackslashRawCompressedTable3(byte[] original, byte[] newArray) {
int newArrayLength = 0;
int eightCharsBatches = original.length / 8;
for (int b = 0; b < eightCharsBatches; b++) {
int i = b * 8;
long readChars = (long) LONG_WRITER.get(original, i);
int digits = Long.bitCount(writeReplacementPart(newArray, newArrayLength, (int) (readChars & 0xFFFF_FFFFL))) / 8;
digits += (Long.bitCount(writeReplacementPart(newArray, newArrayLength + digits, (int)((readChars >>> 32) & 0xFFFFFFFFL))) / 8);
newArrayLength += digits;
}
int tail = original.length % 8;
if (tail > 0) {
int i = eightCharsBatches * 8;
long readChars = readTailChars(original, i, tail);
// since the compression mask for 0 is 0xFF00 we need to trust tail size
int additionalCharsRight = Long.bitCount(writeReplacementPart(newArray, newArrayLength, (int) (readChars & 0xFFFF_FFFFL)) & 0x00FF_00FF_00FF_00FFL) / 8;
int charsWrittenRight = Math.min(tail, 4) + additionalCharsRight;
int additionalCharsLeft = Long.bitCount(writeReplacementPart(newArray, newArrayLength + charsWrittenRight, (int)((readChars >>> 32) & 0xFFFFFFFFL)) & 0x00FF_00FF_00FF_00FFL) / 8;
newArrayLength += (tail + additionalCharsRight + additionalCharsLeft);
}
return newArrayLength;
}

private static long writeReplacementPart(byte[] newArray, int newArrayLength, int readChars) {
long compressionMask = 0;
long replacements = 0;
for (int i = 0; i < 4; i++) {
int c = (readChars >>> (i * 8)) & 0xFF;
long replacement = Short.toUnsignedLong(replacementsAndCompressionTables[c * 2]);
long mask = Short.toUnsignedLong(replacementsAndCompressionTables[(c * 2) + 1]);
replacements |= (replacement << (i * 16));
compressionMask |= (mask << (i * 16));
}
LONG_WRITER.set(newArray, newArrayLength, Long.compress(replacements, compressionMask));
return compressionMask;
}

private static long readTailChars(byte[] original, int index, int count) {
assert count < 8;
long chars = 0;
int idx = index;
long ch = original[idx] & 0xFFL;
chars |= ch;
// TODO this could be implemented with batch operations
// 7 bytes = 4 + 2 + 1 or 6 bytes = 4 + 2 or 5 bytes = 4 + 1
if (count > 1) {
ch = original[idx + 1] & 0xFFL;
chars |= ch << 8;
if (count > 2) {
ch = original[idx + 2] & 0xFFL;
chars |= ch << 16;
if (count > 3) {
ch = original[idx + 3] & 0xFFL;
chars |= ch << 24;
if (count > 4) {
ch = original[idx + 4] & 0xFFL;
chars |= ch << 32;
if (count > 5) {
ch = original[idx + 5] & 0xFFL;
chars |= ch << 40;
if (count > 6) {
ch = original[idx + 6] & 0xFFL;;
chars |= ch << 48;
}
}
}
}
}
int outputOffset = 0;
for (byte b : original) {
int latin = b & 0xFF;
short replacement = replacementsAndLengthsTables[2 * latin];
SHORT_WRITER.set(newArray, outputOffset, replacement);
outputOffset += replacementsAndLengthsTables[(2 * latin) + 1];
}
return chars;
return outputOffset;
}

@Benchmark
Expand Down

0 comments on commit f485dc2

Please sign in to comment.