Skip to content

Commit

Permalink
sudhirtumati implementation - improve to reduce thread contention
Browse files Browse the repository at this point in the history
  • Loading branch information
sudhirtumati committed Jan 29, 2024
1 parent aacc608 commit 8ad72af
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 103 deletions.
2 changes: 1 addition & 1 deletion calculate_average_sudhirtumati.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@
# limitations under the License.
#

JAVA_OPTS="-XX:+UseZGC -XX:-TieredCompilation"
JAVA_OPTS="--enable-preview -XX:+UseZGC -XX:-TieredCompilation"
java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_sudhirtumati
199 changes: 97 additions & 102 deletions src/main/java/dev/morling/onebrc/CalculateAverage_sudhirtumati.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,153 +15,136 @@
*/
package dev.morling.onebrc;

import java.io.FileInputStream;
import java.io.IOException;
import java.io.RandomAccessFile;
import java.nio.ByteBuffer;
import java.nio.channels.FileChannel;
import java.util.HashMap;
import java.util.Map;
import java.util.TreeMap;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.Semaphore;

public class CalculateAverage_sudhirtumati {

private static final String FILE = "./measurements.txt";

private static final int bufferSize = 8192;
private static final byte SEMICOLON = (byte) ';';
private static final byte NEW_LINE = (byte) '\n';
private static final int THREAD_COUNT = Runtime.getRuntime().availableProcessors();

private final BlockingQueue<Runnable> queue = new LinkedBlockingQueue<>(100);

private final ThreadPoolExecutor executor = new ThreadPoolExecutor(THREAD_COUNT, THREAD_COUNT, 100L, TimeUnit.MILLISECONDS, queue);
private static final Semaphore PERMITS = new Semaphore(THREAD_COUNT);
private static final MeasurementAggregator globalAggregator = new MeasurementAggregator();
private static final Semaphore AGGREGATOR_PERMITS = new Semaphore(1);

public static void main(String[] args) throws IOException, InterruptedException {
CalculateAverage_sudhirtumati instance = new CalculateAverage_sudhirtumati();
instance.process();
System.out.println(MeasurementAggregator.getInstance().getResult());
instance.chunkProcess();
}

private void process() throws IOException, InterruptedException {
int bufferSize = 8192 * 128;
MeasurementAggregator aggregator = MeasurementAggregator.getInstance();
executor.prestartAllCoreThreads();
try (RandomAccessFile raf = new RandomAccessFile(FILE, "r");
FileChannel fc = raf.getChannel()) {
ByteBuffer buffer = ByteBuffer.allocate(bufferSize);
ByteBuffer leftoverBuf = ByteBuffer.allocate(0);
while (fc.read(buffer) > 0) {
buffer.flip();
ByteBuffer resultBuf = ByteBuffer.allocate(leftoverBuf.capacity() + buffer.limit()).put(leftoverBuf).put(buffer);
int endIndex = findIndexOfValidEnd(resultBuf);
ByteBuffer toProcessBuf = resultBuf.slice(0, endIndex);
queue.offer(new BufferChunkProcessor(deepCopy(toProcessBuf), aggregator), 10, TimeUnit.SECONDS);
resultBuf.position(0);
endIndex++;
leftoverBuf = resultBuf.slice(endIndex, resultBuf.limit() - endIndex);
buffer.clear();
}
}
executor.shutdown();
try {
if (!executor.awaitTermination(10, TimeUnit.SECONDS)) {
executor.shutdownNow();
}
}
catch (InterruptedException e) {
executor.shutdownNow();
private void chunkProcess() throws InterruptedException {
for (int i = 0; i < THREAD_COUNT; i++) {
PERMITS.acquire();
Thread t = new ChunkProcessingThread(i);
t.setName(STR."T\{i}");
t.start();
}
do {
Thread.sleep(100);
} while (PERMITS.availablePermits() != THREAD_COUNT);
System.out.println(globalAggregator.getResult());
}

private ByteBuffer deepCopy(ByteBuffer original) {
int pos = original.position();
int lim = original.limit();
try {
original.position(0).limit(original.capacity());
ByteBuffer copy = doDeepCopy(original);
copy.position(pos).limit(lim);
return copy;
}
finally {
original.position(pos).limit(lim);
}
}
static class ChunkProcessingThread extends Thread {

private ByteBuffer doDeepCopy(ByteBuffer original) {
int pos = original.position();
try {
ByteBuffer copy = ByteBuffer.allocate(original.remaining());
copy.put(original);
copy.order(original.order());
return copy.position(0);
}
finally {
original.position(pos);
}
}

int findIndexOfValidEnd(ByteBuffer buffer) {
int endIndex = -1;
int pos = buffer.limit() - 1;
while (endIndex == -1 && pos > -1) {
if ((char) buffer.get(pos) == '\n') {
endIndex = pos;
}
pos--;
}
return endIndex;
}

static final class BufferChunkProcessor implements Runnable {

private static final byte SEMICOLON = (byte) ';';
private static final byte NEW_LINE = (byte) '\n';
private final ByteBuffer buffer;
private final MeasurementAggregator measurementAggregator;
private int index;
private final MeasurementAggregator aggregator;

BufferChunkProcessor(ByteBuffer buffer, MeasurementAggregator measurementAggregator) {
this.buffer = buffer;
this.measurementAggregator = measurementAggregator;
ChunkProcessingThread(int index) {
this.index = index;
aggregator = new MeasurementAggregator();
}

@Override
public void run() {
int mStartMark = 0;
try (FileInputStream is = new FileInputStream(FILE);
FileChannel fc = is.getChannel()) {
ByteBuffer buffer = ByteBuffer.allocate(index == 0 ? bufferSize : bufferSize + 50);
fc.position(index == 0 ? 0 : (((long) index * bufferSize) - 50));
while (fc.read(buffer) != -1) {
buffer.flip();
if (index != 0 && fc.position() != bufferSize) {
seekStartPos(buffer);
}
processBuffer(buffer);
index += THREAD_COUNT;
fc.position(((long) index * bufferSize) - 50L);
if (buffer.capacity() == 8192) {
buffer = ByteBuffer.allocate(bufferSize + 50);
}
buffer.position(0);
}
}
catch (IOException e) {
throw new RuntimeException(e);
}
try {
AGGREGATOR_PERMITS.acquire();
globalAggregator.process(aggregator);
AGGREGATOR_PERMITS.release();
}
catch (InterruptedException e) {
throw new RuntimeException(e);
}
PERMITS.release();
}

private void processBuffer(ByteBuffer buffer) throws IOException {
int mStartMark = buffer.position();
int tStartMark = -1;
buffer.position(0);
int count = 0;
int count = buffer.position();
do {
byte b = buffer.get(count);
if (b == SEMICOLON) {
tStartMark = count;
}
else if (b == NEW_LINE || count == buffer.limit() - 1) {
else if (b == NEW_LINE) {
byte[] locArr = new byte[tStartMark - mStartMark];
byte[] tempArr = new byte[count - tStartMark];
buffer.get(mStartMark, locArr);
buffer.get(mStartMark + locArr.length + 1, tempArr);
measurementAggregator.process(locArr, tempArr);
aggregator.process(locArr, tempArr);
mStartMark = count + 1;
}
count++;
} while (count < buffer.limit());
}

private void seekStartPos(ByteBuffer buffer) {
int i = buffer.limit() > 50 ? 49 : buffer.limit() - 2;
for (; i >= 0; i--) {
if (buffer.get(i) == NEW_LINE) {
buffer.position(i + 1);
break;
}
}
}
}

static final class MeasurementAggregator {
private static final MeasurementAggregator instance = new MeasurementAggregator();
private static final long MAX_VALUE_DIVIDE_10 = Long.MAX_VALUE / 10;
private final Map<String, Measurement> store = new ConcurrentHashMap<>();
private final Map<String, Measurement> store = new HashMap<>();

private MeasurementAggregator() {
}

public static MeasurementAggregator getInstance() {
return instance;
public void process(MeasurementAggregator other) {
other.store.forEach((k, v) -> {
Measurement m = store.get(k);
if (m == null) {
m = new Measurement();
store.put(k, m);
}
m.process(v);
});
}

public void process(byte[] location, byte[] temperature) {
public void process(byte[] location, byte[] temperature) throws IOException {
String loc = new String(location);
Measurement measurement = store.get(loc);
if (measurement == null) {
Expand Down Expand Up @@ -285,16 +268,28 @@ public void process(double value) {
count++;
}

public void process(Measurement other) {
if (other.min < min) {
this.min = other.min;
}
if (other.max > max) {
this.max = other.max;
}
this.sum += other.sum;
this.count += other.count;
}

public String toString() {
ResultRow result = new ResultRow(min, sum, count, max);
return result.toString();
}

}

private record ResultRow(double min, double sum, double count, double max) {

public String toString() {
return round(min) + "/" + round((Math.round(sum * 10.0) / 10.0) / count) + "/" + round(max);
return STR."\{round(min)}/\{round((Math.round(sum * 10.0) / 10.0) / count)}/\{round(max)}";
}

private double round(double value) {
Expand Down

0 comments on commit 8ad72af

Please sign in to comment.