Skip to content

Commit

Permalink
Optimize StreamUtil by adding batching support (#543)
Browse files Browse the repository at this point in the history
* Optimize StreamUtil by adding batching support

* Fix javadoc

* Fix javadoc

---------

Co-authored-by: Karthik Ramgopal <[email protected]>
  • Loading branch information
karthikrg and li-kramgopa authored Jan 22, 2024
1 parent 775b5fb commit 8b490a1
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 8 deletions.
1 change: 0 additions & 1 deletion avro-builder/builder-spi/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ dependencies {
implementation "org.apache.logging.log4j:log4j-api:2.17.1"
implementation "commons-io:commons-io:2.11.0"
implementation "jakarta.json:jakarta.json-api:2.0.1"
implementation "com.pivovarit:parallel-collectors:2.5.0"

testImplementation "org.apache.avro:avro:1.9.2"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,17 @@

package com.linkedin.avroutil1.builder.util;

import com.pivovarit.collectors.ParallelCollectors;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Semaphore;
import java.util.concurrent.SynchronousQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.stream.Collector;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;


Expand All @@ -37,8 +40,8 @@ private StreamUtil() {
/**
* A convenience {@link Collector} used for executing parallel computations on a custom {@link Executor}
* and returning a {@link Stream} instance returning results as they arrive.
* <p>
* For the parallelism of 1, the stream is executed by the calling thread.
*
* <p>For the parallelism of 1, the stream is executed by the calling thread.</p>
*
* @param mapper a transformation to be performed in parallel
* @param parallelism the max parallelism level
Expand All @@ -48,6 +51,70 @@ private StreamUtil() {
* @return a {@code Collector} which collects all processed elements into a {@code Stream} in parallel.
*/
public static <T, R> Collector<T, ?, Stream<R>> toParallelStream(Function<T, R> mapper, int parallelism) {
return ParallelCollectors.parallelToStream(mapper, WORK_EXECUTOR, parallelism);
return toParallelStream(mapper, parallelism, 1);
}

/**
* A convenience {@link Collector} used for executing parallel computations on a custom {@link Executor}
* and returning a {@link Stream} instance returning results as they arrive.
*
* <p>For the parallelism of 1 or if the size of the elements is &lt;= batchSize, the stream is executed by the
* calling thread.</p>
*
* @param mapper a transformation to be performed in parallel
* @param parallelism the max parallelism level
* @param batchSize the size into which inputs should be batched before running the mapper.
* @param <T> the type of the collected elements
* @param <R> the result returned by {@code mapper}
*
* @return a {@code Collector} which collects all processed elements into a {@code Stream} in parallel.
*/
public static <T, R> Collector<T, ?, Stream<R>> toParallelStream(Function<T, R> mapper, int parallelism,
int batchSize) {
if (parallelism <= 0 || batchSize <= 0) {
throw new IllegalArgumentException("Parallelism and batch size must be >= 1");
}

return Collectors.collectingAndThen(Collectors.toList(), list -> {
if (list.isEmpty()) {
return Stream.empty();
}

if (parallelism == 1 || list.size() <= batchSize) {
return list.stream().map(mapper);
}

final Executor limitingExecutor = new LimitingExecutor(parallelism);
final int batchCount = (list.size() - 1) / batchSize;
return IntStream.rangeClosed(0, batchCount)
.mapToObj(batch -> {
int startIndex = batch * batchSize;
int endIndex = (batch == batchCount) ? list.size() : (batch + 1) * batchSize;
return list.subList(startIndex, endIndex).stream();
})
.map(batch -> CompletableFuture.supplyAsync(() -> batch.map(mapper), limitingExecutor))
.flatMap(CompletableFuture::join);
});
}

private static class LimitingExecutor implements Executor {

private final Semaphore _limiter;

private LimitingExecutor(int maxParallelism) {
_limiter = new Semaphore(maxParallelism);
}

@Override
public void execute(Runnable command) {
try {
_limiter.acquire();
WORK_EXECUTOR.execute(command);
} catch (InterruptedException e) {
throw new RuntimeException(e);
} finally {
_limiter.release();
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ public static void main(String[] args) throws Exception {
}
OperationContext opContext = operationContextBuilder.buildOperationContext(opConfig);
long operationContextBuildEnd = System.currentTimeMillis();
LOGGER.info("Built operation context in {} millis.", operationContextBuildStart - operationContextBuildEnd);
LOGGER.info("Built operation context in {} millis.", operationContextBuildEnd - operationContextBuildStart);

BuilderPluginContext context = new BuilderPluginContext(opContext);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ public void createOperations(BuilderPluginContext context) {
}

private void generateCode(OperationContext opContext) {
LOGGER.info("Generating Avro Java bindings...");

// Make sure the output folder exists
File outputFolder = config.getOutputSpecificRecordClassesRoot();
if (!outputFolder.exists() && !outputFolder.mkdirs()) {
Expand Down Expand Up @@ -109,7 +111,7 @@ private void generateCode(OperationContext opContext) {
} catch (Exception e) {
throw new RuntimeException("failed to generate class for " + namedSchema.getFullName(), e);
}
}, 10)).collect(Collectors.toList());
}, 10, 10)).collect(Collectors.toList());
long genEnd = System.currentTimeMillis();
LOGGER.info("Generated {} java source files in {} millis", generatedClasses.size(), genEnd - genStart);

Expand Down Expand Up @@ -138,7 +140,7 @@ private void writeJavaFilesToDisk(Collection<JavaFile> javaFiles, Path outputFol
}

return 1;
}, 10)).reduce(0, Integer::sum);
}, 10, 10)).reduce(0, Integer::sum);

long writeEnd = System.currentTimeMillis();
LOGGER.info("Wrote out {} generated java source files under {} in {} millis", filesWritten, outputFolderPath,
Expand Down

0 comments on commit 8b490a1

Please sign in to comment.