diff --git a/src/main/java/io/lettuce/core/internal/AsyncConnectionProvider.java b/src/main/java/io/lettuce/core/internal/AsyncConnectionProvider.java index 4614405ade..02e5b36133 100644 --- a/src/main/java/io/lettuce/core/internal/AsyncConnectionProvider.java +++ b/src/main/java/io/lettuce/core/internal/AsyncConnectionProvider.java @@ -145,7 +145,9 @@ public CompletableFuture close() { for (K k : connections.keySet()) { Sync remove = connections.remove(k); if (remove != null) { - remove.doWithConnection(e -> futures.add(e.closeAsync())); + CompletionStage closeFuture = remove.future.thenAccept(AsyncCloseable::closeAsync); + // always synchronously add the future, made it immutably in Futures.allOf() + futures.add(closeFuture.toCompletableFuture()); } } @@ -218,7 +220,6 @@ static class Sync> { @SuppressWarnings("unchecked") public Sync(K key, F future) { - this.key = key; this.future = (F) future.whenComplete((connection, throwable) -> { diff --git a/src/test/java/io/lettuce/core/internal/AsyncConnectionProviderTest.java b/src/test/java/io/lettuce/core/internal/AsyncConnectionProviderTest.java new file mode 100644 index 0000000000..80e7f9ee10 --- /dev/null +++ b/src/test/java/io/lettuce/core/internal/AsyncConnectionProviderTest.java @@ -0,0 +1,99 @@ +package io.lettuce.core.internal; + +import org.junit.Assert; +import org.junit.Test; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class AsyncConnectionProviderTest { + + @Test + public void testFutureListLength() throws InterruptedException, ExecutionException, TimeoutException { + + CountDownLatch slowCreate = new CountDownLatch(1); + CountDownLatch slowShutdown = new CountDownLatch(1); + + // create a provider with a slow connection creation + AsyncConnectionProvider> provider = new AsyncConnectionProvider<>( + key -> { + return countDownFuture(slowCreate, new io.lettuce.core.api.AsyncCloseable() { + + @Override + public CompletableFuture closeAsync() { + return CompletableFuture.completedFuture(null); + } + + }); + }); + + // add slow shutdown connection first + SlowCloseFuture slowCloseFuture = new SlowCloseFuture(slowShutdown); + provider.register("slowShutdown", new io.lettuce.core.api.AsyncCloseable() { + + @Override + public CompletableFuture closeAsync() { + return slowCloseFuture; + } + + }); + + // add slow creation connection + CompletableFuture createFuture = provider.getConnection("slowCreate"); + + // close the connection. + CompletableFuture closeFuture = provider.close(); + + // the connection has not been created yet, so the close futures array always has 1 element + // we block the iterator on the slowCloseFuture + // then we count down the creation, the close future will be added to the list + slowCreate.countDown(); + + // the close future is added to the list, we unblock the iterator + slowShutdown.countDown(); + + // assert close future is completed, and no exceptions are thrown + closeFuture.get(10, TimeUnit.SECONDS); + Assert.assertTrue(createFuture.isDone()); + } + + private CompletableFuture countDownFuture(CountDownLatch countDownLatch, T value) { + return CompletableFuture.runAsync(() -> { + try { + countDownLatch.await(1, TimeUnit.SECONDS); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + }).thenApply(v -> value); + } + + static class SlowCloseFuture extends CompletableFuture { + + private final CountDownLatch countDownLatch; + + SlowCloseFuture(CountDownLatch countDownLatch) { + this.countDownLatch = countDownLatch; + } + + @Override + public CompletableFuture toCompletableFuture() { + // we block the iterator on here + try { + countDownLatch.await(1, TimeUnit.SECONDS); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + return super.toCompletableFuture(); + } + + @Override + public Void get() { + return null; + } + + } + +}