Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improved 2nd and final submission #683

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions calculate_average_yourwass.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,8 @@
# source "$HOME/.sdkman/bin/sdkman-init.sh"
# sdk use java 21.0.1-graal 1>&2

JAVA_OPTS="--enable-preview --enable-native-access=ALL-UNNAMED --add-modules jdk.incubator.vector"
java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_yourwass
JAVA_OPTS="-Xlog:all=off -Djdk.incubator.vector.VECTOR_ACCESS_OOB_CHECK=0 --enable-preview --enable-native-access=ALL-UNNAMED --add-modules jdk.incubator.vector"

eval "exec 3< <({ java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_yourwass; })"
read <&3 result
echo -e "$result"
151 changes: 77 additions & 74 deletions src/main/java/dev/morling/onebrc/CalculateAverage_yourwass.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
package dev.morling.onebrc;

import java.util.TreeMap;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.io.IOException;
import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
Expand All @@ -31,18 +33,15 @@
import sun.misc.Unsafe;

public class CalculateAverage_yourwass {

static final class Record {
public String city;
public long cityAddr;
public long cityLength;
public int min;
public int max;
public int count;
public long sum;
private long cityAddr;
private long cityLength;
private int min;
private int max;
private int count;
private long sum;

Record(final long cityAddr, final long cityLength) {
this.city = null;
this.cityAddr = cityAddr;
this.cityLength = cityLength;
this.min = 1000;
Expand All @@ -62,6 +61,8 @@ private Record merge(Record r) {
}
}

private final static Lock _mutex = new ReentrantLock(true);
private final static TreeMap<String, Record> aggregateResults = new TreeMap<>();
private static short lookupDecimal[];
private static byte lookupFraction[];
private static byte lookupDotPositive[];
Expand All @@ -70,6 +71,8 @@ private Record merge(Record r) {
private static final VectorSpecies<Byte> SPECIES = ByteVector.SPECIES_PREFERRED;
private static final int MAXINDEX = (1 << 16) + 10000; // short hash + max allowed cities for collisions at the end :p
private static final String FILE = "measurements.txt";
private static long unsafeResults;
private static int RECORDSIZE = 36;
private static final Unsafe UNSAFE = getUnsafe();

private static Unsafe getUnsafe() {
Expand Down Expand Up @@ -113,11 +116,9 @@ public static void main(String[] args) throws IOException, Throwable {
}

// open file
final long fileSize, mmapAddr;
try (var fileChannel = FileChannel.open(Path.of(FILE), StandardOpenOption.READ)) {
fileSize = fileChannel.size();
mmapAddr = fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileSize, Arena.global()).address();
}
final FileChannel fileChannel = FileChannel.open(Path.of(FILE), StandardOpenOption.READ);
final long fileSize = fileChannel.size();
final long mmapAddr = fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileSize, Arena.global()).address();
// VAS: Virtual Address Space, as a MemorySegment upto and including the mmaped file.
// If the mmaped MemorySegment is used for Vector creation as is, then there are two problems:
// 1) fromMemorySegment takes an offset and not an address, so we have to do arithmetic
Expand All @@ -127,36 +128,24 @@ public static void main(String[] args) throws IOException, Throwable {
// XXX there lies the possibility for an out of bounds read at the end of file, which is not handled here.
VAS = MemorySegment.ofAddress(0).reinterpret(mmapAddr + fileSize + SPECIES.length());

// start and wait for threads to finish
// allocate memory for results
final int nThreads = Runtime.getRuntime().availableProcessors();
unsafeResults = UNSAFE.allocateMemory(RECORDSIZE * MAXINDEX * nThreads);
UNSAFE.setMemory(unsafeResults, RECORDSIZE * MAXINDEX * nThreads, (byte) 0);

// start and wait for threads to finish
Thread[] threadList = new Thread[nThreads];
final Record[][] results = new Record[nThreads][];
final long chunkSize = fileSize / nThreads;
for (int i = 0; i < nThreads; i++) {
final int threadIndex = i;
final long startAddr = mmapAddr + i * chunkSize;
final long endAddr = (i == nThreads - 1) ? mmapAddr + fileSize : mmapAddr + (i + 1) * chunkSize;
threadList[i] = new Thread(() -> results[threadIndex] = threadMain(threadIndex, startAddr, endAddr, nThreads));
threadList[i] = new Thread(() -> threadMain(threadIndex, startAddr, endAddr, nThreads));
threadList[i].start();
}
for (int i = 0; i < nThreads; i++)
threadList[i].join();

// aggregate results and sort
// TODO have to compare with concurrent-parallel stream structures:
// * concurrent hashtable that have to sort afterwards
// * concurrent skiplist that is sorted but has O(n) insert
// * ..other?
final TreeMap<String, Record> aggregateResults = new TreeMap<>();
for (int thread = 0; thread < nThreads; thread++) {
for (int index = 0; index < MAXINDEX; index++) {
Record record = results[thread][index];
if (record == null)
continue;
aggregateResults.compute(record.city, (k, v) -> (v == null) ? record : v.merge(record));
}
}

// prepare string and print
StringBuilder sb = new StringBuilder();
sb.append("{");
Expand All @@ -167,12 +156,13 @@ public static void main(String[] args) throws IOException, Throwable {
float max = record.max;
max /= 10.f;
double avg = Math.round((record.sum * 1.0) / record.count) / 10.;
sb.append(record.city).append("=").append(min).append("/").append(avg).append("/").append(max).append(", ");
sb.append(entry.getKey()).append("=").append(min).append("/").append(avg).append("/").append(max).append(", ");
}
int stringLength = sb.length();
sb.setCharAt(stringLength - 2, '}');
sb.setCharAt(stringLength - 1, '\n');
System.out.print(sb.toString());
System.out.close();
}

private static final boolean citiesDiffer(final long a, final long b, final long len) {
Expand All @@ -185,7 +175,7 @@ private static final boolean citiesDiffer(final long a, final long b, final long
return false;
}

private static Record[] threadMain(int id, long startAddr, long endAddr, long nThreads) {
private static void threadMain(int id, long startAddr, long endAddr, long nThreads) {
// snap to newlines
if (id != 0)
while (UNSAFE.getByte(startAddr++) != '\n')
Expand All @@ -194,23 +184,24 @@ private static Record[] threadMain(int id, long startAddr, long endAddr, long nT
while (UNSAFE.getByte(endAddr++) != '\n')
;

final long threadResults = unsafeResults + id * MAXINDEX * RECORDSIZE;
final Record[] results = new Record[MAXINDEX];
final long VECTORBYTESIZE = SPECIES.length();
final ByteOrder BYTEORDER = ByteOrder.nativeOrder();
final ByteVector delim = ByteVector.broadcast(SPECIES, ';');
long nextCityAddr = startAddr; // XXX from these three variables,
long cityAddr = nextCityAddr; // only two are necessary, but if one
long ptr = 0; // is eliminated, on my pc the benchmark gets worse..
while (nextCityAddr < endAddr) {
long cityAddr = startAddr;
long ptr = 0;
while (cityAddr < endAddr) {
// parse city
long mask = ByteVector.fromMemorySegment(SPECIES, VAS, nextCityAddr + ptr, BYTEORDER)
.compare(VectorOperators.EQ, delim).toLong();
if (mask == 0) {
ByteVector parsed = ByteVector.fromMemorySegment(SPECIES, VAS, cityAddr, BYTEORDER);
long mask = parsed.compare(VectorOperators.EQ, delim).toLong();
while (mask == 0) {
ptr += VECTORBYTESIZE;
continue;
mask = ByteVector.fromMemorySegment(SPECIES, VAS, cityAddr + ptr, BYTEORDER).compare(VectorOperators.EQ, delim).toLong();
}
final long cityLength = ptr + Long.numberOfTrailingZeros(mask);
final long tempAddr = cityAddr + cityLength + 1;
ptr = 0;

// compute hash table index
int index;
Expand All @@ -222,67 +213,79 @@ private static Record[] threadMain(int id, long startAddr, long endAddr, long nT
& 0xFFFF;
else
index = (UNSAFE.getByte(cityAddr) << 8) & 0xFF00;

// resolve collisions with linear probing
// use vector api here also, but only if city name fits in one vector length, for faster default case
Record record = results[index];
long record = threadResults + index * RECORDSIZE;
long recordCityLength = UNSAFE.getLong(record);
if (cityLength <= VECTORBYTESIZE) {
ByteVector parsed = ByteVector.fromMemorySegment(SPECIES, VAS, cityAddr, BYTEORDER);
while (record != null) {
if (cityLength == record.cityLength) {
long sameMask = ByteVector.fromMemorySegment(SPECIES, VAS, record.cityAddr, BYTEORDER)
while (recordCityLength > 0) {
if (cityLength == recordCityLength) {
long sameMask = ByteVector.fromMemorySegment(SPECIES, VAS, UNSAFE.getLong(record + 8), BYTEORDER)
.compare(VectorOperators.EQ, parsed).toLong();
if (Long.numberOfTrailingZeros(~sameMask) >= cityLength)
break;
}
record = results[++index];
index++;
record = threadResults + index * RECORDSIZE;
recordCityLength = UNSAFE.getLong(record);
}
}
else { // slower normal case for city names with length > VECTORBYTESIZE
while (record != null && (cityLength != record.cityLength || citiesDiffer(record.cityAddr, cityAddr, cityLength)))
record = results[++index];
while (recordCityLength > 0 && (cityLength != recordCityLength || citiesDiffer(UNSAFE.getLong(record + 8), cityAddr, cityLength))) {
index++;
record = threadResults + index * RECORDSIZE;
recordCityLength = UNSAFE.getLong(record);
}
}

// add record for new keys
// TODO have to avoid memory allocations on hot path
if (record == null) {
results[index] = new Record(cityAddr, cityLength);
record = results[index];
// add record for new key
if (recordCityLength == 0) {
UNSAFE.putLong(record, cityLength);
UNSAFE.putLong(record + 8, cityAddr);
UNSAFE.putInt(record + 16, 1000);
UNSAFE.putInt(record + 20, -1000);
}

// parse temp with lookup tables
int temp;
if (UNSAFE.getByte(tempAddr) == '-') {
temp = -lookupDecimal[UNSAFE.getShort(tempAddr + 1)] - lookupFraction[UNSAFE.getShort(tempAddr + 3)];
nextCityAddr = tempAddr + lookupDotNegative[UNSAFE.getShort(tempAddr + 3)];
cityAddr = tempAddr + lookupDotNegative[UNSAFE.getShort(tempAddr + 3)];
}
else {
temp = lookupDecimal[UNSAFE.getShort(tempAddr)] + lookupFraction[UNSAFE.getShort(tempAddr + 2)];
nextCityAddr = tempAddr + lookupDotPositive[UNSAFE.getShort(tempAddr + 2)];
cityAddr = tempAddr + lookupDotPositive[UNSAFE.getShort(tempAddr + 2)];
}
cityAddr = nextCityAddr;
ptr = 0;

// merge record
if (temp < record.min)
record.min = temp;
if (temp > record.max)
record.max = temp;
record.sum += temp;
record.count += 1;
// merge
if (temp < UNSAFE.getInt(record + 16))
UNSAFE.putInt(record + 16, temp);
if (temp > UNSAFE.getInt(record + 20))
UNSAFE.putInt(record + 20, temp);
UNSAFE.putLong(record + 24, UNSAFE.getLong(record + 24) + temp);
UNSAFE.putInt(record + 32, UNSAFE.getInt(record + 32) + 1);
}

// create strings from raw data
// TODO should avoid this copy
// and aggregate results onto TreeMap
int idx = 0;
byte b[] = new byte[100];
_mutex.lock();
for (int i = 0; i < MAXINDEX; i++) {
Record r = results[i];
if (r == null)
if (UNSAFE.getLong(threadResults + i * RECORDSIZE) == 0)
continue;
UNSAFE.copyMemory(null, r.cityAddr, b, Unsafe.ARRAY_BYTE_BASE_OFFSET, r.cityLength);
r.city = new String(b, 0, (int) r.cityLength, StandardCharsets.UTF_8);
final long recordAddress = threadResults + i * RECORDSIZE;

results[idx] = new Record(UNSAFE.getLong(recordAddress + 8), UNSAFE.getLong(recordAddress));
results[idx].min = UNSAFE.getInt(recordAddress + 16);
results[idx].max = UNSAFE.getInt(recordAddress + 20);
results[idx].sum = UNSAFE.getLong(recordAddress + 24);
results[idx].count = UNSAFE.getInt(recordAddress + 32);
UNSAFE.copyMemory(null, UNSAFE.getLong(recordAddress + 8), b, Unsafe.ARRAY_BYTE_BASE_OFFSET, UNSAFE.getLong(recordAddress));
final Record record = results[idx];
aggregateResults.compute(new String(b, 0, (int) results[idx].cityLength, StandardCharsets.UTF_8), (k, v) -> (v == null) ? record : v.merge(record));
idx++;
}
return results;
_mutex.unlock();
}

}
Loading