diff --git a/src/main/java/org/springframework/data/redis/connection/ClusterCommandExecutor.java b/src/main/java/org/springframework/data/redis/connection/ClusterCommandExecutor.java index 22ad65cdd0..bfcf3a7853 100644 --- a/src/main/java/org/springframework/data/redis/connection/ClusterCommandExecutor.java +++ b/src/main/java/org/springframework/data/redis/connection/ClusterCommandExecutor.java @@ -15,28 +15,13 @@ */ package org.springframework.data.redis.connection; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; -import java.util.Comparator; -import java.util.HashMap; -import java.util.HashSet; -import java.util.Iterator; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; +import java.util.*; import java.util.Map.Entry; -import java.util.Objects; -import java.util.Random; -import java.util.Set; -import java.util.TreeMap; import java.util.concurrent.Callable; import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; -import java.util.function.BiConsumer; import java.util.function.Function; import java.util.stream.Collectors; @@ -111,47 +96,46 @@ public ClusterCommandExecutor(ClusterTopologyProvider topologyProvider, ClusterN /** * Run {@link ClusterCommandCallback} on a random node. * - * @param clusterCommand must not be {@literal null}. + * @param commandCallback must not be {@literal null}. * @return never {@literal null}. */ - public NodeResult executeCommandOnArbitraryNode(ClusterCommandCallback clusterCommand) { + public NodeResult executeCommandOnArbitraryNode(ClusterCommandCallback commandCallback) { - Assert.notNull(clusterCommand, "ClusterCommandCallback must not be null"); + Assert.notNull(commandCallback, "ClusterCommandCallback must not be null"); List nodes = new ArrayList<>(getClusterTopology().getActiveNodes()); RedisClusterNode arbitraryNode = nodes.get(new Random().nextInt(nodes.size())); - return executeCommandOnSingleNode(clusterCommand, arbitraryNode); + return executeCommandOnSingleNode(commandCallback, arbitraryNode); } /** * Run {@link ClusterCommandCallback} on given {@link RedisClusterNode}. * - * @param clusterCommand must not be {@literal null}. + * @param commandCallback must not be {@literal null}. * @param node must not be {@literal null}. * @return the {@link NodeResult} from the single, targeted {@link RedisClusterNode}. * @throws IllegalArgumentException in case no resource can be acquired for given node. */ - public NodeResult executeCommandOnSingleNode(ClusterCommandCallback clusterCommand, + public NodeResult executeCommandOnSingleNode(ClusterCommandCallback commandCallback, RedisClusterNode node) { - return executeCommandOnSingleNode(clusterCommand, node, 0); + return executeCommandOnSingleNode(commandCallback, node, 0); } - private NodeResult executeCommandOnSingleNode(ClusterCommandCallback clusterCommand, + private NodeResult executeCommandOnSingleNode(ClusterCommandCallback commandCallback, RedisClusterNode node, int redirectCount) { - Assert.notNull(clusterCommand, "ClusterCommandCallback must not be null"); + Assert.notNull(commandCallback, "ClusterCommandCallback must not be null"); Assert.notNull(node, "RedisClusterNode must not be null"); if (redirectCount > this.maxRedirects) { - String message = String.format("Cannot follow Cluster Redirects over more than %s legs;" - + " Please consider increasing the number of redirects to follow; Current value is: %s.", - redirectCount, this.maxRedirects); - - throw new TooManyClusterRedirectionsException(message); + throw new TooManyClusterRedirectionsException(String.format( + "Cannot follow Cluster Redirects over more than %s legs; " + + "Consider increasing the number of redirects to follow; Current value is: %s.", + redirectCount, this.maxRedirects)); } RedisClusterNode nodeToUse = lookupNode(node); @@ -161,15 +145,14 @@ private NodeResult executeCommandOnSingleNode(ClusterCommandCallback(node, clusterCommand.doInCluster(client)); + return new NodeResult<>(node, commandCallback.doInCluster(client)); } catch (RuntimeException cause) { RuntimeException translatedException = convertToDataAccessException(cause); if (translatedException instanceof ClusterRedirectException clusterRedirectException) { - return executeCommandOnSingleNode(clusterCommand, topologyProvider.getTopology() - .lookup(clusterRedirectException.getTargetHost(), clusterRedirectException.getTargetPort()), - redirectCount + 1); + return executeCommandOnSingleNode(commandCallback, topologyProvider.getTopology().lookup( + clusterRedirectException.getTargetHost(), clusterRedirectException.getTargetPort()), redirectCount + 1); } else { throw translatedException != null ? translatedException : cause; } @@ -182,7 +165,8 @@ private NodeResult executeCommandOnSingleNode(ClusterCommandCallback MultiNodeResult executeCommandOnAllNodes(ClusterCommandCallback clusterCommand) { - return executeCommandAsyncOnNodes(clusterCommand, getClusterTopology().getActiveMasterNodes()); + public MultiNodeResult executeCommandOnAllNodes(ClusterCommandCallback commandCallback) { + return executeCommandAsyncOnNodes(commandCallback, getClusterTopology().getActiveMasterNodes()); } /** - * @param clusterCommand must not be {@literal null}. + * @param commandCallback must not be {@literal null}. * @param nodes must not be {@literal null}. * @return never {@literal null}. * @throws ClusterCommandExecutionFailureException if a failure occurs while executing the given * {@link ClusterCommandCallback command} on any given {@link RedisClusterNode node}. * @throws IllegalArgumentException in case the node could not be resolved to a topology-known node */ - public MultiNodeResult executeCommandAsyncOnNodes(ClusterCommandCallback clusterCommand, + public MultiNodeResult executeCommandAsyncOnNodes(ClusterCommandCallback commandCallback, Iterable nodes) { - Assert.notNull(clusterCommand, "Callback must not be null"); + Assert.notNull(commandCallback, "Callback must not be null"); Assert.notNull(nodes, "Nodes must not be null"); ClusterTopology topology = this.topologyProvider.getTopology(); @@ -234,7 +218,7 @@ public MultiNodeResult executeCommandAsyncOnNodes(ClusterCommandCallba Map>> futures = new LinkedHashMap<>(); for (RedisClusterNode node : resolvedRedisClusterNodes) { - Callable> nodeCommandExecution = () -> executeCommandOnSingleNode(clusterCommand, node); + Callable> nodeCommandExecution = () -> executeCommandOnSingleNode(commandCallback, node); futures.put(new NodeExecution(node), executor.submit(nodeCommandExecution)); } @@ -243,26 +227,22 @@ public MultiNodeResult executeCommandAsyncOnNodes(ClusterCommandCallba MultiNodeResult collectResults(Map>> futures) { - Map exceptions = new HashMap<>(); + NodeExceptionCollector exceptionCollector = new NodeExceptionCollector(); MultiNodeResult result = new MultiNodeResult<>(); - Set safeguard = new HashSet<>(); + Object placeholder = new Object(); + Map>, Object> safeguard = new IdentityHashMap<>(); - BiConsumer exceptionHandler = getExceptionHandlerFunction(exceptions); - - boolean done = false; - - while (!done) { - - done = true; + for (;;) { + boolean timeout = false; for (Map.Entry>> entry : futures.entrySet()) { NodeExecution nodeExecution = entry.getKey(); Future> futureNodeResult = entry.getValue(); - String futureId = ObjectUtils.getIdentityHexString(futureNodeResult); try { - if (!safeguard.contains(futureId)) { + + if (!safeguard.containsKey(futureNodeResult)) { NodeResult nodeResult = futureNodeResult.get(10L, TimeUnit.MICROSECONDS); @@ -272,39 +252,32 @@ MultiNodeResult collectResults(Map>> result.add(nodeResult); } - safeguard.add(futureId); + safeguard.put(futureNodeResult, placeholder); } } catch (ExecutionException exception) { - safeguard.add(futureId); - exceptionHandler.accept(nodeExecution, exception.getCause()); + safeguard.put(futureNodeResult, placeholder); + exceptionCollector.addException(nodeExecution, exception.getCause()); } catch (TimeoutException ignore) { - done = false; - } catch (InterruptedException cause) { + timeout = true; + } catch (InterruptedException exception) { Thread.currentThread().interrupt(); - exceptionHandler.accept(nodeExecution, cause); + exceptionCollector.addException(nodeExecution, exception); break; } } + + if (!timeout) { + break; + } } - if (!exceptions.isEmpty()) { - throw new ClusterCommandExecutionFailureException(new ArrayList<>(exceptions.values())); + if (exceptionCollector.hasExceptions()) { + throw new ClusterCommandExecutionFailureException(exceptionCollector.getExceptions()); } return result; } - private BiConsumer getExceptionHandlerFunction(Map exceptions) { - - return (nodeExecution, throwable) -> { - - DataAccessException dataAccessException = convertToDataAccessException((Exception) throwable); - Throwable resolvedException = dataAccessException != null ? dataAccessException : throwable; - - exceptions.putIfAbsent(nodeExecution.getNode(), resolvedException); - }; - } - /** * Run {@link MultiKeyClusterCommandCallback} with on a curated set of nodes serving one or more keys. * @@ -331,8 +304,8 @@ public MultiNodeResult executeMultiKeyCommand(MultiKeyClusterCommandCa if (entry.getKey().isMaster()) { for (PositionalKey key : entry.getValue()) { - futures.put(new NodeExecution(entry.getKey(), key), this.executor.submit(() -> - executeMultiKeyCommandOnSingleNode(commandCallback, entry.getKey(), key.getBytes()))); + futures.put(new NodeExecution(entry.getKey(), key), this.executor + .submit(() -> executeMultiKeyCommandOnSingleNode(commandCallback, entry.getKey(), key.getBytes()))); } } } @@ -458,8 +431,8 @@ boolean isPositional() { } /** - * {@link NodeResult} encapsulates the actual {@link T value} returned by a {@link ClusterCommandCallback} - * on a given {@link RedisClusterNode}. + * {@link NodeResult} encapsulates the actual {@link T value} returned by a {@link ClusterCommandCallback} on a given + * {@link RedisClusterNode}. * * @param {@link Class Type} of the {@link Object value} returned in the result. * @author Christoph Strobl @@ -468,9 +441,9 @@ boolean isPositional() { */ public static class NodeResult { - private RedisClusterNode node; - private ByteArrayWrapper key; - private @Nullable T value; + private final RedisClusterNode node; + private final ByteArrayWrapper key; + private final @Nullable T value; /** * Create a new {@link NodeResult}. @@ -551,9 +524,8 @@ public boolean equals(@Nullable Object obj) { return false; } - return ObjectUtils.nullSafeEquals(this.getNode(), that.getNode()) - && Objects.equals(this.key, that.key) - && Objects.equals(this.getValue(), that.getValue()); + return ObjectUtils.nullSafeEquals(this.getNode(), that.getNode()) && Objects.equals(this.key, that.key) + && Objects.equals(this.getValue(), that.getValue()); } @Override @@ -757,8 +729,7 @@ public boolean equals(@Nullable Object obj) { if (!(obj instanceof PositionalKey that)) return false; - return this.getPosition() == that.getPosition() - && ObjectUtils.nullSafeEquals(this.getKey(), that.getKey()); + return this.getPosition() == that.getPosition() && ObjectUtils.nullSafeEquals(this.getKey(), that.getKey()); } @Override @@ -836,4 +807,34 @@ public Iterator iterator() { return this.keys.iterator(); } } + + /** + * Collector for exceptions. Applies translation of exceptions if possible. + */ + private class NodeExceptionCollector { + + private final Map exceptions = new HashMap<>(); + + /** + * @return {@code true} if the collector contains at least one exception. + */ + public boolean hasExceptions() { + return !exceptions.isEmpty(); + } + + public void addException(NodeExecution execution, Throwable throwable) { + + Throwable translated = throwable instanceof Exception e ? convertToDataAccessException(e) : throwable; + Throwable resolvedException = translated != null ? translated : throwable; + + exceptions.putIfAbsent(execution.getNode(), resolvedException); + } + + /** + * @return the collected exceptions. + */ + public List getExceptions() { + return new ArrayList<>(exceptions.values()); + } + } } diff --git a/src/test/java/org/springframework/data/redis/connection/ClusterCommandExecutorUnitTests.java b/src/test/java/org/springframework/data/redis/connection/ClusterCommandExecutorUnitTests.java index 25e3d47562..5d23c25fec 100644 --- a/src/test/java/org/springframework/data/redis/connection/ClusterCommandExecutorUnitTests.java +++ b/src/test/java/org/springframework/data/redis/connection/ClusterCommandExecutorUnitTests.java @@ -15,38 +15,30 @@ */ package org.springframework.data.redis.connection; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatExceptionOfType; -import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; -import static org.mockito.ArgumentMatchers.anyLong; -import static org.mockito.Mockito.any; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.doReturn; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyNoMoreInteractions; -import static org.mockito.Mockito.when; -import static org.springframework.data.redis.test.util.MockitoUtils.verifyInvocationsAcross; - -import java.time.Instant; +import static org.assertj.core.api.Assertions.*; +import static org.mockito.Mockito.*; +import static org.springframework.data.redis.test.util.MockitoUtils.*; + +import edu.umd.cs.mtc.MultithreadedTestCase; +import edu.umd.cs.mtc.TestFramework; + import java.util.Arrays; import java.util.Comparator; import java.util.HashMap; import java.util.HashSet; import java.util.Map; import java.util.Set; -import java.util.concurrent.Callable; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentSkipListMap; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledThreadPoolExecutor; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicInteger; -import java.util.function.Predicate; -import java.util.function.Supplier; +import java.util.function.Consumer; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; @@ -55,11 +47,9 @@ import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; -import org.mockito.stubbing.Answer; - import org.springframework.core.convert.converter.Converter; -import org.springframework.core.task.AsyncTaskExecutor; import org.springframework.core.task.SyncTaskExecutor; +import org.springframework.core.task.support.TaskExecutorAdapter; import org.springframework.dao.DataAccessException; import org.springframework.dao.InvalidDataAccessApiUsageException; import org.springframework.data.redis.ClusterRedirectException; @@ -73,11 +63,6 @@ import org.springframework.data.redis.connection.RedisClusterNode.LinkState; import org.springframework.data.redis.connection.RedisClusterNode.SlotRange; import org.springframework.data.redis.connection.RedisNode.NodeType; -import org.springframework.data.redis.test.util.MockitoUtils; -import org.springframework.scheduling.concurrent.ConcurrentTaskExecutor; - -import edu.umd.cs.mtc.MultithreadedTestCase; -import edu.umd.cs.mtc.TestFramework; /** * Unit Tests for {@link ClusterCommandExecutor}. @@ -85,7 +70,6 @@ * @author Christoph Strobl * @author Mark Paluch * @author John Blum - * @since 1.7 */ @ExtendWith(MockitoExtension.class) class ClusterCommandExecutorUnitTests { @@ -99,34 +83,35 @@ class ClusterCommandExecutorUnitTests { private static final int CLUSTER_NODE_3_PORT = 7381; private static final RedisClusterNode CLUSTER_NODE_1 = RedisClusterNode.newRedisClusterNode() - .listeningAt(CLUSTER_NODE_1_HOST, CLUSTER_NODE_1_PORT) - .serving(new SlotRange(0, 5460)) - .withId("ef570f86c7b1a953846668debc177a3a16733420") - .promotedAs(NodeType.MASTER) - .linkState(LinkState.CONNECTED) - .withName("ClusterNodeX") + .listeningAt(CLUSTER_NODE_1_HOST, CLUSTER_NODE_1_PORT) // + .serving(new SlotRange(0, 5460)) // + .withId("ef570f86c7b1a953846668debc177a3a16733420") // + .promotedAs(NodeType.MASTER) // + .linkState(LinkState.CONNECTED) // + .withName("ClusterNodeX") // .build(); private static final RedisClusterNode CLUSTER_NODE_2 = RedisClusterNode.newRedisClusterNode() - .listeningAt(CLUSTER_NODE_2_HOST, CLUSTER_NODE_2_PORT) - .serving(new SlotRange(5461, 10922)) - .withId("0f2ee5df45d18c50aca07228cc18b1da96fd5e84") - .promotedAs(NodeType.MASTER) - .linkState(LinkState.CONNECTED) - .withName("ClusterNodeY") + .listeningAt(CLUSTER_NODE_2_HOST, CLUSTER_NODE_2_PORT) // + .serving(new SlotRange(5461, 10922)) // + .withId("0f2ee5df45d18c50aca07228cc18b1da96fd5e84") // + .promotedAs(NodeType.MASTER) // + .linkState(LinkState.CONNECTED) // + .withName("ClusterNodeY") // .build(); private static final RedisClusterNode CLUSTER_NODE_3 = RedisClusterNode.newRedisClusterNode() - .listeningAt(CLUSTER_NODE_3_HOST, CLUSTER_NODE_3_PORT) - .serving(new SlotRange(10923, 16383)) - .withId("3b9b8192a874fa8f1f09dbc0ee20afab5738eee7") - .promotedAs(NodeType.MASTER) - .linkState(LinkState.CONNECTED) - .withName("ClusterNodeZ") + .listeningAt(CLUSTER_NODE_3_HOST, CLUSTER_NODE_3_PORT) // + .serving(new SlotRange(10923, 16383)) // + .withId("3b9b8192a874fa8f1f09dbc0ee20afab5738eee7") // + .promotedAs(NodeType.MASTER) // + .linkState(LinkState.CONNECTED) // + .withName("ClusterNodeZ") // .build(); private static final RedisClusterNode CLUSTER_NODE_2_LOOKUP = RedisClusterNode.newRedisClusterNode() - .withId("0f2ee5df45d18c50aca07228cc18b1da96fd5e84").build(); + .withId("0f2ee5df45d18c50aca07228cc18b1da96fd5e84") // + .build(); private static final RedisClusterNode UNKNOWN_CLUSTER_NODE = new RedisClusterNode("8.8.8.8", 7379, SlotRange.empty()); @@ -153,7 +138,8 @@ class ClusterCommandExecutorUnitTests { void setUp() { this.executor = new ClusterCommandExecutor(new MockClusterNodeProvider(), new MockClusterNodeResourceProvider(), - new PassThroughExceptionTranslationStrategy(exceptionConverter), new ImmediateExecutor()); + new PassThroughExceptionTranslationStrategy(exceptionConverter), + new TaskExecutorAdapter(new SyncTaskExecutor())); } @AfterEach @@ -166,7 +152,7 @@ void executeCommandOnSingleNodeShouldBeExecutedCorrectly() { executor.executeCommandOnSingleNode(COMMAND_CALLBACK, CLUSTER_NODE_2); - verify(connection2, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection2).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 @@ -175,25 +161,26 @@ void executeCommandOnSingleNodeByHostAndPortShouldBeExecutedCorrectly() { executor.executeCommandOnSingleNode(COMMAND_CALLBACK, new RedisClusterNode(CLUSTER_NODE_2_HOST, CLUSTER_NODE_2_PORT)); - verify(connection2, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection2).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 + @SuppressWarnings("ConstantConditions") void executeCommandOnSingleNodeByNodeIdShouldBeExecutedCorrectly() { executor.executeCommandOnSingleNode(COMMAND_CALLBACK, new RedisClusterNode(CLUSTER_NODE_2.id)); - verify(connection2, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection2).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 - @SuppressWarnings("all") + @SuppressWarnings("ConstantConditions") void executeCommandOnSingleNodeShouldThrowExceptionWhenNodeIsNull() { assertThatIllegalArgumentException().isThrownBy(() -> executor.executeCommandOnSingleNode(COMMAND_CALLBACK, null)); } @Test // DATAREDIS-315 - @SuppressWarnings("all") + @SuppressWarnings("ConstantConditions") void executeCommandOnSingleNodeShouldThrowExceptionWhenCommandCallbackIsNull() { assertThatIllegalArgumentException().isThrownBy(() -> executor.executeCommandOnSingleNode(null, CLUSTER_NODE_1)); } @@ -207,55 +194,40 @@ void executeCommandOnSingleNodeShouldThrowExceptionWhenNodeIsUnknown() { @Test // DATAREDIS-315 void executeCommandAsyncOnNodesShouldExecuteCommandOnGivenNodes() { - ClusterCommandExecutor executor = new ClusterCommandExecutor(new MockClusterNodeProvider(), - new MockClusterNodeResourceProvider(), new PassThroughExceptionTranslationStrategy(exceptionConverter), - new ConcurrentTaskExecutor(new SyncTaskExecutor())); - executor.executeCommandAsyncOnNodes(COMMAND_CALLBACK, Arrays.asList(CLUSTER_NODE_1, CLUSTER_NODE_2)); - verify(connection1, times(1)).theWheelWeavesAsTheWheelWills(); - verify(connection2, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection1).theWheelWeavesAsTheWheelWills(); + verify(connection2).theWheelWeavesAsTheWheelWills(); verify(connection3, never()).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 void executeCommandAsyncOnNodesShouldExecuteCommandOnGivenNodesByHostAndPort() { - ClusterCommandExecutor executor = new ClusterCommandExecutor(new MockClusterNodeProvider(), - new MockClusterNodeResourceProvider(), new PassThroughExceptionTranslationStrategy(exceptionConverter), - new ConcurrentTaskExecutor(new SyncTaskExecutor())); - executor.executeCommandAsyncOnNodes(COMMAND_CALLBACK, Arrays.asList(new RedisClusterNode(CLUSTER_NODE_1_HOST, CLUSTER_NODE_1_PORT), new RedisClusterNode(CLUSTER_NODE_2_HOST, CLUSTER_NODE_2_PORT))); - verify(connection1, times(1)).theWheelWeavesAsTheWheelWills(); - verify(connection2, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection1).theWheelWeavesAsTheWheelWills(); + verify(connection2).theWheelWeavesAsTheWheelWills(); verify(connection3, never()).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 + @SuppressWarnings("ConstantConditions") void executeCommandAsyncOnNodesShouldExecuteCommandOnGivenNodesByNodeId() { - ClusterCommandExecutor executor = new ClusterCommandExecutor(new MockClusterNodeProvider(), - new MockClusterNodeResourceProvider(), new PassThroughExceptionTranslationStrategy(exceptionConverter), - new ConcurrentTaskExecutor(new SyncTaskExecutor())); - executor.executeCommandAsyncOnNodes(COMMAND_CALLBACK, Arrays.asList(new RedisClusterNode(CLUSTER_NODE_1.id), CLUSTER_NODE_2_LOOKUP)); - verify(connection1, times(1)).theWheelWeavesAsTheWheelWills(); - verify(connection2, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection1).theWheelWeavesAsTheWheelWills(); + verify(connection2).theWheelWeavesAsTheWheelWills(); verify(connection3, never()).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 void executeCommandAsyncOnNodesShouldFailOnGivenUnknownNodes() { - ClusterCommandExecutor executor = new ClusterCommandExecutor(new MockClusterNodeProvider(), - new MockClusterNodeResourceProvider(), new PassThroughExceptionTranslationStrategy(exceptionConverter), - new ConcurrentTaskExecutor(new SyncTaskExecutor())); - assertThatIllegalArgumentException().isThrownBy(() -> executor.executeCommandAsyncOnNodes(COMMAND_CALLBACK, Arrays.asList(new RedisClusterNode("unknown"), CLUSTER_NODE_2_LOOKUP))); } @@ -263,22 +235,19 @@ void executeCommandAsyncOnNodesShouldFailOnGivenUnknownNodes() { @Test // DATAREDIS-315 void executeCommandOnAllNodesShouldExecuteCommandOnEveryKnownClusterNode() { - ClusterCommandExecutor executor = new ClusterCommandExecutor(new MockClusterNodeProvider(), - new MockClusterNodeResourceProvider(), new PassThroughExceptionTranslationStrategy(exceptionConverter), - new ConcurrentTaskExecutor(new SyncTaskExecutor())); - executor.executeCommandOnAllNodes(COMMAND_CALLBACK); - verify(connection1, times(1)).theWheelWeavesAsTheWheelWills(); - verify(connection2, times(1)).theWheelWeavesAsTheWheelWills(); - verify(connection3, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection1).theWheelWeavesAsTheWheelWills(); + verify(connection2).theWheelWeavesAsTheWheelWills(); + verify(connection3).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 void executeCommandAsyncOnNodesShouldCompleteAndCollectErrorsOfAllNodes() { when(connection1.theWheelWeavesAsTheWheelWills()).thenReturn("rand"); - when(connection2.theWheelWeavesAsTheWheelWills()).thenThrow(new IllegalStateException("(error) mat lost the dagger...")); + when(connection2.theWheelWeavesAsTheWheelWills()) + .thenThrow(new IllegalStateException("(error) mat lost the dagger...")); when(connection3.theWheelWeavesAsTheWheelWills()).thenReturn("perrin"); try { @@ -289,9 +258,9 @@ void executeCommandAsyncOnNodesShouldCompleteAndCollectErrorsOfAllNodes() { assertThat(cause.getSuppressed()[0]).isInstanceOf(DataAccessException.class); } - verify(connection1, times(1)).theWheelWeavesAsTheWheelWills(); - verify(connection2, times(1)).theWheelWeavesAsTheWheelWills(); - verify(connection3, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection1).theWheelWeavesAsTheWheelWills(); + verify(connection2).theWheelWeavesAsTheWheelWills(); + verify(connection3).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 @@ -317,8 +286,7 @@ void executeMultikeyCommandShouldRunCommandAcrossCluster() { when(connection3.bloodAndAshes(any(byte[].class))).thenReturn("perrin"); MultiNodeResult result = executor.executeMultiKeyCommand(MULTIKEY_CALLBACK, - new HashSet<>( - Arrays.asList("key-1".getBytes(), "key-2".getBytes(), "key-3".getBytes(), "key-9".getBytes()))); + new HashSet<>(Arrays.asList("key-1".getBytes(), "key-2".getBytes(), "key-3".getBytes(), "key-9".getBytes()))); assertThat(result.resultsAsList()).contains("rand", "mat", "perrin", "egwene"); @@ -329,32 +297,35 @@ void executeMultikeyCommandShouldRunCommandAcrossCluster() { @Test // DATAREDIS-315 void executeCommandOnSingleNodeAndFollowRedirect() { - when(connection1.theWheelWeavesAsTheWheelWills()).thenThrow(new MovedException(CLUSTER_NODE_3_HOST, CLUSTER_NODE_3_PORT)); + when(connection1.theWheelWeavesAsTheWheelWills()) + .thenThrow(new MovedException(CLUSTER_NODE_3_HOST, CLUSTER_NODE_3_PORT)); executor.executeCommandOnSingleNode(COMMAND_CALLBACK, CLUSTER_NODE_1); - verify(connection1, times(1)).theWheelWeavesAsTheWheelWills(); - verify(connection3, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection1).theWheelWeavesAsTheWheelWills(); + verify(connection3).theWheelWeavesAsTheWheelWills(); verify(connection2, never()).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 void executeCommandOnSingleNodeAndFollowRedirectButStopsAfterMaxRedirects() { - when(connection1.theWheelWeavesAsTheWheelWills()).thenThrow(new MovedException(CLUSTER_NODE_3_HOST, CLUSTER_NODE_3_PORT)); - when(connection3.theWheelWeavesAsTheWheelWills()).thenThrow(new MovedException(CLUSTER_NODE_2_HOST, CLUSTER_NODE_2_PORT)); - when(connection2.theWheelWeavesAsTheWheelWills()).thenThrow(new MovedException(CLUSTER_NODE_1_HOST, CLUSTER_NODE_1_PORT)); + when(connection1.theWheelWeavesAsTheWheelWills()) + .thenThrow(new MovedException(CLUSTER_NODE_3_HOST, CLUSTER_NODE_3_PORT)); + when(connection3.theWheelWeavesAsTheWheelWills()) + .thenThrow(new MovedException(CLUSTER_NODE_2_HOST, CLUSTER_NODE_2_PORT)); + when(connection2.theWheelWeavesAsTheWheelWills()) + .thenThrow(new MovedException(CLUSTER_NODE_1_HOST, CLUSTER_NODE_1_PORT)); - try { - executor.setMaxRedirects(4); + executor.setMaxRedirects(4); + + assertThatExceptionOfType(TooManyClusterRedirectionsException.class).isThrownBy(() -> { executor.executeCommandOnSingleNode(COMMAND_CALLBACK, CLUSTER_NODE_1); - } catch (Exception e) { - assertThat(e).isInstanceOf(TooManyClusterRedirectionsException.class); - } + }); verify(connection1, times(2)).theWheelWeavesAsTheWheelWills(); verify(connection3, times(2)).theWheelWeavesAsTheWheelWills(); - verify(connection2, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection2).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 @@ -362,83 +333,45 @@ void executeCommandOnArbitraryNodeShouldPickARandomNode() { executor.executeCommandOnArbitraryNode(COMMAND_CALLBACK); - verifyInvocationsAcross("theWheelWeavesAsTheWheelWills", times(1), connection1, connection2, connection3); + verifyInvocationsAcross("theWheelWeavesAsTheWheelWills", connection1, connection2, connection3); } @Test // GH-2518 - void collectResultsCompletesSuccessfully() { - - Instant done = Instant.now().plusMillis(5); - - Predicate>> isDone = future -> Instant.now().isAfter(done); + void collectResultsCompletesSuccessfullyAfterTimeouts() { Map>> futures = new HashMap<>(); - NodeResult nodeOneA = newNodeResult(CLUSTER_NODE_1, "A"); - NodeResult nodeTwoB = newNodeResult(CLUSTER_NODE_2, "B"); - NodeResult nodeThreeC = newNodeResult(CLUSTER_NODE_3, "C"); - - futures.put(newNodeExecution(CLUSTER_NODE_1), mockFutureAndIsDone(nodeOneA, isDone)); - futures.put(newNodeExecution(CLUSTER_NODE_2), mockFutureAndIsDone(nodeTwoB, isDone)); - futures.put(newNodeExecution(CLUSTER_NODE_3), mockFutureAndIsDone(nodeThreeC, isDone)); + NodeResult nodeOneA = new NodeResult<>(CLUSTER_NODE_1, "A"); + NodeResult nodeTwoB = new NodeResult<>(CLUSTER_NODE_2, "B"); + NodeResult nodeThreeC = new NodeResult<>(CLUSTER_NODE_3, "C"); - MultiNodeResult results = this.executor.collectResults(futures); + doWithScheduler(scheduler -> { - assertThat(results).isNotNull(); - assertThat(results.getResults()).containsExactlyInAnyOrder(nodeOneA, nodeTwoB, nodeThreeC); - - futures.values().forEach(future -> - runsSafely(() -> verify(future, times(1)).get(anyLong(), any(TimeUnit.class)))); - } + futures.put(new NodeExecution(CLUSTER_NODE_1), scheduler.schedule(() -> nodeOneA, 15, TimeUnit.MILLISECONDS)); + futures.put(new NodeExecution(CLUSTER_NODE_2), scheduler.schedule(() -> nodeTwoB, 15, TimeUnit.MILLISECONDS)); + futures.put(new NodeExecution(CLUSTER_NODE_3), scheduler.schedule(() -> nodeThreeC, 15, TimeUnit.MILLISECONDS)); - @Test // GH-2518 - void collectResultsCompletesSuccessfullyEvenWithTimeouts() throws Exception { - - Map>> futures = new HashMap<>(); + MultiNodeResult results = this.executor.collectResults(futures); - NodeResult nodeOneA = newNodeResult(CLUSTER_NODE_1, "A"); - NodeResult nodeTwoB = newNodeResult(CLUSTER_NODE_2, "B"); - NodeResult nodeThreeC = newNodeResult(CLUSTER_NODE_3, "C"); - - Future> nodeOneFutureResult = mockFutureThrowingTimeoutException(nodeOneA, 4); - Future> nodeTwoFutureResult = mockFutureThrowingTimeoutException(nodeTwoB, 1); - Future> nodeThreeFutureResult = mockFutureThrowingTimeoutException(nodeThreeC, 2); - - futures.put(newNodeExecution(CLUSTER_NODE_1), nodeOneFutureResult); - futures.put(newNodeExecution(CLUSTER_NODE_2), nodeTwoFutureResult); - futures.put(newNodeExecution(CLUSTER_NODE_3), nodeThreeFutureResult); - - MultiNodeResult results = this.executor.collectResults(futures); - - assertThat(results).isNotNull(); - assertThat(results.getResults()).containsExactlyInAnyOrder(nodeOneA, nodeTwoB, nodeThreeC); - - verify(nodeOneFutureResult, times(4)).get(anyLong(), any(TimeUnit.class)); - verify(nodeTwoFutureResult, times(1)).get(anyLong(), any(TimeUnit.class)); - verify(nodeThreeFutureResult, times(2)).get(anyLong(), any(TimeUnit.class)); - verifyNoMoreInteractions(nodeOneFutureResult, nodeTwoFutureResult, nodeThreeFutureResult); + assertThat(results).isNotNull(); + assertThat(results.getResults()).containsExactlyInAnyOrder(nodeOneA, nodeTwoB, nodeThreeC); + }); } @Test // GH-2518 void collectResultsFailsWithExecutionException() { Map>> futures = new HashMap<>(); + NodeResult nodeOneA = new NodeResult<>(CLUSTER_NODE_1, "A"); - NodeResult nodeOneA = newNodeResult(CLUSTER_NODE_1, "A"); - - futures.put(newNodeExecution(CLUSTER_NODE_1), mockFutureAndIsDone(nodeOneA, future -> true)); - futures.put(newNodeExecution(CLUSTER_NODE_2), mockFutureThrowingExecutionException( - new ExecutionException("TestError", new IllegalArgumentException("MockError")))); + futures.put(new NodeExecution(CLUSTER_NODE_1), CompletableFuture.completedFuture(nodeOneA)); + futures.put(new NodeExecution(CLUSTER_NODE_2), + CompletableFuture.failedFuture(new IllegalArgumentException("MockError"))); assertThatExceptionOfType(ClusterCommandExecutionFailureException.class) - .isThrownBy(() -> this.executor.collectResults(futures)) - .withMessage("MockError") - .withCauseInstanceOf(InvalidDataAccessApiUsageException.class) - .extracting(Throwable::getCause) - .extracting(Throwable::getCause) - .isInstanceOf(IllegalArgumentException.class) - .extracting(Throwable::getMessage) - .isEqualTo("MockError"); + .isThrownBy(() -> this.executor.collectResults(futures)) // + .withMessage("MockError") // + .withRootCauseInstanceOf(IllegalArgumentException.class); } @Test // GH-2518 @@ -446,38 +379,8 @@ void collectResultsFailsWithInterruptedException() throws Throwable { TestFramework.runOnce(new CollectResultsInterruptedMultithreadedTestCase(this.executor)); } - // Future.get() for X will get called twice if at least one other Future is not done and Future.get() for X - // threw an ExecutionException in the previous iteration, thereby marking it as done! @Test // GH-2518 - @SuppressWarnings("all") - void collectResultsCallsFutureGetOnlyOnce() throws Exception { - - AtomicInteger count = new AtomicInteger(0); - Map>> futures = new HashMap<>(); - - Future> clusterNodeOneFutureResult = mockFutureAndIsDone(null, future -> - count.incrementAndGet() % 2 == 0); - - Future> clusterNodeTwoFutureResult = mockFutureThrowingExecutionException( - new ExecutionException("TestError", new IllegalArgumentException("MockError"))); - - futures.put(newNodeExecution(CLUSTER_NODE_1), clusterNodeOneFutureResult); - futures.put(newNodeExecution(CLUSTER_NODE_2), clusterNodeTwoFutureResult); - - assertThatExceptionOfType(ClusterCommandExecutionFailureException.class) - .isThrownBy(() -> this.executor.collectResults(futures)); - - verify(clusterNodeOneFutureResult, times(1)).get(anyLong(), any()); - verify(clusterNodeTwoFutureResult, times(1)).get(anyLong(), any()); - } - - // Covers the case where Future.get() is mistakenly called multiple times, or if the Future.isDone() implementation - // does not properly take into account Future.get() throwing an ExecutionException during computation subsequently - // returning false instead of true. - // This should be properly handled by the "safeguard" (see collectResultsCallsFutureGetOnlyOnce()), but... - // just in case! The ExecutionException handler now stores the [DataAccess]Exception with Map.putIfAbsent(..). - @Test // GH-2518 - @SuppressWarnings("all") + @SuppressWarnings("ConstantConditions") void collectResultsCapturesFirstExecutionExceptionOnly() { AtomicInteger count = new AtomicInteger(0); @@ -485,85 +388,54 @@ void collectResultsCapturesFirstExecutionExceptionOnly() { Map>> futures = new HashMap<>(); - futures.put(newNodeExecution(CLUSTER_NODE_1), - mockFutureAndIsDone(null, future -> count.incrementAndGet() % 2 == 0)); + CompletableFuture> doneLater = new CompletableFuture<>() { - futures.put(newNodeExecution(CLUSTER_NODE_2), mockFutureThrowingExecutionException(() -> - new ExecutionException("TestError", new IllegalStateException("MockError" + exceptionCount.getAndIncrement())))); + @Override + public NodeResult get(long timeout, TimeUnit unit) throws TimeoutException { - assertThatExceptionOfType(ClusterCommandExecutionFailureException.class) - .isThrownBy(() -> this.executor.collectResults(futures)) - .withMessage("MockError0") - .withCauseInstanceOf(InvalidDataAccessApiUsageException.class) - .extracting(Throwable::getCause) - .extracting(Throwable::getCause) - .isInstanceOf(IllegalStateException.class) - .extracting(Throwable::getMessage) - .isEqualTo("MockError0"); - } - - private Future mockFutureAndIsDone(T result, Predicate> isDone) { - - return MockitoUtils.mockFuture(result, future -> { - doAnswer(invocation -> isDone.test(future)).when(future).isDone(); - return future; - }); - } - - private Future mockFutureThrowingExecutionException(ExecutionException exception) { - return mockFutureThrowingExecutionException(() -> exception); - } - - private Future mockFutureThrowingExecutionException(Supplier exceptionSupplier) { - - Answer getAnswer = invocationOnMock -> { throw exceptionSupplier.get(); }; - - return MockitoUtils.mockFuture(null, future -> { - doReturn(true).when(future).isDone(); - doAnswer(getAnswer).when(future).get(); - doAnswer(getAnswer).when(future).get(anyLong(), any()); - return future; - }); - } - - @SuppressWarnings("unchecked") - private Future mockFutureThrowingTimeoutException(T result, int timeoutCount) { + if (count.incrementAndGet() % 2 == 0) { + return null; + } + throw new TimeoutException(); + } + }; - AtomicInteger counter = new AtomicInteger(timeoutCount); + CompletableFuture> alwaysFail = new CompletableFuture<>() { - Answer getAnswer = invocationOnMock -> { + @Override + public NodeResult get(long timeout, TimeUnit unit) throws ExecutionException { - if (counter.decrementAndGet() > 0) { - throw new TimeoutException("TIMES UP"); + throw new ExecutionException("TestError", + new IllegalStateException("MockError" + exceptionCount.getAndIncrement())); } - - doReturn(true).when((Future>) invocationOnMock.getMock()).isDone(); - - return result; }; - return MockitoUtils.mockFuture(result, future -> { - - doAnswer(getAnswer).when(future).get(); - doAnswer(getAnswer).when(future).get(anyLong(), any()); + futures.put(new NodeExecution(CLUSTER_NODE_1), doneLater); + futures.put(new NodeExecution(CLUSTER_NODE_2), alwaysFail); - return future; - }); - } - - private NodeExecution newNodeExecution(RedisClusterNode clusterNode) { - return new NodeExecution(clusterNode); - } + assertThatExceptionOfType(ClusterCommandExecutionFailureException.class) + .isThrownBy(() -> this.executor.collectResults(futures)) // + .withMessage("MockError0") // + .withCauseInstanceOf(InvalidDataAccessApiUsageException.class) + .withRootCauseInstanceOf(IllegalStateException.class); - private NodeResult newNodeResult(RedisClusterNode clusterNode, T value) { - return new NodeResult<>(clusterNode, value); + assertThat(exceptionCount).hasValue(1); } - private void runsSafely(ThrowableOperation operation) { + /** + * Performs the given action within the scope of a running {@link ScheduledExecutorService}. The scheduler is only + * valid during the callback and shut down after this method returns. + * + * @param callback the action to invoke. + */ + private void doWithScheduler(Consumer callback) { + ScheduledExecutorService scheduler = new ScheduledThreadPoolExecutor(3); try { - operation.run(); - } catch (Throwable ignore) { } + callback.accept(scheduler); + } finally { + scheduler.shutdown(); + } } static class MockClusterNodeProvider implements ClusterTopologyProvider { @@ -581,28 +453,16 @@ class MockClusterNodeResourceProvider implements ClusterNodeResourceProvider { public Connection getResourceForSpecificNode(RedisClusterNode clusterNode) { return CLUSTER_NODE_1.equals(clusterNode) ? connection1 - : CLUSTER_NODE_2.equals(clusterNode) ? connection2 - : CLUSTER_NODE_3.equals(clusterNode) ? connection3 - : null; + : CLUSTER_NODE_2.equals(clusterNode) ? connection2 : CLUSTER_NODE_3.equals(clusterNode) ? connection3 : null; } @Override - public void returnResourceForSpecificNode(RedisClusterNode node, Object resource) { - } + public void returnResourceForSpecificNode(RedisClusterNode node, Object resource) {} } - interface ConnectionCommandCallback extends ClusterCommandCallback { + interface ConnectionCommandCallback extends ClusterCommandCallback {} - } - - interface MultiKeyConnectionCommandCallback extends MultiKeyClusterCommandCallback { - - } - - @FunctionalInterface - interface ThrowableOperation { - void run() throws Throwable; - } + interface MultiKeyConnectionCommandCallback extends MultiKeyClusterCommandCallback {} interface Connection { @@ -622,47 +482,13 @@ static class MovedException extends RuntimeException { } } - static class ImmediateExecutor implements AsyncTaskExecutor { - - @Override - public void execute(Runnable runnable) { - runnable.run(); - } - - @Override - public Future submit(Runnable runnable) { - - return submit(() -> { - - runnable.run(); - - return null; - }); - } - - @Override - public Future submit(Callable callable) { - - try { - return CompletableFuture.completedFuture(callable.call()); - } catch (Exception cause) { - - CompletableFuture future = new CompletableFuture<>(); - - future.completeExceptionally(cause); - - return future; - } - } - } - @SuppressWarnings("all") static class CollectResultsInterruptedMultithreadedTestCase extends MultithreadedTestCase { private static final CountDownLatch latch = new CountDownLatch(1); - private static final Comparator NODE_COMPARATOR = - Comparator.comparing(nodeExecution -> nodeExecution.getNode().getName()); + private static final Comparator NODE_COMPARATOR = Comparator + .comparing(nodeExecution -> nodeExecution.getNode().getName()); private final ClusterCommandExecutor clusterCommandExecutor; @@ -681,26 +507,18 @@ private CollectResultsInterruptedMultithreadedTestCase(ClusterCommandExecutor cl @Override public void initialize() { - super.initialize(); - this.mockNodeOneFutureResult = this.futureNodeResults.computeIfAbsent(new NodeExecution(CLUSTER_NODE_1), - nodeExecution -> MockitoUtils.mockFuture(null, mockFuture -> { - doReturn(false).when(mockFuture).isDone(); - return mockFuture; - })); + nodeExecution -> new CompletableFuture<>()); this.mockNodeTwoFutureResult = this.futureNodeResults.computeIfAbsent(new NodeExecution(CLUSTER_NODE_2), - nodeExecution -> MockitoUtils.mockFuture(null, mockFuture -> { - - doReturn(true).when(mockFuture).isDone(); - - doAnswer(invocation -> { + nodeExecution -> new CompletableFuture<>() { + @Override + public NodeResult get(long timeout, TimeUnit unit) + throws InterruptedException, TimeoutException, ExecutionException { latch.await(); - return null; - }).when(mockFuture).get(anyLong(), any()); - - return mockFuture; - })); + return super.get(timeout, unit); + } + }); } public void thread1() { @@ -711,7 +529,7 @@ public void thread1() { this.collectResultsThread.setName("CollectResults Thread"); assertThatExceptionOfType(ClusterCommandExecutionFailureException.class) - .isThrownBy(() -> this.clusterCommandExecutor.collectResults(this.futureNodeResults)); + .isThrownBy(() -> this.clusterCommandExecutor.collectResults(this.futureNodeResults)); assertThat(this.collectResultsThread.isInterrupted()).isTrue(); } @@ -729,16 +547,5 @@ public void thread2() { this.collectResultsThread.interrupt(); } - - @Override - public void finish() { - - try { - verify(this.mockNodeOneFutureResult, never()).get(); - verify(this.mockNodeTwoFutureResult, times(1)).get(anyLong(), any()); - } catch (ExecutionException | InterruptedException | TimeoutException cause) { - throw new RuntimeException(cause); - } - } } } diff --git a/src/test/java/org/springframework/data/redis/test/util/MockitoUtils.java b/src/test/java/org/springframework/data/redis/test/util/MockitoUtils.java index 318d47ec85..c70d6e8ad5 100644 --- a/src/test/java/org/springframework/data/redis/test/util/MockitoUtils.java +++ b/src/test/java/org/springframework/data/redis/test/util/MockitoUtils.java @@ -15,33 +15,17 @@ */ package org.springframework.data.redis.test.util; -import static org.mockito.Mockito.anyBoolean; -import static org.mockito.Mockito.anyLong; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.isA; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.mockingDetails; -import static org.mockito.Mockito.withSettings; +import static org.mockito.Mockito.*; import java.util.ArrayList; import java.util.Collections; import java.util.List; -import java.util.concurrent.CancellationException; -import java.util.concurrent.Future; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.Function; import org.mockito.internal.invocation.InvocationMatcher; import org.mockito.internal.verification.api.VerificationData; import org.mockito.invocation.Invocation; import org.mockito.invocation.MatchableInvocation; -import org.mockito.quality.Strictness; -import org.mockito.stubbing.Answer; import org.mockito.verification.VerificationMode; - -import org.springframework.lang.NonNull; -import org.springframework.lang.Nullable; import org.springframework.util.StringUtils; /** @@ -56,80 +40,13 @@ public abstract class MockitoUtils { /** - * Creates a mock {@link Future} returning the given {@link Object result}. - * - * @param {@link Class type} of {@link Object result} returned by the mock {@link Future}. - * @param result {@link Object value} returned as the {@literal result} of the mock {@link Future}. - * @return a new mock {@link Future}. - * @see java.util.concurrent.Future - */ - @SuppressWarnings("unchecked") - public static @NonNull Future mockFuture(@Nullable T result) { - - try { - - AtomicBoolean cancelled = new AtomicBoolean(false); - AtomicBoolean done = new AtomicBoolean(false); - - Future mockFuture = mock(Future.class, withSettings().strictness(Strictness.LENIENT)); - - // A Future can only be cancelled if not done, it was not already cancelled, and no error occurred. - // The cancel(..) logic is not Thread-safe due to compound actions involving multiple variables. - // However, the cancel(..) logic does not necessarily need to be Thread-safe given the task execution - // of a Future is asynchronous and cancellation is driven by Thread interrupt from another Thread. - Answer cancelAnswer = invocation -> !done.get() - && cancelled.compareAndSet(done.get(), true) - && done.compareAndSet(done.get(), true); - - Answer getAnswer = invocation -> { - - // The Future is done no matter if it returns the result or was cancelled/interrupted. - done.set(true); - - if (Thread.currentThread().isInterrupted()) { - throw new InterruptedException("Thread was interrupted"); - } - - if (cancelled.get()) { - throw new CancellationException("Task was cancelled"); - } - - return result; - }; - - doAnswer(invocation -> cancelled.get()).when(mockFuture).isCancelled(); - doAnswer(invocation -> done.get()).when(mockFuture).isDone(); - doAnswer(cancelAnswer).when(mockFuture).cancel(anyBoolean()); - doAnswer(getAnswer).when(mockFuture).get(); - doAnswer(getAnswer).when(mockFuture).get(anyLong(), isA(TimeUnit.class)); - - return mockFuture; - } - catch (Exception cause) { - String message = String.format("Failed to create a mock of Future having result [%s]", result); - throw new IllegalStateException(message, cause); - } - } - - /** - * Creates a mock {@link Future} returning the given {@link Object result}, customized with the given, - * required {@link Function}. + * Verifies a given method is called once across all given mocks. * - * @param {@link Class type} of {@link Object result} returned by the mock {@link Future}. - * @param result {@link Object value} returned as the {@literal result} of the mock {@link Future}. - * @param futureFunction {@link Function} used to customize the mock {@link Future} on creation; - * must not be {@literal null}. - * @return a new mock {@link Future}. - * @see java.util.concurrent.Future - * @see java.util.function.Function - * @see #mockFuture(Object) + * @param method {@link String name} of a {@link java.lang.reflect.Method} on the {@link Object mock object}. + * @param mocks array of {@link Object mock objects} to verify. */ - public static @NonNull Future mockFuture(@Nullable T result, - @NonNull ThrowableFunction, Future> futureFunction) { - - Future mockFuture = mockFuture(result); - - return futureFunction.apply(mockFuture); + public static void verifyInvocationsAcross(String method, Object... mocks) { + verifyInvocationsAcross(method, times(1), mocks); } /** @@ -141,19 +58,19 @@ public abstract class MockitoUtils { */ public static void verifyInvocationsAcross(String method, VerificationMode mode, Object... mocks) { - mode.verify(new VerificationDataImpl(getInvocations(method, mocks), new InvocationMatcher(null, Collections - .singletonList(org.mockito.internal.matchers.Any.ANY)) { + mode.verify(new VerificationDataImpl(getInvocations(method, mocks), + new InvocationMatcher(null, Collections.singletonList(org.mockito.internal.matchers.Any.ANY)) { - @Override - public boolean matches(Invocation actual) { - return true; - } + @Override + public boolean matches(Invocation actual) { + return true; + } - @Override - public String toString() { - return String.format("%s for method: %s", mode, method); - } - })); + @Override + public String toString() { + return String.format("%s for method: %s", mode, method); + } + })); } private static List getInvocations(String method, Object... mocks) { @@ -196,29 +113,4 @@ public MatchableInvocation getTarget() { } } - @FunctionalInterface - public interface ThrowableFunction extends Function { - - @Override - default R apply(T target) { - - try { - return applyThrowingException(target); - } - catch (Throwable cause) { - String message = String.format("Failed to apply Function [%s] to target [%s]", this, target); - throw new IllegalStateException(message, cause); - } - } - - R applyThrowingException(T target) throws Throwable; - - @SuppressWarnings("unchecked") - default @NonNull ThrowableFunction andThen( - @Nullable ThrowableFunction after) { - - return after == null ? (ThrowableFunction) this - : target -> after.apply(this.apply(target)); - } - } }