From e70d4b452701b28d2b6d7f13d5f242498101db91 Mon Sep 17 00:00:00 2001 From: John Blum Date: Wed, 8 Mar 2023 15:34:31 -0800 Subject: [PATCH] Remove Thread.sleep(..) from collectResults(..). Replace Thread.sleep(..) with Future.get(timeout, :TimeUnit) for 10 microseconds. As a result, Future.isDone() and Future.isCancelled() are no longer necessary. Simply try to get the results within 10 us, and if a TimeoutException is thrown, then set done to false. 10 microseconds is 1/1000 of 10 milliseconds. This means a Redis cluster with 1000 nodes will run in a similar time to Thread.sleep(10L) if all Futures are blocked waiting for the compuptation to complete and take an equal amount of time to compute the result, which is rarely the case in practice, given different hardware configurations, data access patterns, load balancing/requst routing etc. However, using Future.get(timeout, :TimeUnit) is more fair than Future.get(), which blocks until a result is returned or an ExecutionException is thrown, thereby starving computationally faster nodes vs. other nodes in the cluster that might be overloaded. In the meantime, some nodes may even complete in the short amount of time while waiting for other nodes to complete. 10 microseconds was partially arbitrary, but not anymore so than Thread.sleep(10L) (10 millisconds). The objective was to give each node a chance to complete the computation in a moments notice balanced with the need to quickly check if the computation is "done", hence Future.get(timeout, TimeUnit.MICROSECONDS), or sub-millisecond response times. This may need to be futher tuned over time, but should serve as a reasonable baseline for the time being. This was also based on https://redis.io/docs/reference/cluster-spec/#overview-of-redis-cluster-main-components in the Redis documentation, "recommending" a cluster size of no more than 1000 nodes. One optimimzation might be to reorder the Map of Futures at the end of each iteration organizing Futures that are done first. Furthermore, Futures that have already completed could be removed from the Map. Of course, there is little harm in keeping the future in the Map with the "safeguard" in place. This optimization was not included in the changes simply because the optimization is most likely neglible and should be measured. Reconstructing a TreeMap should run mostly within log(n) time, but memory consumption should also be taken into consideration. Add test coverage for ClusterCommandExecutor collectResults(..) method. Cleanup compiler warnings in ClusterCommandExecutorUnitTests. Closes #2518 --- .../connection/ClusterCommandExecutor.java | 283 ++++++---- .../redis/connection/ClusterTopology.java | 104 ++-- .../ClusterCommandExecutorUnitTests.java | 513 ++++++++++++++---- .../data/redis/test/util/MockitoUtils.java | 142 ++++- 4 files changed, 765 insertions(+), 277 deletions(-) diff --git a/src/main/java/org/springframework/data/redis/connection/ClusterCommandExecutor.java b/src/main/java/org/springframework/data/redis/connection/ClusterCommandExecutor.java index f016dfc33f..d495ade564 100644 --- a/src/main/java/org/springframework/data/redis/connection/ClusterCommandExecutor.java +++ b/src/main/java/org/springframework/data/redis/connection/ClusterCommandExecutor.java @@ -15,10 +15,27 @@ */ package org.springframework.data.redis.connection; -import java.util.*; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; import java.util.Map.Entry; +import java.util.Objects; +import java.util.Random; +import java.util.Set; +import java.util.TreeMap; import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.function.BiConsumer; import java.util.function.Function; import java.util.stream.Collectors; @@ -43,6 +60,7 @@ * * @author Christoph Strobl * @author Mark Paluch + * @author John Blum * @since 1.7 */ public class ClusterCommandExecutor implements DisposableBean { @@ -58,7 +76,7 @@ public class ClusterCommandExecutor implements DisposableBean { private final ExceptionTranslationStrategy exceptionTranslationStrategy; /** - * Create a new instance of {@link ClusterCommandExecutor}. + * Create a new {@link ClusterCommandExecutor}. * * @param topologyProvider must not be {@literal null}. * @param resourceProvider must not be {@literal null}. @@ -92,40 +110,47 @@ public ClusterCommandExecutor(ClusterTopologyProvider topologyProvider, ClusterN /** * Run {@link ClusterCommandCallback} on a random node. * - * @param commandCallback must not be {@literal null}. + * @param clusterCommand must not be {@literal null}. * @return never {@literal null}. */ - public NodeResult executeCommandOnArbitraryNode(ClusterCommandCallback commandCallback) { + public NodeResult executeCommandOnArbitraryNode(ClusterCommandCallback clusterCommand) { - Assert.notNull(commandCallback, "ClusterCommandCallback must not be null"); + Assert.notNull(clusterCommand, "ClusterCommandCallback must not be null"); List nodes = new ArrayList<>(getClusterTopology().getActiveNodes()); - return executeCommandOnSingleNode(commandCallback, nodes.get(new Random().nextInt(nodes.size()))); + RedisClusterNode arbitraryNode = nodes.get(new Random().nextInt(nodes.size())); + + return executeCommandOnSingleNode(clusterCommand, arbitraryNode); } /** * Run {@link ClusterCommandCallback} on given {@link RedisClusterNode}. * - * @param cmd must not be {@literal null}. + * @param clusterCommand must not be {@literal null}. * @param node must not be {@literal null}. * @return the {@link NodeResult} from the single, targeted {@link RedisClusterNode}. * @throws IllegalArgumentException in case no resource can be acquired for given node. */ - public NodeResult executeCommandOnSingleNode(ClusterCommandCallback cmd, RedisClusterNode node) { - return executeCommandOnSingleNode(cmd, node, 0); + public NodeResult executeCommandOnSingleNode(ClusterCommandCallback clusterCommand, + RedisClusterNode node) { + + return executeCommandOnSingleNode(clusterCommand, node, 0); } - private NodeResult executeCommandOnSingleNode(ClusterCommandCallback cmd, RedisClusterNode node, - int redirectCount) { + private NodeResult executeCommandOnSingleNode(ClusterCommandCallback clusterCommand, + RedisClusterNode node, int redirectCount) { - Assert.notNull(cmd, "ClusterCommandCallback must not be null"); + Assert.notNull(clusterCommand, "ClusterCommandCallback must not be null"); Assert.notNull(node, "RedisClusterNode must not be null"); - if (redirectCount > maxRedirects) { - throw new TooManyClusterRedirectionsException(String.format( - "Cannot follow Cluster Redirects over more than %s legs; Please consider increasing the number of redirects to follow; Current value is: %s.", - redirectCount, maxRedirects)); + if (redirectCount > this.maxRedirects) { + + String message = String.format("Cannot follow Cluster Redirects over more than %s legs;" + + " Please consider increasing the number of redirects to follow; Current value is: %s.", + redirectCount, this.maxRedirects); + + throw new TooManyClusterRedirectionsException(message); } RedisClusterNode nodeToUse = lookupNode(node); @@ -135,14 +160,15 @@ private NodeResult executeCommandOnSingleNode(ClusterCommandCallback(node, cmd.doInCluster(client)); + return new NodeResult<>(node, clusterCommand.doInCluster(client)); } catch (RuntimeException cause) { RuntimeException translatedException = convertToDataAccessException(cause); if (translatedException instanceof ClusterRedirectException clusterRedirectException) { - return executeCommandOnSingleNode(cmd, topologyProvider.getTopology().lookup( - clusterRedirectException.getTargetHost(), clusterRedirectException.getTargetPort()), redirectCount + 1); + return executeCommandOnSingleNode(clusterCommand, topologyProvider.getTopology() + .lookup(clusterRedirectException.getTargetHost(), clusterRedirectException.getTargetPort()), + redirectCount + 1); } else { throw translatedException != null ? translatedException : cause; } @@ -152,10 +178,10 @@ private NodeResult executeCommandOnSingleNode(ClusterCommandCallback MultiNodeResult executeCommandOnAllNodes(final ClusterCommandCallback cmd) { - return executeCommandAsyncOnNodes(cmd, getClusterTopology().getActiveMasterNodes()); + public MultiNodeResult executeCommandOnAllNodes(ClusterCommandCallback clusterCommand) { + return executeCommandAsyncOnNodes(clusterCommand, getClusterTopology().getActiveMasterNodes()); } /** - * @param callback must not be {@literal null}. + * @param clusterCommand must not be {@literal null}. * @param nodes must not be {@literal null}. * @return never {@literal null}. * @throws ClusterCommandExecutionFailureException if a failure occurs while executing the given * {@link ClusterCommandCallback command} on any given {@link RedisClusterNode node}. * @throws IllegalArgumentException in case the node could not be resolved to a topology-known node */ - public MultiNodeResult executeCommandAsyncOnNodes(ClusterCommandCallback callback, + public MultiNodeResult executeCommandAsyncOnNodes(ClusterCommandCallback clusterCommand, Iterable nodes) { - Assert.notNull(callback, "Callback must not be null"); + Assert.notNull(clusterCommand, "Callback must not be null"); Assert.notNull(nodes, "Nodes must not be null"); + ClusterTopology topology = this.topologyProvider.getTopology(); List resolvedRedisClusterNodes = new ArrayList<>(); - ClusterTopology topology = topologyProvider.getTopology(); for (RedisClusterNode node : nodes) { try { resolvedRedisClusterNodes.add(topology.lookup(node)); - } catch (ClusterStateFailureException e) { - throw new IllegalArgumentException(String.format("Node %s is unknown to cluster", node), e); + } catch (ClusterStateFailureException cause) { + throw new IllegalArgumentException(String.format("Node %s is unknown to cluster", node), cause); } } Map>> futures = new LinkedHashMap<>(); for (RedisClusterNode node : resolvedRedisClusterNodes) { - futures.put(new NodeExecution(node), executor.submit(() -> executeCommandOnSingleNode(callback, node))); + futures.put(new NodeExecution(node), executor.submit(() -> executeCommandOnSingleNode(clusterCommand, node))); } return collectResults(futures); } - private MultiNodeResult collectResults(Map>> futures) { - - boolean done = false; + MultiNodeResult collectResults(Map>> futures) { Map exceptions = new HashMap<>(); MultiNodeResult result = new MultiNodeResult<>(); - Set saveGuard = new HashSet<>(); - - while (!done) { - - done = true; - - for (Map.Entry>> entry : futures.entrySet()) { + Set safeguard = new HashSet<>(); - if (!entry.getValue().isDone() && !entry.getValue().isCancelled()) { - done = false; - } else { - - NodeExecution execution = entry.getKey(); + BiConsumer exceptionHandler = (execution, throwable) -> { - try { + RuntimeException dataAccessException = convertToDataAccessException((Exception) throwable); - String futureId = ObjectUtils.getIdentityHexString(entry.getValue()); + exceptions.putIfAbsent(execution.getNode(), dataAccessException != null ? dataAccessException : throwable); + }; - if (!saveGuard.contains(futureId)) { + boolean done = false; - if (execution.isPositional()) { - result.add(execution.getPositionalKey(), entry.getValue().get()); - } else { - result.add(entry.getValue().get()); - } + while (!done) { - saveGuard.add(futureId); - } - } catch (ExecutionException cause) { + done = true; - RuntimeException exception = convertToDataAccessException((Exception) cause.getCause()); + for (Map.Entry>> entry : futures.entrySet()) { - exceptions.put(execution.getNode(), exception != null ? exception : cause.getCause()); - } catch (InterruptedException cause) { + NodeExecution nodeExecution = entry.getKey(); + Future> futureNodeResult = entry.getValue(); + String futureId = ObjectUtils.getIdentityHexString(futureNodeResult); - Thread.currentThread().interrupt(); + try { + if (!safeguard.contains(futureId)) { - RuntimeException exception = convertToDataAccessException((Exception) cause.getCause()); + NodeResult nodeResult = futureNodeResult.get(10L, TimeUnit.MICROSECONDS); - exceptions.put(execution.getNode(), exception != null ? exception : cause.getCause()); + if (nodeExecution.isPositional()) { + result.add(nodeExecution.getPositionalKey(), nodeResult); + } else { + result.add(nodeResult); + } - break; + safeguard.add(futureId); } + } catch (ExecutionException exception) { + safeguard.add(futureId); + exceptionHandler.accept(nodeExecution, exception.getCause()); + } catch (TimeoutException ignore) { + done = false; + } catch (InterruptedException cause) { + Thread.currentThread().interrupt(); + exceptionHandler.accept(nodeExecution, cause); + break; } } - - try { - Thread.sleep(10); - } catch (InterruptedException e) { - done = true; - Thread.currentThread().interrupt(); - } } if (!exceptions.isEmpty()) { @@ -306,8 +323,8 @@ public MultiNodeResult executeMultiKeyCommand(MultiKeyClusterCommandCa if (entry.getKey().isMaster()) { for (PositionalKey key : entry.getValue()) { - futures.put(new NodeExecution(entry.getKey(), key), this.executor - .submit(() -> executeMultiKeyCommandOnSingleNode(commandCallback, entry.getKey(), key.getBytes()))); + futures.put(new NodeExecution(entry.getKey(), key), this.executor.submit(() -> + executeMultiKeyCommandOnSingleNode(commandCallback, entry.getKey(), key.getBytes()))); } } } @@ -328,10 +345,11 @@ private NodeResult executeMultiKeyCommandOnSingleNode(MultiKeyClusterC try { return new NodeResult<>(node, commandCallback.doInCluster(client, key), key); - } catch (RuntimeException ex) { + } catch (RuntimeException cause) { + + RuntimeException translatedException = convertToDataAccessException(cause); - RuntimeException translatedException = convertToDataAccessException(ex); - throw translatedException != null ? translatedException : ex; + throw translatedException != null ? translatedException : cause; } finally { this.resourceProvider.returnResourceForSpecificNode(node, client); } @@ -343,7 +361,7 @@ private ClusterTopology getClusterTopology() { @Nullable private DataAccessException convertToDataAccessException(Exception cause) { - return exceptionTranslationStrategy.translate(cause); + return this.exceptionTranslationStrategy.translate(cause); } /** @@ -395,7 +413,7 @@ public interface MultiKeyClusterCommandCallback { * @author Mark Paluch * @since 1.7 */ - private static class NodeExecution { + static class NodeExecution { private final RedisClusterNode node; private final @Nullable PositionalKey positionalKey; @@ -414,7 +432,7 @@ private static class NodeExecution { * Get the {@link RedisClusterNode} the execution happens on. */ RedisClusterNode getNode() { - return node; + return this.node; } /** @@ -423,30 +441,31 @@ RedisClusterNode getNode() { * @since 2.0.3 */ PositionalKey getPositionalKey() { - return positionalKey; + return this.positionalKey; } boolean isPositional() { - return positionalKey != null; + return this.positionalKey != null; } } /** - * {@link NodeResult} encapsulates the actual value returned by a {@link ClusterCommandCallback} on a given - * {@link RedisClusterNode}. + * {@link NodeResult} encapsulates the actual {@link T value} returned by a {@link ClusterCommandCallback} + * on a given {@link RedisClusterNode}. * + * @param {@link Class Type} of the {@link Object value} returned in the result. * @author Christoph Strobl - * @param + * @author John Blum * @since 1.7 */ public static class NodeResult { private RedisClusterNode node; - private @Nullable T value; private ByteArrayWrapper key; + private @Nullable T value; /** - * Create new {@link NodeResult}. + * Create a new {@link NodeResult}. * * @param node must not be {@literal null}. * @param value can be {@literal null}. @@ -456,7 +475,7 @@ public NodeResult(RedisClusterNode node, @Nullable T value) { } /** - * Create new {@link NodeResult}. + * Create a new {@link NodeResult}. * * @param node must not be {@literal null}. * @param value can be {@literal null}. @@ -465,37 +484,36 @@ public NodeResult(RedisClusterNode node, @Nullable T value) { public NodeResult(RedisClusterNode node, @Nullable T value, byte[] key) { this.node = node; - this.value = value; - this.key = new ByteArrayWrapper(key); + this.value = value; } /** - * Get the actual value of the command execution. + * Get the {@link RedisClusterNode} the command was executed on. * - * @return can be {@literal null}. + * @return never {@literal null}. */ - @Nullable - public T getValue() { - return value; + public RedisClusterNode getNode() { + return this.node; } /** - * Get the {@link RedisClusterNode} the command was executed on. + * Return the {@link byte[] key} mapped to the value stored in Redis. * - * @return never {@literal null}. + * @return a {@link byte[] byte array} of the key mapped to the value stored in Redis. */ - public RedisClusterNode getNode() { - return node; + public byte[] getKey() { + return this.key.getArray(); } /** - * Returns the key as an array of bytes. + * Get the actual value of the command execution. * - * @return the key as an array of bytes. + * @return can be {@literal null}. */ - public byte[] getKey() { - return key.getArray(); + @Nullable + public T getValue() { + return this.value; } /** @@ -513,6 +531,34 @@ public U mapValue(Function mapper) { return mapper.apply(getValue()); } + + @Override + public boolean equals(@Nullable Object obj) { + + if (obj == this) { + return true; + } + + if (!(obj instanceof NodeResult that)) { + return false; + } + + return ObjectUtils.nullSafeEquals(this.getNode(), that.getNode()) + && Objects.equals(this.key, that.key) + && Objects.equals(this.getValue(), that.getValue()); + } + + @Override + public int hashCode() { + + int hashValue = 17; + + hashValue = 37 * hashValue + ObjectUtils.nullSafeHashCode(getNode()); + hashValue = 37 * hashValue + ObjectUtils.nullSafeHashCode(this.key); + hashValue = 37 * hashValue + ObjectUtils.nullSafeHashCode(getValue()); + + return hashValue; + } } /** @@ -566,12 +612,14 @@ public List resultsAsListSortBy(byte[]... keys) { if (positionalResults.isEmpty()) { List> clone = new ArrayList<>(nodeResults); + clone.sort(new ResultByReferenceKeyPositionComparator(keys)); return toList(clone); } Map> result = new TreeMap<>(new ResultByKeyPositionComparator(keys)); + result.putAll(positionalResults); return result.values().stream().map(tNodeResult -> tNodeResult.value).collect(Collectors.toList()); @@ -605,9 +653,11 @@ public T getFirstNonNullNotEmptyOrDefault(@Nullable T returnValue) { private List toList(Collection> source) { ArrayList result = new ArrayList<>(); + for (NodeResult nodeResult : source) { result.add(nodeResult.getValue()); } + return result; } @@ -678,7 +728,7 @@ static PositionalKey of(byte[] key, int index) { * @return binary key. */ byte[] getBytes() { - return key.getArray(); + return getKey().getArray(); } public ByteArrayWrapper getKey() { @@ -690,23 +740,23 @@ public int getPosition() { } @Override - public boolean equals(@Nullable Object o) { - if (this == o) - return true; - if (o == null || getClass() != o.getClass()) - return false; + public boolean equals(@Nullable Object obj) { - PositionalKey that = (PositionalKey) o; + if (this == obj) { + return true; + } - if (position != that.position) + if (!(obj instanceof PositionalKey that)) return false; - return ObjectUtils.nullSafeEquals(key, that.key); + + return this.getPosition() == that.getPosition() + && ObjectUtils.nullSafeEquals(this.getKey(), that.getKey()); } @Override public int hashCode() { - int result = ObjectUtils.nullSafeHashCode(key); - result = 31 * result + position; + int result = ObjectUtils.nullSafeHashCode(getKey()); + result = 31 * result + ObjectUtils.nullSafeHashCode(getPosition()); return result; } } @@ -753,6 +803,7 @@ static PositionalKeys of(byte[]... keys) { static PositionalKeys of(PositionalKey... keys) { PositionalKeys result = PositionalKeys.empty(); + result.append(keys); return result; @@ -769,12 +820,12 @@ void append(PositionalKey... keys) { * @return index of the {@link PositionalKey}. */ int indexOf(PositionalKey key) { - return keys.indexOf(key); + return this.keys.indexOf(key); } @Override public Iterator iterator() { - return keys.iterator(); + return this.keys.iterator(); } } } diff --git a/src/main/java/org/springframework/data/redis/connection/ClusterTopology.java b/src/main/java/org/springframework/data/redis/connection/ClusterTopology.java index f49236b188..8c4f6b1427 100644 --- a/src/main/java/org/springframework/data/redis/connection/ClusterTopology.java +++ b/src/main/java/org/springframework/data/redis/connection/ClusterTopology.java @@ -17,8 +17,10 @@ import java.util.Arrays; import java.util.Collections; -import java.util.LinkedHashSet; +import java.util.Objects; import java.util.Set; +import java.util.function.Predicate; +import java.util.stream.Collectors; import org.springframework.data.redis.ClusterStateFailureException; import org.springframework.lang.Nullable; @@ -26,20 +28,27 @@ import org.springframework.util.StringUtils; /** - * {@link ClusterTopology} holds snapshot like information about {@link RedisClusterNode}s. + * Holder of snapshot-like information about {@link RedisClusterNode nodes} in the Redis cluster. * * @author Christoph Strobl * @author Mark Paluch + * @author John Blum + * @see org.springframework.data.redis.connection.RedisClusterNode * @since 1.7 */ public class ClusterTopology { + protected static final Predicate ACTIVE_NODE_PREDICATE = + node -> node.isConnected() && !node.isMarkedAsFail(); + + protected static final Predicate MASTER_NODE_PREDICATE = RedisNode::isMaster; + private final Set nodes; /** - * Creates new instance of {@link ClusterTopology}. + * Creates a new {@link ClusterTopology} from the given {@link Set} of {@link RedisClusterNode}s. * - * @param nodes can be {@literal null}. + * @param nodes {@link RedisClusterNode nodes} forming the topology of the Redis cluster; can be {@literal null}. */ public ClusterTopology(@Nullable Set nodes) { this.nodes = nodes != null ? nodes : Collections.emptySet(); @@ -61,14 +70,7 @@ public Set getNodes() { * @return never {@literal null}. */ public Set getActiveNodes() { - - Set activeNodes = new LinkedHashSet<>(nodes.size()); - for (RedisClusterNode node : nodes) { - if (node.isConnected() && !node.isMarkedAsFail()) { - activeNodes.add(node); - } - } - return activeNodes; + return getNodes().stream().filter(ACTIVE_NODE_PREDICATE).collect(Collectors.toSet()); } /** @@ -79,13 +81,9 @@ public Set getActiveNodes() { */ public Set getActiveMasterNodes() { - Set activeMasterNodes = new LinkedHashSet<>(nodes.size()); - for (RedisClusterNode node : nodes) { - if (node.isMaster() && node.isConnected() && !node.isMarkedAsFail()) { - activeMasterNodes.add(node); - } - } - return activeMasterNodes; + return getNodes().stream() + .filter(MASTER_NODE_PREDICATE.and(ACTIVE_NODE_PREDICATE)) + .collect(Collectors.toSet()); } /** @@ -94,14 +92,7 @@ public Set getActiveMasterNodes() { * @return never {@literal null}. */ public Set getMasterNodes() { - - Set masterNodes = new LinkedHashSet<>(nodes.size()); - for (RedisClusterNode node : nodes) { - if (node.isMaster()) { - masterNodes.add(node); - } - } - return masterNodes; + return getNodes().stream().filter(MASTER_NODE_PREDICATE).collect(Collectors.toSet()); } /** @@ -111,14 +102,7 @@ public Set getMasterNodes() { * @return never {@literal null}. */ public Set getSlotServingNodes(int slot) { - - Set slotServingNodes = new LinkedHashSet<>(nodes.size()); - for (RedisClusterNode node : nodes) { - if (node.servesSlot(slot)) { - slotServingNodes.add(node); - } - } - return slotServingNodes; + return getNodes().stream().filter(node -> node.servesSlot(slot)).collect(Collectors.toSet()); } /** @@ -134,14 +118,12 @@ public RedisClusterNode getKeyServingMasterNode(byte[] key) { int slot = ClusterSlotHashUtil.calculateSlot(key); - for (RedisClusterNode node : nodes) { - if (node.isMaster() && node.servesSlot(slot)) { - return node; - } - } - - throw new ClusterStateFailureException( - String.format("Could not find master node serving slot %s for key '%s',", slot, Arrays.toString(key))); + return getNodes().stream() + .filter(MASTER_NODE_PREDICATE.and(node -> node.servesSlot(slot))) + .findFirst() + .orElseThrow(() -> new ClusterStateFailureException( + String.format("Could not find master node serving slot %s for key '%s',", + slot, Arrays.toString(key)))); } /** @@ -154,14 +136,20 @@ public RedisClusterNode getKeyServingMasterNode(byte[] key) { */ public RedisClusterNode lookup(String host, int port) { - for (RedisClusterNode node : nodes) { - if (host.equals(node.getHost()) && (node.getPort() != null && port == node.getPort())) { - return node; - } - } + return getNodes().stream() + .filter(isNodeMatch(host, port)) + .findFirst() + .orElseThrow(() -> new ClusterStateFailureException( + String.format("Could not find node at %s:%s; Is your cluster info up to date", host, port))); + } - throw new ClusterStateFailureException( - String.format("Could not find node at %s:%s; Is your cluster info up to date", host, port)); + private Predicate isNodeMatch(String host, int port) { + return node -> Objects.equals(node.getHost(), host) && resolvePort(node) == port; + } + + private int resolvePort(RedisClusterNode node) { + Integer port = node.getPort(); + return port != null ? port : -1; } /** @@ -175,14 +163,11 @@ public RedisClusterNode lookup(String nodeId) { Assert.notNull(nodeId, "NodeId must not be null"); - for (RedisClusterNode node : nodes) { - if (nodeId.equals(node.getId())) { - return node; - } - } - - throw new ClusterStateFailureException( - String.format("Could not find node at %s; Is your cluster info up to date", nodeId)); + return getNodes().stream() + .filter(node -> nodeId.equals(node.getId())) + .findFirst() + .orElseThrow(() -> new ClusterStateFailureException( + String.format("Could not find node at %s; Is your cluster info up to date", nodeId))); } /** @@ -219,7 +204,8 @@ public RedisClusterNode lookup(RedisClusterNode node) { */ public Set getKeyServingNodes(byte[] key) { - Assert.notNull(key, "Key must not be null for Cluster Node lookup."); + Assert.notNull(key, "Key must not be null for Cluster Node lookup"); + return getSlotServingNodes(ClusterSlotHashUtil.calculateSlot(key)); } } diff --git a/src/test/java/org/springframework/data/redis/connection/ClusterCommandExecutorUnitTests.java b/src/test/java/org/springframework/data/redis/connection/ClusterCommandExecutorUnitTests.java index f99e786ce4..25e3d47562 100644 --- a/src/test/java/org/springframework/data/redis/connection/ClusterCommandExecutorUnitTests.java +++ b/src/test/java/org/springframework/data/redis/connection/ClusterCommandExecutorUnitTests.java @@ -15,16 +15,38 @@ */ package org.springframework.data.redis.connection; -import static org.assertj.core.api.Assertions.*; -import static org.mockito.Mockito.*; -import static org.springframework.data.redis.test.util.MockitoUtils.*; - +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; +import static org.springframework.data.redis.test.util.MockitoUtils.verifyInvocationsAcross; + +import java.time.Instant; import java.util.Arrays; +import java.util.Comparator; +import java.util.HashMap; import java.util.HashSet; -import java.util.LinkedHashSet; +import java.util.Map; +import java.util.Set; import java.util.concurrent.Callable; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentSkipListMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Predicate; +import java.util.function.Supplier; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; @@ -33,6 +55,7 @@ import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.mockito.stubbing.Answer; import org.springframework.core.convert.converter.Converter; import org.springframework.core.task.AsyncTaskExecutor; @@ -45,14 +68,24 @@ import org.springframework.data.redis.connection.ClusterCommandExecutor.ClusterCommandCallback; import org.springframework.data.redis.connection.ClusterCommandExecutor.MultiKeyClusterCommandCallback; import org.springframework.data.redis.connection.ClusterCommandExecutor.MultiNodeResult; +import org.springframework.data.redis.connection.ClusterCommandExecutor.NodeExecution; +import org.springframework.data.redis.connection.ClusterCommandExecutor.NodeResult; import org.springframework.data.redis.connection.RedisClusterNode.LinkState; import org.springframework.data.redis.connection.RedisClusterNode.SlotRange; import org.springframework.data.redis.connection.RedisNode.NodeType; +import org.springframework.data.redis.test.util.MockitoUtils; import org.springframework.scheduling.concurrent.ConcurrentTaskExecutor; +import edu.umd.cs.mtc.MultithreadedTestCase; +import edu.umd.cs.mtc.TestFramework; + /** + * Unit Tests for {@link ClusterCommandExecutor}. + * * @author Christoph Strobl * @author Mark Paluch + * @author John Blum + * @since 1.7 */ @ExtendWith(MockitoExtension.class) class ClusterCommandExecutorUnitTests { @@ -66,17 +99,32 @@ class ClusterCommandExecutorUnitTests { private static final int CLUSTER_NODE_3_PORT = 7381; private static final RedisClusterNode CLUSTER_NODE_1 = RedisClusterNode.newRedisClusterNode() - .listeningAt(CLUSTER_NODE_1_HOST, CLUSTER_NODE_1_PORT).serving(new SlotRange(0, 5460)) - .withId("ef570f86c7b1a953846668debc177a3a16733420").promotedAs(NodeType.MASTER).linkState(LinkState.CONNECTED) + .listeningAt(CLUSTER_NODE_1_HOST, CLUSTER_NODE_1_PORT) + .serving(new SlotRange(0, 5460)) + .withId("ef570f86c7b1a953846668debc177a3a16733420") + .promotedAs(NodeType.MASTER) + .linkState(LinkState.CONNECTED) + .withName("ClusterNodeX") .build(); + private static final RedisClusterNode CLUSTER_NODE_2 = RedisClusterNode.newRedisClusterNode() - .listeningAt(CLUSTER_NODE_2_HOST, CLUSTER_NODE_2_PORT).serving(new SlotRange(5461, 10922)) - .withId("0f2ee5df45d18c50aca07228cc18b1da96fd5e84").promotedAs(NodeType.MASTER).linkState(LinkState.CONNECTED) + .listeningAt(CLUSTER_NODE_2_HOST, CLUSTER_NODE_2_PORT) + .serving(new SlotRange(5461, 10922)) + .withId("0f2ee5df45d18c50aca07228cc18b1da96fd5e84") + .promotedAs(NodeType.MASTER) + .linkState(LinkState.CONNECTED) + .withName("ClusterNodeY") .build(); + private static final RedisClusterNode CLUSTER_NODE_3 = RedisClusterNode.newRedisClusterNode() - .listeningAt(CLUSTER_NODE_3_HOST, CLUSTER_NODE_3_PORT).serving(new SlotRange(10923, 16383)) - .withId("3b9b8192a874fa8f1f09dbc0ee20afab5738eee7").promotedAs(NodeType.MASTER).linkState(LinkState.CONNECTED) + .listeningAt(CLUSTER_NODE_3_HOST, CLUSTER_NODE_3_PORT) + .serving(new SlotRange(10923, 16383)) + .withId("3b9b8192a874fa8f1f09dbc0ee20afab5738eee7") + .promotedAs(NodeType.MASTER) + .linkState(LinkState.CONNECTED) + .withName("ClusterNodeZ") .build(); + private static final RedisClusterNode CLUSTER_NODE_2_LOOKUP = RedisClusterNode.newRedisClusterNode() .withId("0f2ee5df45d18c50aca07228cc18b1da96fd5e84").build(); @@ -88,8 +136,8 @@ class ClusterCommandExecutorUnitTests { private static final Converter exceptionConverter = source -> { - if (source instanceof MovedException) { - return new ClusterRedirectException(1000, ((MovedException) source).host, ((MovedException) source).port, source); + if (source instanceof MovedException movedException) { + return new ClusterRedirectException(1000, movedException.host, movedException.port, source); } return new InvalidDataAccessApiUsageException(source.getMessage(), source); @@ -97,14 +145,14 @@ class ClusterCommandExecutorUnitTests { private static final MultiKeyConnectionCommandCallback MULTIKEY_CALLBACK = Connection::bloodAndAshes; - @Mock Connection con1; - @Mock Connection con2; - @Mock Connection con3; + @Mock Connection connection1; + @Mock Connection connection2; + @Mock Connection connection3; @BeforeEach void setUp() { - this.executor = new ClusterCommandExecutor(new MockClusterNodeProvider(), new MockClusterResourceProvider(), + this.executor = new ClusterCommandExecutor(new MockClusterNodeProvider(), new MockClusterNodeResourceProvider(), new PassThroughExceptionTranslationStrategy(exceptionConverter), new ImmediateExecutor()); } @@ -118,7 +166,7 @@ void executeCommandOnSingleNodeShouldBeExecutedCorrectly() { executor.executeCommandOnSingleNode(COMMAND_CALLBACK, CLUSTER_NODE_2); - verify(con2, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection2, times(1)).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 @@ -127,7 +175,7 @@ void executeCommandOnSingleNodeByHostAndPortShouldBeExecutedCorrectly() { executor.executeCommandOnSingleNode(COMMAND_CALLBACK, new RedisClusterNode(CLUSTER_NODE_2_HOST, CLUSTER_NODE_2_PORT)); - verify(con2, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection2, times(1)).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 @@ -135,15 +183,17 @@ void executeCommandOnSingleNodeByNodeIdShouldBeExecutedCorrectly() { executor.executeCommandOnSingleNode(COMMAND_CALLBACK, new RedisClusterNode(CLUSTER_NODE_2.id)); - verify(con2, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection2, times(1)).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 + @SuppressWarnings("all") void executeCommandOnSingleNodeShouldThrowExceptionWhenNodeIsNull() { assertThatIllegalArgumentException().isThrownBy(() -> executor.executeCommandOnSingleNode(COMMAND_CALLBACK, null)); } @Test // DATAREDIS-315 + @SuppressWarnings("all") void executeCommandOnSingleNodeShouldThrowExceptionWhenCommandCallbackIsNull() { assertThatIllegalArgumentException().isThrownBy(() -> executor.executeCommandOnSingleNode(null, CLUSTER_NODE_1)); } @@ -158,52 +208,52 @@ void executeCommandOnSingleNodeShouldThrowExceptionWhenNodeIsUnknown() { void executeCommandAsyncOnNodesShouldExecuteCommandOnGivenNodes() { ClusterCommandExecutor executor = new ClusterCommandExecutor(new MockClusterNodeProvider(), - new MockClusterResourceProvider(), new PassThroughExceptionTranslationStrategy(exceptionConverter), + new MockClusterNodeResourceProvider(), new PassThroughExceptionTranslationStrategy(exceptionConverter), new ConcurrentTaskExecutor(new SyncTaskExecutor())); executor.executeCommandAsyncOnNodes(COMMAND_CALLBACK, Arrays.asList(CLUSTER_NODE_1, CLUSTER_NODE_2)); - verify(con1, times(1)).theWheelWeavesAsTheWheelWills(); - verify(con2, times(1)).theWheelWeavesAsTheWheelWills(); - verify(con3, never()).theWheelWeavesAsTheWheelWills(); + verify(connection1, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection2, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection3, never()).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 void executeCommandAsyncOnNodesShouldExecuteCommandOnGivenNodesByHostAndPort() { ClusterCommandExecutor executor = new ClusterCommandExecutor(new MockClusterNodeProvider(), - new MockClusterResourceProvider(), new PassThroughExceptionTranslationStrategy(exceptionConverter), + new MockClusterNodeResourceProvider(), new PassThroughExceptionTranslationStrategy(exceptionConverter), new ConcurrentTaskExecutor(new SyncTaskExecutor())); executor.executeCommandAsyncOnNodes(COMMAND_CALLBACK, Arrays.asList(new RedisClusterNode(CLUSTER_NODE_1_HOST, CLUSTER_NODE_1_PORT), new RedisClusterNode(CLUSTER_NODE_2_HOST, CLUSTER_NODE_2_PORT))); - verify(con1, times(1)).theWheelWeavesAsTheWheelWills(); - verify(con2, times(1)).theWheelWeavesAsTheWheelWills(); - verify(con3, never()).theWheelWeavesAsTheWheelWills(); + verify(connection1, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection2, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection3, never()).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 void executeCommandAsyncOnNodesShouldExecuteCommandOnGivenNodesByNodeId() { ClusterCommandExecutor executor = new ClusterCommandExecutor(new MockClusterNodeProvider(), - new MockClusterResourceProvider(), new PassThroughExceptionTranslationStrategy(exceptionConverter), + new MockClusterNodeResourceProvider(), new PassThroughExceptionTranslationStrategy(exceptionConverter), new ConcurrentTaskExecutor(new SyncTaskExecutor())); executor.executeCommandAsyncOnNodes(COMMAND_CALLBACK, Arrays.asList(new RedisClusterNode(CLUSTER_NODE_1.id), CLUSTER_NODE_2_LOOKUP)); - verify(con1, times(1)).theWheelWeavesAsTheWheelWills(); - verify(con2, times(1)).theWheelWeavesAsTheWheelWills(); - verify(con3, never()).theWheelWeavesAsTheWheelWills(); + verify(connection1, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection2, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection3, never()).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 void executeCommandAsyncOnNodesShouldFailOnGivenUnknownNodes() { ClusterCommandExecutor executor = new ClusterCommandExecutor(new MockClusterNodeProvider(), - new MockClusterResourceProvider(), new PassThroughExceptionTranslationStrategy(exceptionConverter), + new MockClusterNodeResourceProvider(), new PassThroughExceptionTranslationStrategy(exceptionConverter), new ConcurrentTaskExecutor(new SyncTaskExecutor())); assertThatIllegalArgumentException().isThrownBy(() -> executor.executeCommandAsyncOnNodes(COMMAND_CALLBACK, @@ -214,42 +264,42 @@ void executeCommandAsyncOnNodesShouldFailOnGivenUnknownNodes() { void executeCommandOnAllNodesShouldExecuteCommandOnEveryKnownClusterNode() { ClusterCommandExecutor executor = new ClusterCommandExecutor(new MockClusterNodeProvider(), - new MockClusterResourceProvider(), new PassThroughExceptionTranslationStrategy(exceptionConverter), + new MockClusterNodeResourceProvider(), new PassThroughExceptionTranslationStrategy(exceptionConverter), new ConcurrentTaskExecutor(new SyncTaskExecutor())); executor.executeCommandOnAllNodes(COMMAND_CALLBACK); - verify(con1, times(1)).theWheelWeavesAsTheWheelWills(); - verify(con2, times(1)).theWheelWeavesAsTheWheelWills(); - verify(con3, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection1, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection2, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection3, times(1)).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 void executeCommandAsyncOnNodesShouldCompleteAndCollectErrorsOfAllNodes() { - when(con1.theWheelWeavesAsTheWheelWills()).thenReturn("rand"); - when(con2.theWheelWeavesAsTheWheelWills()).thenThrow(new IllegalStateException("(error) mat lost the dagger...")); - when(con3.theWheelWeavesAsTheWheelWills()).thenReturn("perrin"); + when(connection1.theWheelWeavesAsTheWheelWills()).thenReturn("rand"); + when(connection2.theWheelWeavesAsTheWheelWills()).thenThrow(new IllegalStateException("(error) mat lost the dagger...")); + when(connection3.theWheelWeavesAsTheWheelWills()).thenReturn("perrin"); try { executor.executeCommandOnAllNodes(COMMAND_CALLBACK); - } catch (ClusterCommandExecutionFailureException e) { + } catch (ClusterCommandExecutionFailureException cause) { - assertThat(e.getSuppressed()).hasSize(1); - assertThat(e.getSuppressed()[0]).isInstanceOf(DataAccessException.class); + assertThat(cause.getSuppressed()).hasSize(1); + assertThat(cause.getSuppressed()[0]).isInstanceOf(DataAccessException.class); } - verify(con1, times(1)).theWheelWeavesAsTheWheelWills(); - verify(con2, times(1)).theWheelWeavesAsTheWheelWills(); - verify(con3, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection1, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection2, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection3, times(1)).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 void executeCommandAsyncOnNodesShouldCollectResultsCorrectly() { - when(con1.theWheelWeavesAsTheWheelWills()).thenReturn("rand"); - when(con2.theWheelWeavesAsTheWheelWills()).thenReturn("mat"); - when(con3.theWheelWeavesAsTheWheelWills()).thenReturn("perrin"); + when(connection1.theWheelWeavesAsTheWheelWills()).thenReturn("rand"); + when(connection2.theWheelWeavesAsTheWheelWills()).thenReturn("mat"); + when(connection3.theWheelWeavesAsTheWheelWills()).thenReturn("perrin"); MultiNodeResult result = executor.executeCommandOnAllNodes(COMMAND_CALLBACK); @@ -261,10 +311,10 @@ void executeMultikeyCommandShouldRunCommandAcrossCluster() { // key-1 and key-9 map both to node1 ArgumentCaptor captor = ArgumentCaptor.forClass(byte[].class); - when(con1.bloodAndAshes(captor.capture())).thenReturn("rand").thenReturn("egwene"); - when(con2.bloodAndAshes(any(byte[].class))).thenReturn("mat"); - when(con3.bloodAndAshes(any(byte[].class))).thenReturn("perrin"); + when(connection1.bloodAndAshes(captor.capture())).thenReturn("rand").thenReturn("egwene"); + when(connection2.bloodAndAshes(any(byte[].class))).thenReturn("mat"); + when(connection3.bloodAndAshes(any(byte[].class))).thenReturn("perrin"); MultiNodeResult result = executor.executeMultiKeyCommand(MULTIKEY_CALLBACK, new HashSet<>( @@ -279,21 +329,21 @@ void executeMultikeyCommandShouldRunCommandAcrossCluster() { @Test // DATAREDIS-315 void executeCommandOnSingleNodeAndFollowRedirect() { - when(con1.theWheelWeavesAsTheWheelWills()).thenThrow(new MovedException(CLUSTER_NODE_3_HOST, CLUSTER_NODE_3_PORT)); + when(connection1.theWheelWeavesAsTheWheelWills()).thenThrow(new MovedException(CLUSTER_NODE_3_HOST, CLUSTER_NODE_3_PORT)); executor.executeCommandOnSingleNode(COMMAND_CALLBACK, CLUSTER_NODE_1); - verify(con1, times(1)).theWheelWeavesAsTheWheelWills(); - verify(con3, times(1)).theWheelWeavesAsTheWheelWills(); - verify(con2, never()).theWheelWeavesAsTheWheelWills(); + verify(connection1, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection3, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection2, never()).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 void executeCommandOnSingleNodeAndFollowRedirectButStopsAfterMaxRedirects() { - when(con1.theWheelWeavesAsTheWheelWills()).thenThrow(new MovedException(CLUSTER_NODE_3_HOST, CLUSTER_NODE_3_PORT)); - when(con3.theWheelWeavesAsTheWheelWills()).thenThrow(new MovedException(CLUSTER_NODE_2_HOST, CLUSTER_NODE_2_PORT)); - when(con2.theWheelWeavesAsTheWheelWills()).thenThrow(new MovedException(CLUSTER_NODE_1_HOST, CLUSTER_NODE_1_PORT)); + when(connection1.theWheelWeavesAsTheWheelWills()).thenThrow(new MovedException(CLUSTER_NODE_3_HOST, CLUSTER_NODE_3_PORT)); + when(connection3.theWheelWeavesAsTheWheelWills()).thenThrow(new MovedException(CLUSTER_NODE_2_HOST, CLUSTER_NODE_2_PORT)); + when(connection2.theWheelWeavesAsTheWheelWills()).thenThrow(new MovedException(CLUSTER_NODE_1_HOST, CLUSTER_NODE_1_PORT)); try { executor.setMaxRedirects(4); @@ -302,9 +352,9 @@ void executeCommandOnSingleNodeAndFollowRedirectButStopsAfterMaxRedirects() { assertThat(e).isInstanceOf(TooManyClusterRedirectionsException.class); } - verify(con1, times(2)).theWheelWeavesAsTheWheelWills(); - verify(con3, times(2)).theWheelWeavesAsTheWheelWills(); - verify(con2, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection1, times(2)).theWheelWeavesAsTheWheelWills(); + verify(connection3, times(2)).theWheelWeavesAsTheWheelWills(); + verify(connection2, times(1)).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 @@ -312,53 +362,249 @@ void executeCommandOnArbitraryNodeShouldPickARandomNode() { executor.executeCommandOnArbitraryNode(COMMAND_CALLBACK); - verifyInvocationsAcross("theWheelWeavesAsTheWheelWills", times(1), con1, con2, con3); + verifyInvocationsAcross("theWheelWeavesAsTheWheelWills", times(1), connection1, connection2, connection3); + } + + @Test // GH-2518 + void collectResultsCompletesSuccessfully() { + + Instant done = Instant.now().plusMillis(5); + + Predicate>> isDone = future -> Instant.now().isAfter(done); + + Map>> futures = new HashMap<>(); + + NodeResult nodeOneA = newNodeResult(CLUSTER_NODE_1, "A"); + NodeResult nodeTwoB = newNodeResult(CLUSTER_NODE_2, "B"); + NodeResult nodeThreeC = newNodeResult(CLUSTER_NODE_3, "C"); + + futures.put(newNodeExecution(CLUSTER_NODE_1), mockFutureAndIsDone(nodeOneA, isDone)); + futures.put(newNodeExecution(CLUSTER_NODE_2), mockFutureAndIsDone(nodeTwoB, isDone)); + futures.put(newNodeExecution(CLUSTER_NODE_3), mockFutureAndIsDone(nodeThreeC, isDone)); + + MultiNodeResult results = this.executor.collectResults(futures); + + assertThat(results).isNotNull(); + assertThat(results.getResults()).containsExactlyInAnyOrder(nodeOneA, nodeTwoB, nodeThreeC); + + futures.values().forEach(future -> + runsSafely(() -> verify(future, times(1)).get(anyLong(), any(TimeUnit.class)))); + } + + @Test // GH-2518 + void collectResultsCompletesSuccessfullyEvenWithTimeouts() throws Exception { + + Map>> futures = new HashMap<>(); + + NodeResult nodeOneA = newNodeResult(CLUSTER_NODE_1, "A"); + NodeResult nodeTwoB = newNodeResult(CLUSTER_NODE_2, "B"); + NodeResult nodeThreeC = newNodeResult(CLUSTER_NODE_3, "C"); + + Future> nodeOneFutureResult = mockFutureThrowingTimeoutException(nodeOneA, 4); + Future> nodeTwoFutureResult = mockFutureThrowingTimeoutException(nodeTwoB, 1); + Future> nodeThreeFutureResult = mockFutureThrowingTimeoutException(nodeThreeC, 2); + + futures.put(newNodeExecution(CLUSTER_NODE_1), nodeOneFutureResult); + futures.put(newNodeExecution(CLUSTER_NODE_2), nodeTwoFutureResult); + futures.put(newNodeExecution(CLUSTER_NODE_3), nodeThreeFutureResult); + + MultiNodeResult results = this.executor.collectResults(futures); + + assertThat(results).isNotNull(); + assertThat(results.getResults()).containsExactlyInAnyOrder(nodeOneA, nodeTwoB, nodeThreeC); + + verify(nodeOneFutureResult, times(4)).get(anyLong(), any(TimeUnit.class)); + verify(nodeTwoFutureResult, times(1)).get(anyLong(), any(TimeUnit.class)); + verify(nodeThreeFutureResult, times(2)).get(anyLong(), any(TimeUnit.class)); + verifyNoMoreInteractions(nodeOneFutureResult, nodeTwoFutureResult, nodeThreeFutureResult); + } + + @Test // GH-2518 + void collectResultsFailsWithExecutionException() { + + Map>> futures = new HashMap<>(); + + NodeResult nodeOneA = newNodeResult(CLUSTER_NODE_1, "A"); + + futures.put(newNodeExecution(CLUSTER_NODE_1), mockFutureAndIsDone(nodeOneA, future -> true)); + futures.put(newNodeExecution(CLUSTER_NODE_2), mockFutureThrowingExecutionException( + new ExecutionException("TestError", new IllegalArgumentException("MockError")))); + + assertThatExceptionOfType(ClusterCommandExecutionFailureException.class) + .isThrownBy(() -> this.executor.collectResults(futures)) + .withMessage("MockError") + .withCauseInstanceOf(InvalidDataAccessApiUsageException.class) + .extracting(Throwable::getCause) + .extracting(Throwable::getCause) + .isInstanceOf(IllegalArgumentException.class) + .extracting(Throwable::getMessage) + .isEqualTo("MockError"); } - class MockClusterNodeProvider implements ClusterTopologyProvider { + @Test // GH-2518 + void collectResultsFailsWithInterruptedException() throws Throwable { + TestFramework.runOnce(new CollectResultsInterruptedMultithreadedTestCase(this.executor)); + } + + // Future.get() for X will get called twice if at least one other Future is not done and Future.get() for X + // threw an ExecutionException in the previous iteration, thereby marking it as done! + @Test // GH-2518 + @SuppressWarnings("all") + void collectResultsCallsFutureGetOnlyOnce() throws Exception { + + AtomicInteger count = new AtomicInteger(0); + Map>> futures = new HashMap<>(); + + Future> clusterNodeOneFutureResult = mockFutureAndIsDone(null, future -> + count.incrementAndGet() % 2 == 0); + + Future> clusterNodeTwoFutureResult = mockFutureThrowingExecutionException( + new ExecutionException("TestError", new IllegalArgumentException("MockError"))); + + futures.put(newNodeExecution(CLUSTER_NODE_1), clusterNodeOneFutureResult); + futures.put(newNodeExecution(CLUSTER_NODE_2), clusterNodeTwoFutureResult); + + assertThatExceptionOfType(ClusterCommandExecutionFailureException.class) + .isThrownBy(() -> this.executor.collectResults(futures)); + + verify(clusterNodeOneFutureResult, times(1)).get(anyLong(), any()); + verify(clusterNodeTwoFutureResult, times(1)).get(anyLong(), any()); + } + + // Covers the case where Future.get() is mistakenly called multiple times, or if the Future.isDone() implementation + // does not properly take into account Future.get() throwing an ExecutionException during computation subsequently + // returning false instead of true. + // This should be properly handled by the "safeguard" (see collectResultsCallsFutureGetOnlyOnce()), but... + // just in case! The ExecutionException handler now stores the [DataAccess]Exception with Map.putIfAbsent(..). + @Test // GH-2518 + @SuppressWarnings("all") + void collectResultsCapturesFirstExecutionExceptionOnly() { + + AtomicInteger count = new AtomicInteger(0); + AtomicInteger exceptionCount = new AtomicInteger(0); + + Map>> futures = new HashMap<>(); + + futures.put(newNodeExecution(CLUSTER_NODE_1), + mockFutureAndIsDone(null, future -> count.incrementAndGet() % 2 == 0)); + + futures.put(newNodeExecution(CLUSTER_NODE_2), mockFutureThrowingExecutionException(() -> + new ExecutionException("TestError", new IllegalStateException("MockError" + exceptionCount.getAndIncrement())))); + + assertThatExceptionOfType(ClusterCommandExecutionFailureException.class) + .isThrownBy(() -> this.executor.collectResults(futures)) + .withMessage("MockError0") + .withCauseInstanceOf(InvalidDataAccessApiUsageException.class) + .extracting(Throwable::getCause) + .extracting(Throwable::getCause) + .isInstanceOf(IllegalStateException.class) + .extracting(Throwable::getMessage) + .isEqualTo("MockError0"); + } + + private Future mockFutureAndIsDone(T result, Predicate> isDone) { + + return MockitoUtils.mockFuture(result, future -> { + doAnswer(invocation -> isDone.test(future)).when(future).isDone(); + return future; + }); + } + + private Future mockFutureThrowingExecutionException(ExecutionException exception) { + return mockFutureThrowingExecutionException(() -> exception); + } + + private Future mockFutureThrowingExecutionException(Supplier exceptionSupplier) { + + Answer getAnswer = invocationOnMock -> { throw exceptionSupplier.get(); }; + + return MockitoUtils.mockFuture(null, future -> { + doReturn(true).when(future).isDone(); + doAnswer(getAnswer).when(future).get(); + doAnswer(getAnswer).when(future).get(anyLong(), any()); + return future; + }); + } + + @SuppressWarnings("unchecked") + private Future mockFutureThrowingTimeoutException(T result, int timeoutCount) { + + AtomicInteger counter = new AtomicInteger(timeoutCount); + + Answer getAnswer = invocationOnMock -> { + + if (counter.decrementAndGet() > 0) { + throw new TimeoutException("TIMES UP"); + } + + doReturn(true).when((Future>) invocationOnMock.getMock()).isDone(); + + return result; + }; + + return MockitoUtils.mockFuture(result, future -> { + + doAnswer(getAnswer).when(future).get(); + doAnswer(getAnswer).when(future).get(anyLong(), any()); + + return future; + }); + } + + private NodeExecution newNodeExecution(RedisClusterNode clusterNode) { + return new NodeExecution(clusterNode); + } + + private NodeResult newNodeResult(RedisClusterNode clusterNode, T value) { + return new NodeResult<>(clusterNode, value); + } + + private void runsSafely(ThrowableOperation operation) { + + try { + operation.run(); + } catch (Throwable ignore) { } + } + + static class MockClusterNodeProvider implements ClusterTopologyProvider { @Override public ClusterTopology getTopology() { - return new ClusterTopology( - new LinkedHashSet<>(Arrays.asList(CLUSTER_NODE_1, CLUSTER_NODE_2, CLUSTER_NODE_3))); + return new ClusterTopology(Set.of(CLUSTER_NODE_1, CLUSTER_NODE_2, CLUSTER_NODE_3)); } - } - class MockClusterResourceProvider implements ClusterNodeResourceProvider { + class MockClusterNodeResourceProvider implements ClusterNodeResourceProvider { @Override - public Connection getResourceForSpecificNode(RedisClusterNode node) { - - if (CLUSTER_NODE_1.equals(node)) { - return con1; - } - if (CLUSTER_NODE_2.equals(node)) { - return con2; - } - if (CLUSTER_NODE_3.equals(node)) { - return con3; - } + @SuppressWarnings("all") + public Connection getResourceForSpecificNode(RedisClusterNode clusterNode) { - return null; + return CLUSTER_NODE_1.equals(clusterNode) ? connection1 + : CLUSTER_NODE_2.equals(clusterNode) ? connection2 + : CLUSTER_NODE_3.equals(clusterNode) ? connection3 + : null; } @Override public void returnResourceForSpecificNode(RedisClusterNode node, Object resource) { - // TODO Auto-generated method stub } - } - static interface ConnectionCommandCallback extends ClusterCommandCallback { + interface ConnectionCommandCallback extends ClusterCommandCallback { } - static interface MultiKeyConnectionCommandCallback extends MultiKeyClusterCommandCallback { + interface MultiKeyConnectionCommandCallback extends MultiKeyClusterCommandCallback { } - static interface Connection { + @FunctionalInterface + interface ThrowableOperation { + void run() throws Throwable; + } + + interface Connection { String theWheelWeavesAsTheWheelWills(); @@ -374,19 +620,20 @@ static class MovedException extends RuntimeException { this.host = host; this.port = port; } - } static class ImmediateExecutor implements AsyncTaskExecutor { @Override - public void execute(Runnable runnable, long l) { + public void execute(Runnable runnable) { runnable.run(); } @Override public Future submit(Runnable runnable) { + return submit(() -> { + runnable.run(); return null; @@ -395,19 +642,103 @@ public Future submit(Runnable runnable) { @Override public Future submit(Callable callable) { + try { return CompletableFuture.completedFuture(callable.call()); - } catch (Exception e) { + } catch (Exception cause) { CompletableFuture future = new CompletableFuture<>(); - future.completeExceptionally(e); + + future.completeExceptionally(cause); + return future; } } + } + + @SuppressWarnings("all") + static class CollectResultsInterruptedMultithreadedTestCase extends MultithreadedTestCase { + + private static final CountDownLatch latch = new CountDownLatch(1); + + private static final Comparator NODE_COMPARATOR = + Comparator.comparing(nodeExecution -> nodeExecution.getNode().getName()); + + private final ClusterCommandExecutor clusterCommandExecutor; + + private final Map>> futureNodeResults; + + private Future> mockNodeOneFutureResult; + private Future> mockNodeTwoFutureResult; + + private volatile Thread collectResultsThread; + + private CollectResultsInterruptedMultithreadedTestCase(ClusterCommandExecutor clusterCommandExecutor) { + this.clusterCommandExecutor = clusterCommandExecutor; + this.futureNodeResults = new ConcurrentSkipListMap<>(NODE_COMPARATOR); + } @Override - public void execute(Runnable runnable) { - runnable.run(); + public void initialize() { + + super.initialize(); + + this.mockNodeOneFutureResult = this.futureNodeResults.computeIfAbsent(new NodeExecution(CLUSTER_NODE_1), + nodeExecution -> MockitoUtils.mockFuture(null, mockFuture -> { + doReturn(false).when(mockFuture).isDone(); + return mockFuture; + })); + + this.mockNodeTwoFutureResult = this.futureNodeResults.computeIfAbsent(new NodeExecution(CLUSTER_NODE_2), + nodeExecution -> MockitoUtils.mockFuture(null, mockFuture -> { + + doReturn(true).when(mockFuture).isDone(); + + doAnswer(invocation -> { + latch.await(); + return null; + }).when(mockFuture).get(anyLong(), any()); + + return mockFuture; + })); + } + + public void thread1() { + + assertTick(0); + + this.collectResultsThread = Thread.currentThread(); + this.collectResultsThread.setName("CollectResults Thread"); + + assertThatExceptionOfType(ClusterCommandExecutionFailureException.class) + .isThrownBy(() -> this.clusterCommandExecutor.collectResults(this.futureNodeResults)); + + assertThat(this.collectResultsThread.isInterrupted()).isTrue(); + } + + public void thread2() { + + assertTick(0); + + Thread.currentThread().setName("Interrupting Thread"); + + waitForTick(1); + + assertThat(this.collectResultsThread).isNotNull(); + assertThat(this.collectResultsThread.getName()).isEqualTo("CollectResults Thread"); + + this.collectResultsThread.interrupt(); + } + + @Override + public void finish() { + + try { + verify(this.mockNodeOneFutureResult, never()).get(); + verify(this.mockNodeTwoFutureResult, times(1)).get(anyLong(), any()); + } catch (ExecutionException | InterruptedException | TimeoutException cause) { + throw new RuntimeException(cause); + } } } } diff --git a/src/test/java/org/springframework/data/redis/test/util/MockitoUtils.java b/src/test/java/org/springframework/data/redis/test/util/MockitoUtils.java index e42f0749d8..318d47ec85 100644 --- a/src/test/java/org/springframework/data/redis/test/util/MockitoUtils.java +++ b/src/test/java/org/springframework/data/redis/test/util/MockitoUtils.java @@ -15,34 +15,131 @@ */ package org.springframework.data.redis.test.util; -import static org.mockito.Mockito.*; +import static org.mockito.Mockito.anyBoolean; +import static org.mockito.Mockito.anyLong; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.isA; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockingDetails; +import static org.mockito.Mockito.withSettings; import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.concurrent.CancellationException; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Function; import org.mockito.internal.invocation.InvocationMatcher; import org.mockito.internal.verification.api.VerificationData; import org.mockito.invocation.Invocation; import org.mockito.invocation.MatchableInvocation; +import org.mockito.quality.Strictness; +import org.mockito.stubbing.Answer; import org.mockito.verification.VerificationMode; + +import org.springframework.lang.NonNull; +import org.springframework.lang.Nullable; import org.springframework.util.StringUtils; /** + * Utilities for using {@literal Mockito} and creating {@link Object mock objects} in {@literal unit tests}. + * * @author Christoph Strobl + * @author John Blum + * @see org.mockito.Mockito * @since 1.7 */ -public class MockitoUtils { +@SuppressWarnings("unused") +public abstract class MockitoUtils { + + /** + * Creates a mock {@link Future} returning the given {@link Object result}. + * + * @param {@link Class type} of {@link Object result} returned by the mock {@link Future}. + * @param result {@link Object value} returned as the {@literal result} of the mock {@link Future}. + * @return a new mock {@link Future}. + * @see java.util.concurrent.Future + */ + @SuppressWarnings("unchecked") + public static @NonNull Future mockFuture(@Nullable T result) { + + try { + + AtomicBoolean cancelled = new AtomicBoolean(false); + AtomicBoolean done = new AtomicBoolean(false); + + Future mockFuture = mock(Future.class, withSettings().strictness(Strictness.LENIENT)); + + // A Future can only be cancelled if not done, it was not already cancelled, and no error occurred. + // The cancel(..) logic is not Thread-safe due to compound actions involving multiple variables. + // However, the cancel(..) logic does not necessarily need to be Thread-safe given the task execution + // of a Future is asynchronous and cancellation is driven by Thread interrupt from another Thread. + Answer cancelAnswer = invocation -> !done.get() + && cancelled.compareAndSet(done.get(), true) + && done.compareAndSet(done.get(), true); + + Answer getAnswer = invocation -> { + + // The Future is done no matter if it returns the result or was cancelled/interrupted. + done.set(true); + + if (Thread.currentThread().isInterrupted()) { + throw new InterruptedException("Thread was interrupted"); + } + + if (cancelled.get()) { + throw new CancellationException("Task was cancelled"); + } + + return result; + }; + + doAnswer(invocation -> cancelled.get()).when(mockFuture).isCancelled(); + doAnswer(invocation -> done.get()).when(mockFuture).isDone(); + doAnswer(cancelAnswer).when(mockFuture).cancel(anyBoolean()); + doAnswer(getAnswer).when(mockFuture).get(); + doAnswer(getAnswer).when(mockFuture).get(anyLong(), isA(TimeUnit.class)); + + return mockFuture; + } + catch (Exception cause) { + String message = String.format("Failed to create a mock of Future having result [%s]", result); + throw new IllegalStateException(message, cause); + } + } + + /** + * Creates a mock {@link Future} returning the given {@link Object result}, customized with the given, + * required {@link Function}. + * + * @param {@link Class type} of {@link Object result} returned by the mock {@link Future}. + * @param result {@link Object value} returned as the {@literal result} of the mock {@link Future}. + * @param futureFunction {@link Function} used to customize the mock {@link Future} on creation; + * must not be {@literal null}. + * @return a new mock {@link Future}. + * @see java.util.concurrent.Future + * @see java.util.function.Function + * @see #mockFuture(Object) + */ + public static @NonNull Future mockFuture(@Nullable T result, + @NonNull ThrowableFunction, Future> futureFunction) { + + Future mockFuture = mockFuture(result); + + return futureFunction.apply(mockFuture); + } /** * Verifies a given method is called a total number of times across all given mocks. * - * @param method - * @param mode - * @param mocks + * @param method {@link String name} of a {@link java.lang.reflect.Method} on the {@link Object mock object}. + * @param mode mode of verification used by {@literal Mockito} to verify invocations on {@link Object mock objects}. + * @param mocks array of {@link Object mock objects} to verify. */ - @SuppressWarnings({ "rawtypes", "serial" }) - public static void verifyInvocationsAcross(final String method, final VerificationMode mode, Object... mocks) { + public static void verifyInvocationsAcross(String method, VerificationMode mode, Object... mocks) { mode.verify(new VerificationDataImpl(getInvocations(method, mocks), new InvocationMatcher(null, Collections .singletonList(org.mockito.internal.matchers.Any.ANY)) { @@ -56,17 +153,15 @@ public boolean matches(Invocation actual) { public String toString() { return String.format("%s for method: %s", mode, method); } - })); } private static List getInvocations(String method, Object... mocks) { List invocations = new ArrayList<>(); - for (Object mock : mocks) { + for (Object mock : mocks) { if (StringUtils.hasText(method)) { - for (Invocation invocation : mockingDetails(mock).getInvocations()) { if (invocation.getMethod().getName().equals(method)) { invocations.add(invocation); @@ -76,6 +171,7 @@ private static List getInvocations(String method, Object... mocks) { invocations.addAll(mockingDetails(mock).getInvocations()); } } + return invocations; } @@ -98,7 +194,31 @@ public List getAllInvocations() { public MatchableInvocation getTarget() { return wanted; } - } + @FunctionalInterface + public interface ThrowableFunction extends Function { + + @Override + default R apply(T target) { + + try { + return applyThrowingException(target); + } + catch (Throwable cause) { + String message = String.format("Failed to apply Function [%s] to target [%s]", this, target); + throw new IllegalStateException(message, cause); + } + } + + R applyThrowingException(T target) throws Throwable; + + @SuppressWarnings("unchecked") + default @NonNull ThrowableFunction andThen( + @Nullable ThrowableFunction after) { + + return after == null ? (ThrowableFunction) this + : target -> after.apply(this.apply(target)); + } + } }