From 06c0d61a164b8b2e3d0e24121075aaf53ddad447 Mon Sep 17 00:00:00 2001 From: John Blum Date: Wed, 8 Mar 2023 15:34:31 -0800 Subject: [PATCH] Conditionally wrap Thread.sleep(..) call in collectResults(..). Keep Thread.sleep(..) to avoid the busy (hot) loop when the Future executions are not yet complete. Add test coverage for ClusterCommandExecutor collectResults(..) method. Cleanup compiler warnings in ClusterCommandExecutorUnitTests. Closes #2518 --- .../connection/ClusterCommandExecutor.java | 243 ++++++---- .../redis/connection/ClusterTopology.java | 104 ++--- .../ClusterCommandExecutorUnitTests.java | 438 ++++++++++++++---- .../data/redis/test/util/MockitoUtils.java | 142 +++++- 4 files changed, 679 insertions(+), 248 deletions(-) diff --git a/src/main/java/org/springframework/data/redis/connection/ClusterCommandExecutor.java b/src/main/java/org/springframework/data/redis/connection/ClusterCommandExecutor.java index f016dfc33f..f9ef2405f4 100644 --- a/src/main/java/org/springframework/data/redis/connection/ClusterCommandExecutor.java +++ b/src/main/java/org/springframework/data/redis/connection/ClusterCommandExecutor.java @@ -15,10 +15,25 @@ */ package org.springframework.data.redis.connection; -import java.util.*; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; import java.util.Map.Entry; +import java.util.Objects; +import java.util.Random; +import java.util.Set; +import java.util.TreeMap; import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; +import java.util.function.BiConsumer; import java.util.function.Function; import java.util.stream.Collectors; @@ -43,6 +58,7 @@ * * @author Christoph Strobl * @author Mark Paluch + * @author John Blum * @since 1.7 */ public class ClusterCommandExecutor implements DisposableBean { @@ -58,7 +74,7 @@ public class ClusterCommandExecutor implements DisposableBean { private final ExceptionTranslationStrategy exceptionTranslationStrategy; /** - * Create a new instance of {@link ClusterCommandExecutor}. + * Create a new {@link ClusterCommandExecutor}. * * @param topologyProvider must not be {@literal null}. * @param resourceProvider must not be {@literal null}. @@ -92,40 +108,47 @@ public ClusterCommandExecutor(ClusterTopologyProvider topologyProvider, ClusterN /** * Run {@link ClusterCommandCallback} on a random node. * - * @param commandCallback must not be {@literal null}. + * @param clusterCommand must not be {@literal null}. * @return never {@literal null}. */ - public NodeResult executeCommandOnArbitraryNode(ClusterCommandCallback commandCallback) { + public NodeResult executeCommandOnArbitraryNode(ClusterCommandCallback clusterCommand) { - Assert.notNull(commandCallback, "ClusterCommandCallback must not be null"); + Assert.notNull(clusterCommand, "ClusterCommandCallback must not be null"); List nodes = new ArrayList<>(getClusterTopology().getActiveNodes()); - return executeCommandOnSingleNode(commandCallback, nodes.get(new Random().nextInt(nodes.size()))); + RedisClusterNode arbitraryNode = nodes.get(new Random().nextInt(nodes.size())); + + return executeCommandOnSingleNode(clusterCommand, arbitraryNode); } /** * Run {@link ClusterCommandCallback} on given {@link RedisClusterNode}. * - * @param cmd must not be {@literal null}. + * @param clusterCommand must not be {@literal null}. * @param node must not be {@literal null}. * @return the {@link NodeResult} from the single, targeted {@link RedisClusterNode}. * @throws IllegalArgumentException in case no resource can be acquired for given node. */ - public NodeResult executeCommandOnSingleNode(ClusterCommandCallback cmd, RedisClusterNode node) { - return executeCommandOnSingleNode(cmd, node, 0); + public NodeResult executeCommandOnSingleNode(ClusterCommandCallback clusterCommand, + RedisClusterNode node) { + + return executeCommandOnSingleNode(clusterCommand, node, 0); } - private NodeResult executeCommandOnSingleNode(ClusterCommandCallback cmd, RedisClusterNode node, - int redirectCount) { + private NodeResult executeCommandOnSingleNode(ClusterCommandCallback clusterCommand, + RedisClusterNode node, int redirectCount) { - Assert.notNull(cmd, "ClusterCommandCallback must not be null"); + Assert.notNull(clusterCommand, "ClusterCommandCallback must not be null"); Assert.notNull(node, "RedisClusterNode must not be null"); - if (redirectCount > maxRedirects) { - throw new TooManyClusterRedirectionsException(String.format( - "Cannot follow Cluster Redirects over more than %s legs; Please consider increasing the number of redirects to follow; Current value is: %s.", - redirectCount, maxRedirects)); + if (redirectCount > this.maxRedirects) { + + String message = String.format("Cannot follow Cluster Redirects over more than %s legs;" + + " Please consider increasing the number of redirects to follow; Current value is: %s.", + redirectCount, this.maxRedirects); + + throw new TooManyClusterRedirectionsException(message); } RedisClusterNode nodeToUse = lookupNode(node); @@ -135,14 +158,15 @@ private NodeResult executeCommandOnSingleNode(ClusterCommandCallback(node, cmd.doInCluster(client)); + return new NodeResult<>(node, clusterCommand.doInCluster(client)); } catch (RuntimeException cause) { RuntimeException translatedException = convertToDataAccessException(cause); if (translatedException instanceof ClusterRedirectException clusterRedirectException) { - return executeCommandOnSingleNode(cmd, topologyProvider.getTopology().lookup( - clusterRedirectException.getTargetHost(), clusterRedirectException.getTargetPort()), redirectCount + 1); + return executeCommandOnSingleNode(clusterCommand, topologyProvider.getTopology() + .lookup(clusterRedirectException.getTargetHost(), clusterRedirectException.getTargetPort()), + redirectCount + 1); } else { throw translatedException != null ? translatedException : cause; } @@ -152,10 +176,10 @@ private NodeResult executeCommandOnSingleNode(ClusterCommandCallback MultiNodeResult executeCommandOnAllNodes(final ClusterCommandCallback cmd) { - return executeCommandAsyncOnNodes(cmd, getClusterTopology().getActiveMasterNodes()); + public MultiNodeResult executeCommandOnAllNodes(ClusterCommandCallback clusterCommand) { + return executeCommandAsyncOnNodes(clusterCommand, getClusterTopology().getActiveMasterNodes()); } /** - * @param callback must not be {@literal null}. + * @param clusterCommand must not be {@literal null}. * @param nodes must not be {@literal null}. * @return never {@literal null}. * @throws ClusterCommandExecutionFailureException if a failure occurs while executing the given * {@link ClusterCommandCallback command} on any given {@link RedisClusterNode node}. * @throws IllegalArgumentException in case the node could not be resolved to a topology-known node */ - public MultiNodeResult executeCommandAsyncOnNodes(ClusterCommandCallback callback, + public MultiNodeResult executeCommandAsyncOnNodes(ClusterCommandCallback clusterCommand, Iterable nodes) { - Assert.notNull(callback, "Callback must not be null"); + Assert.notNull(clusterCommand, "Callback must not be null"); Assert.notNull(nodes, "Nodes must not be null"); + ClusterTopology topology = this.topologyProvider.getTopology(); List resolvedRedisClusterNodes = new ArrayList<>(); - ClusterTopology topology = topologyProvider.getTopology(); for (RedisClusterNode node : nodes) { try { resolvedRedisClusterNodes.add(topology.lookup(node)); - } catch (ClusterStateFailureException e) { - throw new IllegalArgumentException(String.format("Node %s is unknown to cluster", node), e); + } catch (ClusterStateFailureException cause) { + throw new IllegalArgumentException(String.format("Node %s is unknown to cluster", node), cause); } } Map>> futures = new LinkedHashMap<>(); for (RedisClusterNode node : resolvedRedisClusterNodes) { - futures.put(new NodeExecution(node), executor.submit(() -> executeCommandOnSingleNode(callback, node))); + futures.put(new NodeExecution(node), executor.submit(() -> executeCommandOnSingleNode(clusterCommand, node))); } return collectResults(futures); } - private MultiNodeResult collectResults(Map>> futures) { + MultiNodeResult collectResults(Map>> futures) { boolean done = false; Map exceptions = new HashMap<>(); MultiNodeResult result = new MultiNodeResult<>(); - Set saveGuard = new HashSet<>(); + Set safeguard = new HashSet<>(); + + BiConsumer exceptionHandler = (execution, throwable) -> { + + RuntimeException dataAccessException = convertToDataAccessException((Exception) throwable); + + exceptions.putIfAbsent(execution.getNode(), dataAccessException != null ? dataAccessException : throwable); + }; while (!done) { @@ -227,50 +258,40 @@ private MultiNodeResult collectResults(Map>> entry : futures.entrySet()) { - if (!entry.getValue().isDone() && !entry.getValue().isCancelled()) { + Future> futureNodeResult = entry.getValue(); + + if (isNotFinished(futureNodeResult)) { done = false; } else { - NodeExecution execution = entry.getKey(); + NodeExecution nodeExecution = entry.getKey(); + String futureId = ObjectUtils.getIdentityHexString(futureNodeResult); try { + if (!safeguard.contains(futureId)) { - String futureId = ObjectUtils.getIdentityHexString(entry.getValue()); - - if (!saveGuard.contains(futureId)) { + NodeResult nodeResult = futureNodeResult.get(); - if (execution.isPositional()) { - result.add(execution.getPositionalKey(), entry.getValue().get()); + if (nodeExecution.isPositional()) { + result.add(nodeExecution.getPositionalKey(), nodeResult); } else { - result.add(entry.getValue().get()); + result.add(nodeResult); } - saveGuard.add(futureId); + safeguard.add(futureId); } - } catch (ExecutionException cause) { - - RuntimeException exception = convertToDataAccessException((Exception) cause.getCause()); - - exceptions.put(execution.getNode(), exception != null ? exception : cause.getCause()); + } catch (ExecutionException exception) { + exceptionHandler.accept(nodeExecution, exception.getCause()); + safeguard.add(futureId); } catch (InterruptedException cause) { - Thread.currentThread().interrupt(); - - RuntimeException exception = convertToDataAccessException((Exception) cause.getCause()); - - exceptions.put(execution.getNode(), exception != null ? exception : cause.getCause()); - + exceptionHandler.accept(nodeExecution, cause); break; } } } - try { - Thread.sleep(10); - } catch (InterruptedException e) { - done = true; - Thread.currentThread().interrupt(); - } + done = pause(done); } if (!exceptions.isEmpty()) { @@ -280,6 +301,25 @@ private MultiNodeResult collectResults(Map future) { + return !(future.isDone() || future.isCancelled()); + } + + private boolean pause(boolean done) { + + if (!done) { + try { + Thread.sleep(1); + return false; + } catch (InterruptedException ignore) { + Thread.currentThread().interrupt(); + } + } + + return true; + } + /** * Run {@link MultiKeyClusterCommandCallback} with on a curated set of nodes serving one or more keys. * @@ -306,8 +346,8 @@ public MultiNodeResult executeMultiKeyCommand(MultiKeyClusterCommandCa if (entry.getKey().isMaster()) { for (PositionalKey key : entry.getValue()) { - futures.put(new NodeExecution(entry.getKey(), key), this.executor - .submit(() -> executeMultiKeyCommandOnSingleNode(commandCallback, entry.getKey(), key.getBytes()))); + futures.put(new NodeExecution(entry.getKey(), key), this.executor.submit(() -> + executeMultiKeyCommandOnSingleNode(commandCallback, entry.getKey(), key.getBytes()))); } } } @@ -328,10 +368,11 @@ private NodeResult executeMultiKeyCommandOnSingleNode(MultiKeyClusterC try { return new NodeResult<>(node, commandCallback.doInCluster(client, key), key); - } catch (RuntimeException ex) { + } catch (RuntimeException cause) { - RuntimeException translatedException = convertToDataAccessException(ex); - throw translatedException != null ? translatedException : ex; + RuntimeException translatedException = convertToDataAccessException(cause); + + throw translatedException != null ? translatedException : cause; } finally { this.resourceProvider.returnResourceForSpecificNode(node, client); } @@ -395,7 +436,7 @@ public interface MultiKeyClusterCommandCallback { * @author Mark Paluch * @since 1.7 */ - private static class NodeExecution { + static class NodeExecution { private final RedisClusterNode node; private final @Nullable PositionalKey positionalKey; @@ -432,21 +473,22 @@ boolean isPositional() { } /** - * {@link NodeResult} encapsulates the actual value returned by a {@link ClusterCommandCallback} on a given - * {@link RedisClusterNode}. + * {@link NodeResult} encapsulates the actual {@link T value} returned by a {@link ClusterCommandCallback} + * on a given {@link RedisClusterNode}. * + * @param {@link Class Type} of the {@link Object value} returned in the result. * @author Christoph Strobl - * @param + * @author John Blum * @since 1.7 */ public static class NodeResult { private RedisClusterNode node; - private @Nullable T value; private ByteArrayWrapper key; + private @Nullable T value; /** - * Create new {@link NodeResult}. + * Create a new {@link NodeResult}. * * @param node must not be {@literal null}. * @param value can be {@literal null}. @@ -456,7 +498,7 @@ public NodeResult(RedisClusterNode node, @Nullable T value) { } /** - * Create new {@link NodeResult}. + * Create a new {@link NodeResult}. * * @param node must not be {@literal null}. * @param value can be {@literal null}. @@ -465,37 +507,36 @@ public NodeResult(RedisClusterNode node, @Nullable T value) { public NodeResult(RedisClusterNode node, @Nullable T value, byte[] key) { this.node = node; - this.value = value; - this.key = new ByteArrayWrapper(key); + this.value = value; } /** - * Get the actual value of the command execution. + * Get the {@link RedisClusterNode} the command was executed on. * - * @return can be {@literal null}. + * @return never {@literal null}. */ - @Nullable - public T getValue() { - return value; + public RedisClusterNode getNode() { + return this.node; } /** - * Get the {@link RedisClusterNode} the command was executed on. + * Return the {@link byte[] key} mapped to the value stored in Redis. * - * @return never {@literal null}. + * @return a {@link byte[] byte array} of the key mapped to the value stored in Redis. */ - public RedisClusterNode getNode() { - return node; + public byte[] getKey() { + return this.key.getArray(); } /** - * Returns the key as an array of bytes. + * Get the actual value of the command execution. * - * @return the key as an array of bytes. + * @return can be {@literal null}. */ - public byte[] getKey() { - return key.getArray(); + @Nullable + public T getValue() { + return this.value; } /** @@ -513,6 +554,34 @@ public U mapValue(Function mapper) { return mapper.apply(getValue()); } + + @Override + public boolean equals(@Nullable Object obj) { + + if (obj == this) { + return true; + } + + if (!(obj instanceof NodeResult that)) { + return false; + } + + return ObjectUtils.nullSafeEquals(this.getNode(), that.getNode()) + && Objects.equals(this.key, that.key) + && Objects.equals(this.getValue(), that.getValue()); + } + + @Override + public int hashCode() { + + int hashValue = 17; + + hashValue = 37 * hashValue + ObjectUtils.nullSafeHashCode(getNode()); + hashValue = 37 * hashValue + ObjectUtils.nullSafeHashCode(this.key); + hashValue = 37 * hashValue + ObjectUtils.nullSafeHashCode(getValue()); + + return hashValue; + } } /** diff --git a/src/main/java/org/springframework/data/redis/connection/ClusterTopology.java b/src/main/java/org/springframework/data/redis/connection/ClusterTopology.java index f49236b188..8c4f6b1427 100644 --- a/src/main/java/org/springframework/data/redis/connection/ClusterTopology.java +++ b/src/main/java/org/springframework/data/redis/connection/ClusterTopology.java @@ -17,8 +17,10 @@ import java.util.Arrays; import java.util.Collections; -import java.util.LinkedHashSet; +import java.util.Objects; import java.util.Set; +import java.util.function.Predicate; +import java.util.stream.Collectors; import org.springframework.data.redis.ClusterStateFailureException; import org.springframework.lang.Nullable; @@ -26,20 +28,27 @@ import org.springframework.util.StringUtils; /** - * {@link ClusterTopology} holds snapshot like information about {@link RedisClusterNode}s. + * Holder of snapshot-like information about {@link RedisClusterNode nodes} in the Redis cluster. * * @author Christoph Strobl * @author Mark Paluch + * @author John Blum + * @see org.springframework.data.redis.connection.RedisClusterNode * @since 1.7 */ public class ClusterTopology { + protected static final Predicate ACTIVE_NODE_PREDICATE = + node -> node.isConnected() && !node.isMarkedAsFail(); + + protected static final Predicate MASTER_NODE_PREDICATE = RedisNode::isMaster; + private final Set nodes; /** - * Creates new instance of {@link ClusterTopology}. + * Creates a new {@link ClusterTopology} from the given {@link Set} of {@link RedisClusterNode}s. * - * @param nodes can be {@literal null}. + * @param nodes {@link RedisClusterNode nodes} forming the topology of the Redis cluster; can be {@literal null}. */ public ClusterTopology(@Nullable Set nodes) { this.nodes = nodes != null ? nodes : Collections.emptySet(); @@ -61,14 +70,7 @@ public Set getNodes() { * @return never {@literal null}. */ public Set getActiveNodes() { - - Set activeNodes = new LinkedHashSet<>(nodes.size()); - for (RedisClusterNode node : nodes) { - if (node.isConnected() && !node.isMarkedAsFail()) { - activeNodes.add(node); - } - } - return activeNodes; + return getNodes().stream().filter(ACTIVE_NODE_PREDICATE).collect(Collectors.toSet()); } /** @@ -79,13 +81,9 @@ public Set getActiveNodes() { */ public Set getActiveMasterNodes() { - Set activeMasterNodes = new LinkedHashSet<>(nodes.size()); - for (RedisClusterNode node : nodes) { - if (node.isMaster() && node.isConnected() && !node.isMarkedAsFail()) { - activeMasterNodes.add(node); - } - } - return activeMasterNodes; + return getNodes().stream() + .filter(MASTER_NODE_PREDICATE.and(ACTIVE_NODE_PREDICATE)) + .collect(Collectors.toSet()); } /** @@ -94,14 +92,7 @@ public Set getActiveMasterNodes() { * @return never {@literal null}. */ public Set getMasterNodes() { - - Set masterNodes = new LinkedHashSet<>(nodes.size()); - for (RedisClusterNode node : nodes) { - if (node.isMaster()) { - masterNodes.add(node); - } - } - return masterNodes; + return getNodes().stream().filter(MASTER_NODE_PREDICATE).collect(Collectors.toSet()); } /** @@ -111,14 +102,7 @@ public Set getMasterNodes() { * @return never {@literal null}. */ public Set getSlotServingNodes(int slot) { - - Set slotServingNodes = new LinkedHashSet<>(nodes.size()); - for (RedisClusterNode node : nodes) { - if (node.servesSlot(slot)) { - slotServingNodes.add(node); - } - } - return slotServingNodes; + return getNodes().stream().filter(node -> node.servesSlot(slot)).collect(Collectors.toSet()); } /** @@ -134,14 +118,12 @@ public RedisClusterNode getKeyServingMasterNode(byte[] key) { int slot = ClusterSlotHashUtil.calculateSlot(key); - for (RedisClusterNode node : nodes) { - if (node.isMaster() && node.servesSlot(slot)) { - return node; - } - } - - throw new ClusterStateFailureException( - String.format("Could not find master node serving slot %s for key '%s',", slot, Arrays.toString(key))); + return getNodes().stream() + .filter(MASTER_NODE_PREDICATE.and(node -> node.servesSlot(slot))) + .findFirst() + .orElseThrow(() -> new ClusterStateFailureException( + String.format("Could not find master node serving slot %s for key '%s',", + slot, Arrays.toString(key)))); } /** @@ -154,14 +136,20 @@ public RedisClusterNode getKeyServingMasterNode(byte[] key) { */ public RedisClusterNode lookup(String host, int port) { - for (RedisClusterNode node : nodes) { - if (host.equals(node.getHost()) && (node.getPort() != null && port == node.getPort())) { - return node; - } - } + return getNodes().stream() + .filter(isNodeMatch(host, port)) + .findFirst() + .orElseThrow(() -> new ClusterStateFailureException( + String.format("Could not find node at %s:%s; Is your cluster info up to date", host, port))); + } - throw new ClusterStateFailureException( - String.format("Could not find node at %s:%s; Is your cluster info up to date", host, port)); + private Predicate isNodeMatch(String host, int port) { + return node -> Objects.equals(node.getHost(), host) && resolvePort(node) == port; + } + + private int resolvePort(RedisClusterNode node) { + Integer port = node.getPort(); + return port != null ? port : -1; } /** @@ -175,14 +163,11 @@ public RedisClusterNode lookup(String nodeId) { Assert.notNull(nodeId, "NodeId must not be null"); - for (RedisClusterNode node : nodes) { - if (nodeId.equals(node.getId())) { - return node; - } - } - - throw new ClusterStateFailureException( - String.format("Could not find node at %s; Is your cluster info up to date", nodeId)); + return getNodes().stream() + .filter(node -> nodeId.equals(node.getId())) + .findFirst() + .orElseThrow(() -> new ClusterStateFailureException( + String.format("Could not find node at %s; Is your cluster info up to date", nodeId))); } /** @@ -219,7 +204,8 @@ public RedisClusterNode lookup(RedisClusterNode node) { */ public Set getKeyServingNodes(byte[] key) { - Assert.notNull(key, "Key must not be null for Cluster Node lookup."); + Assert.notNull(key, "Key must not be null for Cluster Node lookup"); + return getSlotServingNodes(ClusterSlotHashUtil.calculateSlot(key)); } } diff --git a/src/test/java/org/springframework/data/redis/connection/ClusterCommandExecutorUnitTests.java b/src/test/java/org/springframework/data/redis/connection/ClusterCommandExecutorUnitTests.java index f99e786ce4..1e514be173 100644 --- a/src/test/java/org/springframework/data/redis/connection/ClusterCommandExecutorUnitTests.java +++ b/src/test/java/org/springframework/data/redis/connection/ClusterCommandExecutorUnitTests.java @@ -15,16 +15,34 @@ */ package org.springframework.data.redis.connection; -import static org.assertj.core.api.Assertions.*; -import static org.mockito.Mockito.*; -import static org.springframework.data.redis.test.util.MockitoUtils.*; - +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.springframework.data.redis.test.util.MockitoUtils.verifyInvocationsAcross; + +import java.time.Instant; import java.util.Arrays; +import java.util.Comparator; +import java.util.HashMap; import java.util.HashSet; -import java.util.LinkedHashSet; +import java.util.Map; +import java.util.Set; import java.util.concurrent.Callable; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentSkipListMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Predicate; +import java.util.function.Supplier; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; @@ -45,14 +63,24 @@ import org.springframework.data.redis.connection.ClusterCommandExecutor.ClusterCommandCallback; import org.springframework.data.redis.connection.ClusterCommandExecutor.MultiKeyClusterCommandCallback; import org.springframework.data.redis.connection.ClusterCommandExecutor.MultiNodeResult; +import org.springframework.data.redis.connection.ClusterCommandExecutor.NodeExecution; +import org.springframework.data.redis.connection.ClusterCommandExecutor.NodeResult; import org.springframework.data.redis.connection.RedisClusterNode.LinkState; import org.springframework.data.redis.connection.RedisClusterNode.SlotRange; import org.springframework.data.redis.connection.RedisNode.NodeType; +import org.springframework.data.redis.test.util.MockitoUtils; import org.springframework.scheduling.concurrent.ConcurrentTaskExecutor; +import edu.umd.cs.mtc.MultithreadedTestCase; +import edu.umd.cs.mtc.TestFramework; + /** + * Unit Tests for {@link ClusterCommandExecutor}. + * * @author Christoph Strobl * @author Mark Paluch + * @author John Blum + * @since 1.7 */ @ExtendWith(MockitoExtension.class) class ClusterCommandExecutorUnitTests { @@ -66,17 +94,32 @@ class ClusterCommandExecutorUnitTests { private static final int CLUSTER_NODE_3_PORT = 7381; private static final RedisClusterNode CLUSTER_NODE_1 = RedisClusterNode.newRedisClusterNode() - .listeningAt(CLUSTER_NODE_1_HOST, CLUSTER_NODE_1_PORT).serving(new SlotRange(0, 5460)) - .withId("ef570f86c7b1a953846668debc177a3a16733420").promotedAs(NodeType.MASTER).linkState(LinkState.CONNECTED) + .listeningAt(CLUSTER_NODE_1_HOST, CLUSTER_NODE_1_PORT) + .serving(new SlotRange(0, 5460)) + .withId("ef570f86c7b1a953846668debc177a3a16733420") + .promotedAs(NodeType.MASTER) + .linkState(LinkState.CONNECTED) + .withName("ClusterNodeX") .build(); + private static final RedisClusterNode CLUSTER_NODE_2 = RedisClusterNode.newRedisClusterNode() - .listeningAt(CLUSTER_NODE_2_HOST, CLUSTER_NODE_2_PORT).serving(new SlotRange(5461, 10922)) - .withId("0f2ee5df45d18c50aca07228cc18b1da96fd5e84").promotedAs(NodeType.MASTER).linkState(LinkState.CONNECTED) + .listeningAt(CLUSTER_NODE_2_HOST, CLUSTER_NODE_2_PORT) + .serving(new SlotRange(5461, 10922)) + .withId("0f2ee5df45d18c50aca07228cc18b1da96fd5e84") + .promotedAs(NodeType.MASTER) + .linkState(LinkState.CONNECTED) + .withName("ClusterNodeY") .build(); + private static final RedisClusterNode CLUSTER_NODE_3 = RedisClusterNode.newRedisClusterNode() - .listeningAt(CLUSTER_NODE_3_HOST, CLUSTER_NODE_3_PORT).serving(new SlotRange(10923, 16383)) - .withId("3b9b8192a874fa8f1f09dbc0ee20afab5738eee7").promotedAs(NodeType.MASTER).linkState(LinkState.CONNECTED) + .listeningAt(CLUSTER_NODE_3_HOST, CLUSTER_NODE_3_PORT) + .serving(new SlotRange(10923, 16383)) + .withId("3b9b8192a874fa8f1f09dbc0ee20afab5738eee7") + .promotedAs(NodeType.MASTER) + .linkState(LinkState.CONNECTED) + .withName("ClusterNodeZ") .build(); + private static final RedisClusterNode CLUSTER_NODE_2_LOOKUP = RedisClusterNode.newRedisClusterNode() .withId("0f2ee5df45d18c50aca07228cc18b1da96fd5e84").build(); @@ -88,8 +131,8 @@ class ClusterCommandExecutorUnitTests { private static final Converter exceptionConverter = source -> { - if (source instanceof MovedException) { - return new ClusterRedirectException(1000, ((MovedException) source).host, ((MovedException) source).port, source); + if (source instanceof MovedException movedException) { + return new ClusterRedirectException(1000, movedException.host, movedException.port, source); } return new InvalidDataAccessApiUsageException(source.getMessage(), source); @@ -97,14 +140,14 @@ class ClusterCommandExecutorUnitTests { private static final MultiKeyConnectionCommandCallback MULTIKEY_CALLBACK = Connection::bloodAndAshes; - @Mock Connection con1; - @Mock Connection con2; - @Mock Connection con3; + @Mock Connection connection1; + @Mock Connection connection2; + @Mock Connection connection3; @BeforeEach void setUp() { - this.executor = new ClusterCommandExecutor(new MockClusterNodeProvider(), new MockClusterResourceProvider(), + this.executor = new ClusterCommandExecutor(new MockClusterNodeProvider(), new MockClusterNodeResourceProvider(), new PassThroughExceptionTranslationStrategy(exceptionConverter), new ImmediateExecutor()); } @@ -118,7 +161,7 @@ void executeCommandOnSingleNodeShouldBeExecutedCorrectly() { executor.executeCommandOnSingleNode(COMMAND_CALLBACK, CLUSTER_NODE_2); - verify(con2, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection2, times(1)).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 @@ -127,7 +170,7 @@ void executeCommandOnSingleNodeByHostAndPortShouldBeExecutedCorrectly() { executor.executeCommandOnSingleNode(COMMAND_CALLBACK, new RedisClusterNode(CLUSTER_NODE_2_HOST, CLUSTER_NODE_2_PORT)); - verify(con2, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection2, times(1)).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 @@ -135,15 +178,17 @@ void executeCommandOnSingleNodeByNodeIdShouldBeExecutedCorrectly() { executor.executeCommandOnSingleNode(COMMAND_CALLBACK, new RedisClusterNode(CLUSTER_NODE_2.id)); - verify(con2, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection2, times(1)).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 + @SuppressWarnings("all") void executeCommandOnSingleNodeShouldThrowExceptionWhenNodeIsNull() { assertThatIllegalArgumentException().isThrownBy(() -> executor.executeCommandOnSingleNode(COMMAND_CALLBACK, null)); } @Test // DATAREDIS-315 + @SuppressWarnings("all") void executeCommandOnSingleNodeShouldThrowExceptionWhenCommandCallbackIsNull() { assertThatIllegalArgumentException().isThrownBy(() -> executor.executeCommandOnSingleNode(null, CLUSTER_NODE_1)); } @@ -158,52 +203,52 @@ void executeCommandOnSingleNodeShouldThrowExceptionWhenNodeIsUnknown() { void executeCommandAsyncOnNodesShouldExecuteCommandOnGivenNodes() { ClusterCommandExecutor executor = new ClusterCommandExecutor(new MockClusterNodeProvider(), - new MockClusterResourceProvider(), new PassThroughExceptionTranslationStrategy(exceptionConverter), + new MockClusterNodeResourceProvider(), new PassThroughExceptionTranslationStrategy(exceptionConverter), new ConcurrentTaskExecutor(new SyncTaskExecutor())); executor.executeCommandAsyncOnNodes(COMMAND_CALLBACK, Arrays.asList(CLUSTER_NODE_1, CLUSTER_NODE_2)); - verify(con1, times(1)).theWheelWeavesAsTheWheelWills(); - verify(con2, times(1)).theWheelWeavesAsTheWheelWills(); - verify(con3, never()).theWheelWeavesAsTheWheelWills(); + verify(connection1, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection2, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection3, never()).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 void executeCommandAsyncOnNodesShouldExecuteCommandOnGivenNodesByHostAndPort() { ClusterCommandExecutor executor = new ClusterCommandExecutor(new MockClusterNodeProvider(), - new MockClusterResourceProvider(), new PassThroughExceptionTranslationStrategy(exceptionConverter), + new MockClusterNodeResourceProvider(), new PassThroughExceptionTranslationStrategy(exceptionConverter), new ConcurrentTaskExecutor(new SyncTaskExecutor())); executor.executeCommandAsyncOnNodes(COMMAND_CALLBACK, Arrays.asList(new RedisClusterNode(CLUSTER_NODE_1_HOST, CLUSTER_NODE_1_PORT), new RedisClusterNode(CLUSTER_NODE_2_HOST, CLUSTER_NODE_2_PORT))); - verify(con1, times(1)).theWheelWeavesAsTheWheelWills(); - verify(con2, times(1)).theWheelWeavesAsTheWheelWills(); - verify(con3, never()).theWheelWeavesAsTheWheelWills(); + verify(connection1, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection2, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection3, never()).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 void executeCommandAsyncOnNodesShouldExecuteCommandOnGivenNodesByNodeId() { ClusterCommandExecutor executor = new ClusterCommandExecutor(new MockClusterNodeProvider(), - new MockClusterResourceProvider(), new PassThroughExceptionTranslationStrategy(exceptionConverter), + new MockClusterNodeResourceProvider(), new PassThroughExceptionTranslationStrategy(exceptionConverter), new ConcurrentTaskExecutor(new SyncTaskExecutor())); executor.executeCommandAsyncOnNodes(COMMAND_CALLBACK, Arrays.asList(new RedisClusterNode(CLUSTER_NODE_1.id), CLUSTER_NODE_2_LOOKUP)); - verify(con1, times(1)).theWheelWeavesAsTheWheelWills(); - verify(con2, times(1)).theWheelWeavesAsTheWheelWills(); - verify(con3, never()).theWheelWeavesAsTheWheelWills(); + verify(connection1, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection2, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection3, never()).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 void executeCommandAsyncOnNodesShouldFailOnGivenUnknownNodes() { ClusterCommandExecutor executor = new ClusterCommandExecutor(new MockClusterNodeProvider(), - new MockClusterResourceProvider(), new PassThroughExceptionTranslationStrategy(exceptionConverter), + new MockClusterNodeResourceProvider(), new PassThroughExceptionTranslationStrategy(exceptionConverter), new ConcurrentTaskExecutor(new SyncTaskExecutor())); assertThatIllegalArgumentException().isThrownBy(() -> executor.executeCommandAsyncOnNodes(COMMAND_CALLBACK, @@ -214,42 +259,42 @@ void executeCommandAsyncOnNodesShouldFailOnGivenUnknownNodes() { void executeCommandOnAllNodesShouldExecuteCommandOnEveryKnownClusterNode() { ClusterCommandExecutor executor = new ClusterCommandExecutor(new MockClusterNodeProvider(), - new MockClusterResourceProvider(), new PassThroughExceptionTranslationStrategy(exceptionConverter), + new MockClusterNodeResourceProvider(), new PassThroughExceptionTranslationStrategy(exceptionConverter), new ConcurrentTaskExecutor(new SyncTaskExecutor())); executor.executeCommandOnAllNodes(COMMAND_CALLBACK); - verify(con1, times(1)).theWheelWeavesAsTheWheelWills(); - verify(con2, times(1)).theWheelWeavesAsTheWheelWills(); - verify(con3, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection1, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection2, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection3, times(1)).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 void executeCommandAsyncOnNodesShouldCompleteAndCollectErrorsOfAllNodes() { - when(con1.theWheelWeavesAsTheWheelWills()).thenReturn("rand"); - when(con2.theWheelWeavesAsTheWheelWills()).thenThrow(new IllegalStateException("(error) mat lost the dagger...")); - when(con3.theWheelWeavesAsTheWheelWills()).thenReturn("perrin"); + when(connection1.theWheelWeavesAsTheWheelWills()).thenReturn("rand"); + when(connection2.theWheelWeavesAsTheWheelWills()).thenThrow(new IllegalStateException("(error) mat lost the dagger...")); + when(connection3.theWheelWeavesAsTheWheelWills()).thenReturn("perrin"); try { executor.executeCommandOnAllNodes(COMMAND_CALLBACK); - } catch (ClusterCommandExecutionFailureException e) { + } catch (ClusterCommandExecutionFailureException cause) { - assertThat(e.getSuppressed()).hasSize(1); - assertThat(e.getSuppressed()[0]).isInstanceOf(DataAccessException.class); + assertThat(cause.getSuppressed()).hasSize(1); + assertThat(cause.getSuppressed()[0]).isInstanceOf(DataAccessException.class); } - verify(con1, times(1)).theWheelWeavesAsTheWheelWills(); - verify(con2, times(1)).theWheelWeavesAsTheWheelWills(); - verify(con3, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection1, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection2, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection3, times(1)).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 void executeCommandAsyncOnNodesShouldCollectResultsCorrectly() { - when(con1.theWheelWeavesAsTheWheelWills()).thenReturn("rand"); - when(con2.theWheelWeavesAsTheWheelWills()).thenReturn("mat"); - when(con3.theWheelWeavesAsTheWheelWills()).thenReturn("perrin"); + when(connection1.theWheelWeavesAsTheWheelWills()).thenReturn("rand"); + when(connection2.theWheelWeavesAsTheWheelWills()).thenReturn("mat"); + when(connection3.theWheelWeavesAsTheWheelWills()).thenReturn("perrin"); MultiNodeResult result = executor.executeCommandOnAllNodes(COMMAND_CALLBACK); @@ -261,10 +306,10 @@ void executeMultikeyCommandShouldRunCommandAcrossCluster() { // key-1 and key-9 map both to node1 ArgumentCaptor captor = ArgumentCaptor.forClass(byte[].class); - when(con1.bloodAndAshes(captor.capture())).thenReturn("rand").thenReturn("egwene"); - when(con2.bloodAndAshes(any(byte[].class))).thenReturn("mat"); - when(con3.bloodAndAshes(any(byte[].class))).thenReturn("perrin"); + when(connection1.bloodAndAshes(captor.capture())).thenReturn("rand").thenReturn("egwene"); + when(connection2.bloodAndAshes(any(byte[].class))).thenReturn("mat"); + when(connection3.bloodAndAshes(any(byte[].class))).thenReturn("perrin"); MultiNodeResult result = executor.executeMultiKeyCommand(MULTIKEY_CALLBACK, new HashSet<>( @@ -279,21 +324,21 @@ void executeMultikeyCommandShouldRunCommandAcrossCluster() { @Test // DATAREDIS-315 void executeCommandOnSingleNodeAndFollowRedirect() { - when(con1.theWheelWeavesAsTheWheelWills()).thenThrow(new MovedException(CLUSTER_NODE_3_HOST, CLUSTER_NODE_3_PORT)); + when(connection1.theWheelWeavesAsTheWheelWills()).thenThrow(new MovedException(CLUSTER_NODE_3_HOST, CLUSTER_NODE_3_PORT)); executor.executeCommandOnSingleNode(COMMAND_CALLBACK, CLUSTER_NODE_1); - verify(con1, times(1)).theWheelWeavesAsTheWheelWills(); - verify(con3, times(1)).theWheelWeavesAsTheWheelWills(); - verify(con2, never()).theWheelWeavesAsTheWheelWills(); + verify(connection1, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection3, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection2, never()).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 void executeCommandOnSingleNodeAndFollowRedirectButStopsAfterMaxRedirects() { - when(con1.theWheelWeavesAsTheWheelWills()).thenThrow(new MovedException(CLUSTER_NODE_3_HOST, CLUSTER_NODE_3_PORT)); - when(con3.theWheelWeavesAsTheWheelWills()).thenThrow(new MovedException(CLUSTER_NODE_2_HOST, CLUSTER_NODE_2_PORT)); - when(con2.theWheelWeavesAsTheWheelWills()).thenThrow(new MovedException(CLUSTER_NODE_1_HOST, CLUSTER_NODE_1_PORT)); + when(connection1.theWheelWeavesAsTheWheelWills()).thenThrow(new MovedException(CLUSTER_NODE_3_HOST, CLUSTER_NODE_3_PORT)); + when(connection3.theWheelWeavesAsTheWheelWills()).thenThrow(new MovedException(CLUSTER_NODE_2_HOST, CLUSTER_NODE_2_PORT)); + when(connection2.theWheelWeavesAsTheWheelWills()).thenThrow(new MovedException(CLUSTER_NODE_1_HOST, CLUSTER_NODE_1_PORT)); try { executor.setMaxRedirects(4); @@ -302,9 +347,9 @@ void executeCommandOnSingleNodeAndFollowRedirectButStopsAfterMaxRedirects() { assertThat(e).isInstanceOf(TooManyClusterRedirectionsException.class); } - verify(con1, times(2)).theWheelWeavesAsTheWheelWills(); - verify(con3, times(2)).theWheelWeavesAsTheWheelWills(); - verify(con2, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection1, times(2)).theWheelWeavesAsTheWheelWills(); + verify(connection3, times(2)).theWheelWeavesAsTheWheelWills(); + verify(connection2, times(1)).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 @@ -312,53 +357,178 @@ void executeCommandOnArbitraryNodeShouldPickARandomNode() { executor.executeCommandOnArbitraryNode(COMMAND_CALLBACK); - verifyInvocationsAcross("theWheelWeavesAsTheWheelWills", times(1), con1, con2, con3); + verifyInvocationsAcross("theWheelWeavesAsTheWheelWills", times(1), connection1, connection2, connection3); + } + + @Test // GH-2518 + void collectResultsCompletesSuccessfully() { + + Instant done = Instant.now().plusMillis(5); + + Predicate>> isDone = future -> Instant.now().isAfter(done); + + Map>> futures = new HashMap<>(); + + NodeResult nodeOneA = newNodeResult(CLUSTER_NODE_1, "A"); + NodeResult nodeTwoB = newNodeResult(CLUSTER_NODE_2, "B"); + NodeResult nodeThreeC = newNodeResult(CLUSTER_NODE_3, "C"); + + futures.put(newNodeExecution(CLUSTER_NODE_1), mockFutureAndIsDone(nodeOneA, isDone)); + futures.put(newNodeExecution(CLUSTER_NODE_2), mockFutureAndIsDone(nodeTwoB, isDone)); + futures.put(newNodeExecution(CLUSTER_NODE_3), mockFutureAndIsDone(nodeThreeC, isDone)); + + MultiNodeResult results = this.executor.collectResults(futures); + + assertThat(results).isNotNull(); + assertThat(results.getResults()).containsExactlyInAnyOrder(nodeOneA, nodeTwoB, nodeThreeC); + } + + @Test // GH-2518 + void collectResultsFailsWithExecutionException() { + + Map>> futures = new HashMap<>(); + + NodeResult nodeOneA = newNodeResult(CLUSTER_NODE_1, "A"); + + futures.put(newNodeExecution(CLUSTER_NODE_1), mockFutureAndIsDone(nodeOneA, future -> true)); + futures.put(newNodeExecution(CLUSTER_NODE_2), mockFutureThrowingExecutionException( + new ExecutionException("TestError", new IllegalArgumentException("MockError")))); + + assertThatExceptionOfType(ClusterCommandExecutionFailureException.class) + .isThrownBy(() -> this.executor.collectResults(futures)) + .withMessage("MockError") + .withCauseInstanceOf(InvalidDataAccessApiUsageException.class) + .extracting(Throwable::getCause) + .extracting(Throwable::getCause) + .isInstanceOf(IllegalArgumentException.class) + .extracting(Throwable::getMessage) + .isEqualTo("MockError"); + } + + @Test // GH-2518 + void collectResultsFailsWithInterruptedException() throws Throwable { + TestFramework.runOnce(new CollectResultsInterruptedMultithreadedTestCase(this.executor)); + } + + // Future.get() for X will get called twice if at least one other Future is not done and Future.get() for X + // threw an ExecutionException in the previous iteration, thereby marking it as done! + @Test // GH-2518 + @SuppressWarnings("all") + void collectResultsCallsFutureGetOnlyOnce() throws Exception { + + AtomicInteger count = new AtomicInteger(0); + Map>> futures = new HashMap<>(); + + Future> clusterNodeOneFutureResult = mockFutureAndIsDone(null, future -> + count.incrementAndGet() % 2 == 0); + + Future> clusterNodeTwoFutureResult = mockFutureThrowingExecutionException( + new ExecutionException("TestError", new IllegalArgumentException("MockError"))); + + futures.put(newNodeExecution(CLUSTER_NODE_1), clusterNodeOneFutureResult); + futures.put(newNodeExecution(CLUSTER_NODE_2), clusterNodeTwoFutureResult); + + assertThatExceptionOfType(ClusterCommandExecutionFailureException.class) + .isThrownBy(() -> this.executor.collectResults(futures)); + + verify(clusterNodeOneFutureResult, times(1)).get(); + verify(clusterNodeTwoFutureResult, times(1)).get(); + } + + // Covers the case where Future.get() is mistakenly called multiple times, or if the Future.isDone() implementation + // does not properly take into account Future.get() throwing an ExecutionException during computation subsequently + // returning false instead of true. + // This should be properly handled by the "safeguard" (see collectResultsCallsFutureGetOnlyOnce()), but... + // just in case! The ExecutionException handler now stores the [DataAccess]Exception with Map.putIfAbsent(..). + @Test // GH-2518 + @SuppressWarnings("all") + void collectResultsCapturesFirstExecutionExceptionOnly() { + + AtomicInteger count = new AtomicInteger(0); + AtomicInteger exceptionCount = new AtomicInteger(0); + + Map>> futures = new HashMap<>(); + + futures.put(newNodeExecution(CLUSTER_NODE_1), + mockFutureAndIsDone(null, future -> count.incrementAndGet() % 2 == 0)); + + futures.put(newNodeExecution(CLUSTER_NODE_2), mockFutureThrowingExecutionException(() -> + new ExecutionException("TestError", new IllegalStateException("MockError" + exceptionCount.getAndIncrement())))); + + assertThatExceptionOfType(ClusterCommandExecutionFailureException.class) + .isThrownBy(() -> this.executor.collectResults(futures)) + .withMessage("MockError0") + .withCauseInstanceOf(InvalidDataAccessApiUsageException.class) + .extracting(Throwable::getCause) + .extracting(Throwable::getCause) + .isInstanceOf(IllegalStateException.class) + .extracting(Throwable::getMessage) + .isEqualTo("MockError0"); + } + + private Future mockFutureAndIsDone(T result, Predicate> isDone) { + + return MockitoUtils.mockFuture(result, future -> { + doAnswer(invocation -> isDone.test(future)).when(future).isDone(); + return future; + }); } - class MockClusterNodeProvider implements ClusterTopologyProvider { + private Future mockFutureThrowingExecutionException(ExecutionException exception) { + return mockFutureThrowingExecutionException(() -> exception); + } + + private Future mockFutureThrowingExecutionException(Supplier exceptionSupplier) { + + return MockitoUtils.mockFuture(null, future -> { + doReturn(true).when(future).isDone(); + doAnswer(invocationOnMock -> { throw exceptionSupplier.get(); }).when(future).get(); + return future; + }); + } + + private NodeExecution newNodeExecution(RedisClusterNode clusterNode) { + return new NodeExecution(clusterNode); + } + + private NodeResult newNodeResult(RedisClusterNode clusterNode, T value) { + return new NodeResult<>(clusterNode, value); + } + + static class MockClusterNodeProvider implements ClusterTopologyProvider { @Override public ClusterTopology getTopology() { - return new ClusterTopology( - new LinkedHashSet<>(Arrays.asList(CLUSTER_NODE_1, CLUSTER_NODE_2, CLUSTER_NODE_3))); + return new ClusterTopology(Set.of(CLUSTER_NODE_1, CLUSTER_NODE_2, CLUSTER_NODE_3)); } - } - class MockClusterResourceProvider implements ClusterNodeResourceProvider { + class MockClusterNodeResourceProvider implements ClusterNodeResourceProvider { @Override - public Connection getResourceForSpecificNode(RedisClusterNode node) { + @SuppressWarnings("all") + public Connection getResourceForSpecificNode(RedisClusterNode clusterNode) { - if (CLUSTER_NODE_1.equals(node)) { - return con1; - } - if (CLUSTER_NODE_2.equals(node)) { - return con2; - } - if (CLUSTER_NODE_3.equals(node)) { - return con3; - } - - return null; + return CLUSTER_NODE_1.equals(clusterNode) ? connection1 + : CLUSTER_NODE_2.equals(clusterNode) ? connection2 + : CLUSTER_NODE_3.equals(clusterNode) ? connection3 + : null; } @Override public void returnResourceForSpecificNode(RedisClusterNode node, Object resource) { - // TODO Auto-generated method stub } - } - static interface ConnectionCommandCallback extends ClusterCommandCallback { + interface ConnectionCommandCallback extends ClusterCommandCallback { } - static interface MultiKeyConnectionCommandCallback extends MultiKeyClusterCommandCallback { + interface MultiKeyConnectionCommandCallback extends MultiKeyClusterCommandCallback { } - static interface Connection { + interface Connection { String theWheelWeavesAsTheWheelWills(); @@ -374,19 +544,20 @@ static class MovedException extends RuntimeException { this.host = host; this.port = port; } - } static class ImmediateExecutor implements AsyncTaskExecutor { @Override - public void execute(Runnable runnable, long l) { + public void execute(Runnable runnable) { runnable.run(); } @Override public Future submit(Runnable runnable) { + return submit(() -> { + runnable.run(); return null; @@ -395,19 +566,104 @@ public Future submit(Runnable runnable) { @Override public Future submit(Callable callable) { + try { return CompletableFuture.completedFuture(callable.call()); - } catch (Exception e) { + } catch (Exception cause) { CompletableFuture future = new CompletableFuture<>(); - future.completeExceptionally(e); + + future.completeExceptionally(cause); + return future; } } + } + + @SuppressWarnings("all") + private static class CollectResultsInterruptedMultithreadedTestCase extends MultithreadedTestCase { + + private static final CountDownLatch latch = new CountDownLatch(1); + + private static final Comparator NODE_COMPARATOR = + Comparator.comparing(nodeExecution -> nodeExecution.getNode().getName()); + + private final ClusterCommandExecutor clusterCommandExecutor; + + private final Map>> futureNodeResults; + + private Future> mockNodeOneFutureResult; + private Future> mockNodeTwoFutureResult; + + private volatile Thread collectResultsThread; + + private CollectResultsInterruptedMultithreadedTestCase(ClusterCommandExecutor clusterCommandExecutor) { + this.clusterCommandExecutor = clusterCommandExecutor; + this.futureNodeResults = new ConcurrentSkipListMap<>(NODE_COMPARATOR); + } @Override - public void execute(Runnable runnable) { - runnable.run(); + public void initialize() { + + super.initialize(); + + this.mockNodeOneFutureResult = this.futureNodeResults.computeIfAbsent(new NodeExecution(CLUSTER_NODE_1), + nodeExecution -> MockitoUtils.mockFuture(null, mockFuture -> { + doReturn(false).when(mockFuture).isDone(); + return mockFuture; + })); + + this.mockNodeTwoFutureResult = this.futureNodeResults.computeIfAbsent(new NodeExecution(CLUSTER_NODE_2), + nodeExecution -> MockitoUtils.mockFuture(null, mockFuture -> { + + doReturn(true).when(mockFuture).isDone(); + + doAnswer(invocation -> { + latch.await(); + return null; + }).when(mockFuture).get(); + + return mockFuture; + })); + } + + public void thread1() { + + assertTick(0); + + this.collectResultsThread = Thread.currentThread(); + this.collectResultsThread.setName("CollectResults Thread"); + + assertThatExceptionOfType(ClusterCommandExecutionFailureException.class) + .isThrownBy(() -> this.clusterCommandExecutor.collectResults(this.futureNodeResults)); + + assertThat(this.collectResultsThread.isInterrupted()).isTrue(); + } + + public void thread2() { + + assertTick(0); + + Thread.currentThread().setName("Interrupting Thread"); + + waitForTick(1); + + assertThat(this.collectResultsThread).isNotNull(); + assertThat(this.collectResultsThread.getName()).isEqualTo("CollectResults Thread"); + + this.collectResultsThread.interrupt(); + } + + @Override + public void finish() { + + try { + verify(this.mockNodeOneFutureResult, times(1)).isDone(); + verify(this.mockNodeOneFutureResult, never()).get(); + verify(this.mockNodeTwoFutureResult, times(1)).isDone(); + verify(this.mockNodeTwoFutureResult, times(1)).get(); + } catch (Throwable ignore) { + } } } } diff --git a/src/test/java/org/springframework/data/redis/test/util/MockitoUtils.java b/src/test/java/org/springframework/data/redis/test/util/MockitoUtils.java index e42f0749d8..318d47ec85 100644 --- a/src/test/java/org/springframework/data/redis/test/util/MockitoUtils.java +++ b/src/test/java/org/springframework/data/redis/test/util/MockitoUtils.java @@ -15,34 +15,131 @@ */ package org.springframework.data.redis.test.util; -import static org.mockito.Mockito.*; +import static org.mockito.Mockito.anyBoolean; +import static org.mockito.Mockito.anyLong; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.isA; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockingDetails; +import static org.mockito.Mockito.withSettings; import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.concurrent.CancellationException; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Function; import org.mockito.internal.invocation.InvocationMatcher; import org.mockito.internal.verification.api.VerificationData; import org.mockito.invocation.Invocation; import org.mockito.invocation.MatchableInvocation; +import org.mockito.quality.Strictness; +import org.mockito.stubbing.Answer; import org.mockito.verification.VerificationMode; + +import org.springframework.lang.NonNull; +import org.springframework.lang.Nullable; import org.springframework.util.StringUtils; /** + * Utilities for using {@literal Mockito} and creating {@link Object mock objects} in {@literal unit tests}. + * * @author Christoph Strobl + * @author John Blum + * @see org.mockito.Mockito * @since 1.7 */ -public class MockitoUtils { +@SuppressWarnings("unused") +public abstract class MockitoUtils { + + /** + * Creates a mock {@link Future} returning the given {@link Object result}. + * + * @param {@link Class type} of {@link Object result} returned by the mock {@link Future}. + * @param result {@link Object value} returned as the {@literal result} of the mock {@link Future}. + * @return a new mock {@link Future}. + * @see java.util.concurrent.Future + */ + @SuppressWarnings("unchecked") + public static @NonNull Future mockFuture(@Nullable T result) { + + try { + + AtomicBoolean cancelled = new AtomicBoolean(false); + AtomicBoolean done = new AtomicBoolean(false); + + Future mockFuture = mock(Future.class, withSettings().strictness(Strictness.LENIENT)); + + // A Future can only be cancelled if not done, it was not already cancelled, and no error occurred. + // The cancel(..) logic is not Thread-safe due to compound actions involving multiple variables. + // However, the cancel(..) logic does not necessarily need to be Thread-safe given the task execution + // of a Future is asynchronous and cancellation is driven by Thread interrupt from another Thread. + Answer cancelAnswer = invocation -> !done.get() + && cancelled.compareAndSet(done.get(), true) + && done.compareAndSet(done.get(), true); + + Answer getAnswer = invocation -> { + + // The Future is done no matter if it returns the result or was cancelled/interrupted. + done.set(true); + + if (Thread.currentThread().isInterrupted()) { + throw new InterruptedException("Thread was interrupted"); + } + + if (cancelled.get()) { + throw new CancellationException("Task was cancelled"); + } + + return result; + }; + + doAnswer(invocation -> cancelled.get()).when(mockFuture).isCancelled(); + doAnswer(invocation -> done.get()).when(mockFuture).isDone(); + doAnswer(cancelAnswer).when(mockFuture).cancel(anyBoolean()); + doAnswer(getAnswer).when(mockFuture).get(); + doAnswer(getAnswer).when(mockFuture).get(anyLong(), isA(TimeUnit.class)); + + return mockFuture; + } + catch (Exception cause) { + String message = String.format("Failed to create a mock of Future having result [%s]", result); + throw new IllegalStateException(message, cause); + } + } + + /** + * Creates a mock {@link Future} returning the given {@link Object result}, customized with the given, + * required {@link Function}. + * + * @param {@link Class type} of {@link Object result} returned by the mock {@link Future}. + * @param result {@link Object value} returned as the {@literal result} of the mock {@link Future}. + * @param futureFunction {@link Function} used to customize the mock {@link Future} on creation; + * must not be {@literal null}. + * @return a new mock {@link Future}. + * @see java.util.concurrent.Future + * @see java.util.function.Function + * @see #mockFuture(Object) + */ + public static @NonNull Future mockFuture(@Nullable T result, + @NonNull ThrowableFunction, Future> futureFunction) { + + Future mockFuture = mockFuture(result); + + return futureFunction.apply(mockFuture); + } /** * Verifies a given method is called a total number of times across all given mocks. * - * @param method - * @param mode - * @param mocks + * @param method {@link String name} of a {@link java.lang.reflect.Method} on the {@link Object mock object}. + * @param mode mode of verification used by {@literal Mockito} to verify invocations on {@link Object mock objects}. + * @param mocks array of {@link Object mock objects} to verify. */ - @SuppressWarnings({ "rawtypes", "serial" }) - public static void verifyInvocationsAcross(final String method, final VerificationMode mode, Object... mocks) { + public static void verifyInvocationsAcross(String method, VerificationMode mode, Object... mocks) { mode.verify(new VerificationDataImpl(getInvocations(method, mocks), new InvocationMatcher(null, Collections .singletonList(org.mockito.internal.matchers.Any.ANY)) { @@ -56,17 +153,15 @@ public boolean matches(Invocation actual) { public String toString() { return String.format("%s for method: %s", mode, method); } - })); } private static List getInvocations(String method, Object... mocks) { List invocations = new ArrayList<>(); - for (Object mock : mocks) { + for (Object mock : mocks) { if (StringUtils.hasText(method)) { - for (Invocation invocation : mockingDetails(mock).getInvocations()) { if (invocation.getMethod().getName().equals(method)) { invocations.add(invocation); @@ -76,6 +171,7 @@ private static List getInvocations(String method, Object... mocks) { invocations.addAll(mockingDetails(mock).getInvocations()); } } + return invocations; } @@ -98,7 +194,31 @@ public List getAllInvocations() { public MatchableInvocation getTarget() { return wanted; } - } + @FunctionalInterface + public interface ThrowableFunction extends Function { + + @Override + default R apply(T target) { + + try { + return applyThrowingException(target); + } + catch (Throwable cause) { + String message = String.format("Failed to apply Function [%s] to target [%s]", this, target); + throw new IllegalStateException(message, cause); + } + } + + R applyThrowingException(T target) throws Throwable; + + @SuppressWarnings("unchecked") + default @NonNull ThrowableFunction andThen( + @Nullable ThrowableFunction after) { + + return after == null ? (ThrowableFunction) this + : target -> after.apply(this.apply(target)); + } + } }