Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TaskExecutor should not fork unnecessarily #13472

Merged
merged 11 commits into from
Jul 4, 2024
85 changes: 41 additions & 44 deletions lucene/core/src/java/org/apache/lucene/search/TaskExecutor.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,27 +30,21 @@
import java.util.concurrent.FutureTask;
import java.util.concurrent.RunnableFuture;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.ThreadInterruptedException;

/**
* Executor wrapper responsible for the execution of concurrent tasks. Used to parallelize search
* across segments as well as query rewrite in some cases. Exposes a single {@link
* #invokeAll(Collection)} method that takes a collection of {@link Callable}s and executes them
* concurrently/ Once all tasks are submitted to the executor, it blocks and wait for all tasks to
* be completed, and then returns a list with the obtained results. Ensures that the underlying
* executor is only used for top-level {@link #invokeAll(Collection)} calls, and not for potential
* {@link #invokeAll(Collection)} calls made from one of the tasks. This is to prevent deadlock with
* certain types of pool based executors (e.g. {@link java.util.concurrent.ThreadPoolExecutor}).
* concurrently. Once all but one task have been submitted to the executor, it tries to run as many
* tasks as possible on the calling thread, then waits for all tasks that have been executed in
* parallel on the executor to be completed and then returns a list with the obtained results.
javanna marked this conversation as resolved.
Show resolved Hide resolved
*
* @lucene.experimental
*/
public final class TaskExecutor {
// a static thread local is ok as long as we use a counter, which accounts for multiple
// searchers holding a different TaskExecutor all backed by the same executor
private static final ThreadLocal<Integer> numberOfRunningTasksInCurrentThread =
ThreadLocal.withInitial(() -> 0);

private final Executor executor;

/**
Expand Down Expand Up @@ -84,26 +78,21 @@ public String toString() {
/**
* Holds all the sub-tasks that a certain operation gets split into as it gets parallelized and
* exposes the ability to invoke such tasks and wait for them all to complete their execution and
* provide their results. Ensures that each task does not get parallelized further: this is
* important to avoid a deadlock in situations where one executor thread waits on other executor
* threads to complete before it can progress. This happens in situations where for instance
* {@link Query#createWeight(IndexSearcher, ScoreMode, float)} is called as part of searching each
* slice, like {@link TopFieldCollector#populateScores(ScoreDoc[], IndexSearcher, Query)} does.
* Additionally, if one task throws an exception, all other tasks from the same group are
* cancelled, to avoid needless computation as their results would not be exposed anyways. Creates
* one {@link FutureTask} for each {@link Callable} provided
* provide their results. Additionally, if one task throws an exception, all other tasks from the
* same group are cancelled, to avoid needless computation as their results would not be exposed
* anyways. Creates one {@link FutureTask} for each {@link Callable} provided
*
* @param <T> the return type of all the callables
*/
private static final class TaskGroup<T> {
private final Collection<RunnableFuture<T>> futures;
private final List<RunnableFuture<T>> futures;

TaskGroup(Collection<Callable<T>> callables) {
List<RunnableFuture<T>> tasks = new ArrayList<>(callables.size());
for (Callable<T> callable : callables) {
tasks.add(createTask(callable));
}
this.futures = Collections.unmodifiableCollection(tasks);
this.futures = Collections.unmodifiableList(tasks);
}

RunnableFuture<T> createTask(Callable<T> callable) {
Expand All @@ -112,15 +101,10 @@ RunnableFuture<T> createTask(Callable<T> callable) {
() -> {
if (startedOrCancelled.compareAndSet(false, true)) {
try {
Integer counter = numberOfRunningTasksInCurrentThread.get();
javanna marked this conversation as resolved.
Show resolved Hide resolved
numberOfRunningTasksInCurrentThread.set(counter + 1);
return callable.call();
} catch (Throwable t) {
cancelAll();
throw t;
} finally {
Integer counter = numberOfRunningTasksInCurrentThread.get();
numberOfRunningTasksInCurrentThread.set(counter - 1);
}
}
// task is cancelled hence it has no results to return. That's fine: they would be
Expand All @@ -144,32 +128,45 @@ public boolean cancel(boolean mayInterruptIfRunning) {
}

List<T> invokeAll(Executor executor) throws IOException {
boolean runOnCallerThread = numberOfRunningTasksInCurrentThread.get() > 0;
for (Runnable runnable : futures) {
if (runOnCallerThread) {
runnable.run();
} else {
executor.execute(runnable);
final int count = futures.size();
// taskId provides the first index of an un-executed task in #futures
final AtomicInteger taskId = new AtomicInteger(0);
// we fork execution count - 1 tasks to execute at least one task on the current thread to
// minimize needless forking and blocking of the current thread
if (count > 1) {
final Runnable work =
() -> {
int id = taskId.getAndIncrement();
if (id < count) {
futures.get(id).run();
}
};
for (int j = 0; j < count - 1; j++) {
javanna marked this conversation as resolved.
Show resolved Hide resolved
executor.execute(work);
}
}
// try to execute as many tasks as possible on the current thread to minimize context
// switching in case of long running concurrent
// tasks as well as dead-locking if the current thread is part of #executor for executors that
// have limited or no parallelism
int id;
while ((id = taskId.getAndIncrement()) < count) {
futures.get(id).run();
if (id >= count - 1) {
// save redundant CAS in case this was the last task
break;
}
}
Throwable exc = null;
List<T> results = new ArrayList<>(futures.size());
for (Future<T> future : futures) {
List<T> results = new ArrayList<>(count);
for (int i = 0; i < count; i++) {
javanna marked this conversation as resolved.
Show resolved Hide resolved
Future<T> future = futures.get(i);
try {
results.add(future.get());
} catch (InterruptedException e) {
var newException = new ThreadInterruptedException(e);
if (exc == null) {
exc = newException;
} else {
exc.addSuppressed(newException);
}
exc = IOUtils.useOrSuppress(exc, new ThreadInterruptedException(e));
} catch (ExecutionException e) {
if (exc == null) {
exc = e.getCause();
} else {
exc.addSuppressed(e.getCause());
}
exc = IOUtils.useOrSuppress(exc, e.getCause());
}
}
assert assertAllFuturesCompleted() : "Some tasks are still running?";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ protected LeafSlice[] slices(List<LeafReaderContext> leaves) {
}
};
searcher.search(new MatchAllDocsQuery(), 10);
assertEquals(leaves.size(), numExecutions.get());
assertEquals(leaves.size(), numExecutions.get() + 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor nit: I think it would be more readable to adjust the expected result, than to increment the actual side of the assert:

 assertEquals(leaves.size() - 1, numExecutions.get());

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, this test method was renamed when we introduced unconditional offloading, and it is called testSlicesAllOffloadedToTheExecutor . Given that with this change we no longer offload all slices to the executor, we should probably rename this test method accordingly?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed and moved the 1 to the other side :)

}

public void testNullExecutorNonNullTaskExecutor() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,11 +234,8 @@ public void testInvokeAllDoesNotLeaveTasksBehind() {
TaskExecutor taskExecutor =
new TaskExecutor(
command -> {
executorService.execute(
() -> {
tasksStarted.incrementAndGet();
command.run();
});
tasksStarted.incrementAndGet();
command.run();
});
AtomicInteger tasksExecuted = new AtomicInteger(0);
List<Callable<Void>> callables = new ArrayList<>();
Expand All @@ -258,7 +255,8 @@ public void testInvokeAllDoesNotLeaveTasksBehind() {
expectThrows(RuntimeException.class, () -> taskExecutor.invokeAll(callables));
assertEquals(1, tasksExecuted.get());
// the callables are technically all run, but the cancelled ones will be no-op
assertEquals(100, tasksStarted.get());
// add one for the task the gets executed on the current thread
javanna marked this conversation as resolved.
Show resolved Hide resolved
assertEquals(100, tasksStarted.get() + 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why did this need adapting? Is that expected?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes because you have N - 1 tasks started on the executor now and 1 on the caller thread -> need to add or subtract one, found it easiest to read this way.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is perhaps subjective but I think we should be adapting the expectation as opposed to tweaking the actual value.

Additionally, I think that we should check that the task that gets executed by the caller thread is skipped when the first task throws an exception. Can you add that to the test?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adjusted the expectation. Also changed the test a little to simply fail in case we run any of the additional tasks no matter the thread, they obviously should all be skipped. Together with counting 99 tasks on the executor I believe that tests exactly what you are looking for :)

}

/**
Expand Down Expand Up @@ -308,7 +306,7 @@ public void testInvokeAllCatchesMultipleExceptions() {
}

public void testCancelTasksOnException() {
TaskExecutor taskExecutor = new TaskExecutor(executorService);
TaskExecutor taskExecutor = new TaskExecutor(Runnable::run);
javanna marked this conversation as resolved.
Show resolved Hide resolved
final int numTasks = random().nextInt(10, 50);
final int throwingTask = random().nextInt(numTasks);
boolean error = random().nextBoolean();
Expand Down