diff --git a/src/main/java/org/springframework/data/redis/connection/ClusterCommandExecutor.java b/src/main/java/org/springframework/data/redis/connection/ClusterCommandExecutor.java index f016dfc33f..d495ade564 100644 --- a/src/main/java/org/springframework/data/redis/connection/ClusterCommandExecutor.java +++ b/src/main/java/org/springframework/data/redis/connection/ClusterCommandExecutor.java @@ -15,10 +15,27 @@ */ package org.springframework.data.redis.connection; -import java.util.*; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; import java.util.Map.Entry; +import java.util.Objects; +import java.util.Random; +import java.util.Set; +import java.util.TreeMap; import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.function.BiConsumer; import java.util.function.Function; import java.util.stream.Collectors; @@ -43,6 +60,7 @@ * * @author Christoph Strobl * @author Mark Paluch + * @author John Blum * @since 1.7 */ public class ClusterCommandExecutor implements DisposableBean { @@ -58,7 +76,7 @@ public class ClusterCommandExecutor implements DisposableBean { private final ExceptionTranslationStrategy exceptionTranslationStrategy; /** - * Create a new instance of {@link ClusterCommandExecutor}. + * Create a new {@link ClusterCommandExecutor}. * * @param topologyProvider must not be {@literal null}. * @param resourceProvider must not be {@literal null}. @@ -92,40 +110,47 @@ public ClusterCommandExecutor(ClusterTopologyProvider topologyProvider, ClusterN /** * Run {@link ClusterCommandCallback} on a random node. * - * @param commandCallback must not be {@literal null}. + * @param clusterCommand must not be {@literal null}. * @return never {@literal null}. */ - public NodeResult executeCommandOnArbitraryNode(ClusterCommandCallback commandCallback) { + public NodeResult executeCommandOnArbitraryNode(ClusterCommandCallback clusterCommand) { - Assert.notNull(commandCallback, "ClusterCommandCallback must not be null"); + Assert.notNull(clusterCommand, "ClusterCommandCallback must not be null"); List nodes = new ArrayList<>(getClusterTopology().getActiveNodes()); - return executeCommandOnSingleNode(commandCallback, nodes.get(new Random().nextInt(nodes.size()))); + RedisClusterNode arbitraryNode = nodes.get(new Random().nextInt(nodes.size())); + + return executeCommandOnSingleNode(clusterCommand, arbitraryNode); } /** * Run {@link ClusterCommandCallback} on given {@link RedisClusterNode}. * - * @param cmd must not be {@literal null}. + * @param clusterCommand must not be {@literal null}. * @param node must not be {@literal null}. * @return the {@link NodeResult} from the single, targeted {@link RedisClusterNode}. * @throws IllegalArgumentException in case no resource can be acquired for given node. */ - public NodeResult executeCommandOnSingleNode(ClusterCommandCallback cmd, RedisClusterNode node) { - return executeCommandOnSingleNode(cmd, node, 0); + public NodeResult executeCommandOnSingleNode(ClusterCommandCallback clusterCommand, + RedisClusterNode node) { + + return executeCommandOnSingleNode(clusterCommand, node, 0); } - private NodeResult executeCommandOnSingleNode(ClusterCommandCallback cmd, RedisClusterNode node, - int redirectCount) { + private NodeResult executeCommandOnSingleNode(ClusterCommandCallback clusterCommand, + RedisClusterNode node, int redirectCount) { - Assert.notNull(cmd, "ClusterCommandCallback must not be null"); + Assert.notNull(clusterCommand, "ClusterCommandCallback must not be null"); Assert.notNull(node, "RedisClusterNode must not be null"); - if (redirectCount > maxRedirects) { - throw new TooManyClusterRedirectionsException(String.format( - "Cannot follow Cluster Redirects over more than %s legs; Please consider increasing the number of redirects to follow; Current value is: %s.", - redirectCount, maxRedirects)); + if (redirectCount > this.maxRedirects) { + + String message = String.format("Cannot follow Cluster Redirects over more than %s legs;" + + " Please consider increasing the number of redirects to follow; Current value is: %s.", + redirectCount, this.maxRedirects); + + throw new TooManyClusterRedirectionsException(message); } RedisClusterNode nodeToUse = lookupNode(node); @@ -135,14 +160,15 @@ private NodeResult executeCommandOnSingleNode(ClusterCommandCallback(node, cmd.doInCluster(client)); + return new NodeResult<>(node, clusterCommand.doInCluster(client)); } catch (RuntimeException cause) { RuntimeException translatedException = convertToDataAccessException(cause); if (translatedException instanceof ClusterRedirectException clusterRedirectException) { - return executeCommandOnSingleNode(cmd, topologyProvider.getTopology().lookup( - clusterRedirectException.getTargetHost(), clusterRedirectException.getTargetPort()), redirectCount + 1); + return executeCommandOnSingleNode(clusterCommand, topologyProvider.getTopology() + .lookup(clusterRedirectException.getTargetHost(), clusterRedirectException.getTargetPort()), + redirectCount + 1); } else { throw translatedException != null ? translatedException : cause; } @@ -152,10 +178,10 @@ private NodeResult executeCommandOnSingleNode(ClusterCommandCallback MultiNodeResult executeCommandOnAllNodes(final ClusterCommandCallback cmd) { - return executeCommandAsyncOnNodes(cmd, getClusterTopology().getActiveMasterNodes()); + public MultiNodeResult executeCommandOnAllNodes(ClusterCommandCallback clusterCommand) { + return executeCommandAsyncOnNodes(clusterCommand, getClusterTopology().getActiveMasterNodes()); } /** - * @param callback must not be {@literal null}. + * @param clusterCommand must not be {@literal null}. * @param nodes must not be {@literal null}. * @return never {@literal null}. * @throws ClusterCommandExecutionFailureException if a failure occurs while executing the given * {@link ClusterCommandCallback command} on any given {@link RedisClusterNode node}. * @throws IllegalArgumentException in case the node could not be resolved to a topology-known node */ - public MultiNodeResult executeCommandAsyncOnNodes(ClusterCommandCallback callback, + public MultiNodeResult executeCommandAsyncOnNodes(ClusterCommandCallback clusterCommand, Iterable nodes) { - Assert.notNull(callback, "Callback must not be null"); + Assert.notNull(clusterCommand, "Callback must not be null"); Assert.notNull(nodes, "Nodes must not be null"); + ClusterTopology topology = this.topologyProvider.getTopology(); List resolvedRedisClusterNodes = new ArrayList<>(); - ClusterTopology topology = topologyProvider.getTopology(); for (RedisClusterNode node : nodes) { try { resolvedRedisClusterNodes.add(topology.lookup(node)); - } catch (ClusterStateFailureException e) { - throw new IllegalArgumentException(String.format("Node %s is unknown to cluster", node), e); + } catch (ClusterStateFailureException cause) { + throw new IllegalArgumentException(String.format("Node %s is unknown to cluster", node), cause); } } Map>> futures = new LinkedHashMap<>(); for (RedisClusterNode node : resolvedRedisClusterNodes) { - futures.put(new NodeExecution(node), executor.submit(() -> executeCommandOnSingleNode(callback, node))); + futures.put(new NodeExecution(node), executor.submit(() -> executeCommandOnSingleNode(clusterCommand, node))); } return collectResults(futures); } - private MultiNodeResult collectResults(Map>> futures) { - - boolean done = false; + MultiNodeResult collectResults(Map>> futures) { Map exceptions = new HashMap<>(); MultiNodeResult result = new MultiNodeResult<>(); - Set saveGuard = new HashSet<>(); - - while (!done) { - - done = true; - - for (Map.Entry>> entry : futures.entrySet()) { + Set safeguard = new HashSet<>(); - if (!entry.getValue().isDone() && !entry.getValue().isCancelled()) { - done = false; - } else { - - NodeExecution execution = entry.getKey(); + BiConsumer exceptionHandler = (execution, throwable) -> { - try { + RuntimeException dataAccessException = convertToDataAccessException((Exception) throwable); - String futureId = ObjectUtils.getIdentityHexString(entry.getValue()); + exceptions.putIfAbsent(execution.getNode(), dataAccessException != null ? dataAccessException : throwable); + }; - if (!saveGuard.contains(futureId)) { + boolean done = false; - if (execution.isPositional()) { - result.add(execution.getPositionalKey(), entry.getValue().get()); - } else { - result.add(entry.getValue().get()); - } + while (!done) { - saveGuard.add(futureId); - } - } catch (ExecutionException cause) { + done = true; - RuntimeException exception = convertToDataAccessException((Exception) cause.getCause()); + for (Map.Entry>> entry : futures.entrySet()) { - exceptions.put(execution.getNode(), exception != null ? exception : cause.getCause()); - } catch (InterruptedException cause) { + NodeExecution nodeExecution = entry.getKey(); + Future> futureNodeResult = entry.getValue(); + String futureId = ObjectUtils.getIdentityHexString(futureNodeResult); - Thread.currentThread().interrupt(); + try { + if (!safeguard.contains(futureId)) { - RuntimeException exception = convertToDataAccessException((Exception) cause.getCause()); + NodeResult nodeResult = futureNodeResult.get(10L, TimeUnit.MICROSECONDS); - exceptions.put(execution.getNode(), exception != null ? exception : cause.getCause()); + if (nodeExecution.isPositional()) { + result.add(nodeExecution.getPositionalKey(), nodeResult); + } else { + result.add(nodeResult); + } - break; + safeguard.add(futureId); } + } catch (ExecutionException exception) { + safeguard.add(futureId); + exceptionHandler.accept(nodeExecution, exception.getCause()); + } catch (TimeoutException ignore) { + done = false; + } catch (InterruptedException cause) { + Thread.currentThread().interrupt(); + exceptionHandler.accept(nodeExecution, cause); + break; } } - - try { - Thread.sleep(10); - } catch (InterruptedException e) { - done = true; - Thread.currentThread().interrupt(); - } } if (!exceptions.isEmpty()) { @@ -306,8 +323,8 @@ public MultiNodeResult executeMultiKeyCommand(MultiKeyClusterCommandCa if (entry.getKey().isMaster()) { for (PositionalKey key : entry.getValue()) { - futures.put(new NodeExecution(entry.getKey(), key), this.executor - .submit(() -> executeMultiKeyCommandOnSingleNode(commandCallback, entry.getKey(), key.getBytes()))); + futures.put(new NodeExecution(entry.getKey(), key), this.executor.submit(() -> + executeMultiKeyCommandOnSingleNode(commandCallback, entry.getKey(), key.getBytes()))); } } } @@ -328,10 +345,11 @@ private NodeResult executeMultiKeyCommandOnSingleNode(MultiKeyClusterC try { return new NodeResult<>(node, commandCallback.doInCluster(client, key), key); - } catch (RuntimeException ex) { + } catch (RuntimeException cause) { + + RuntimeException translatedException = convertToDataAccessException(cause); - RuntimeException translatedException = convertToDataAccessException(ex); - throw translatedException != null ? translatedException : ex; + throw translatedException != null ? translatedException : cause; } finally { this.resourceProvider.returnResourceForSpecificNode(node, client); } @@ -343,7 +361,7 @@ private ClusterTopology getClusterTopology() { @Nullable private DataAccessException convertToDataAccessException(Exception cause) { - return exceptionTranslationStrategy.translate(cause); + return this.exceptionTranslationStrategy.translate(cause); } /** @@ -395,7 +413,7 @@ public interface MultiKeyClusterCommandCallback { * @author Mark Paluch * @since 1.7 */ - private static class NodeExecution { + static class NodeExecution { private final RedisClusterNode node; private final @Nullable PositionalKey positionalKey; @@ -414,7 +432,7 @@ private static class NodeExecution { * Get the {@link RedisClusterNode} the execution happens on. */ RedisClusterNode getNode() { - return node; + return this.node; } /** @@ -423,30 +441,31 @@ RedisClusterNode getNode() { * @since 2.0.3 */ PositionalKey getPositionalKey() { - return positionalKey; + return this.positionalKey; } boolean isPositional() { - return positionalKey != null; + return this.positionalKey != null; } } /** - * {@link NodeResult} encapsulates the actual value returned by a {@link ClusterCommandCallback} on a given - * {@link RedisClusterNode}. + * {@link NodeResult} encapsulates the actual {@link T value} returned by a {@link ClusterCommandCallback} + * on a given {@link RedisClusterNode}. * + * @param {@link Class Type} of the {@link Object value} returned in the result. * @author Christoph Strobl - * @param + * @author John Blum * @since 1.7 */ public static class NodeResult { private RedisClusterNode node; - private @Nullable T value; private ByteArrayWrapper key; + private @Nullable T value; /** - * Create new {@link NodeResult}. + * Create a new {@link NodeResult}. * * @param node must not be {@literal null}. * @param value can be {@literal null}. @@ -456,7 +475,7 @@ public NodeResult(RedisClusterNode node, @Nullable T value) { } /** - * Create new {@link NodeResult}. + * Create a new {@link NodeResult}. * * @param node must not be {@literal null}. * @param value can be {@literal null}. @@ -465,37 +484,36 @@ public NodeResult(RedisClusterNode node, @Nullable T value) { public NodeResult(RedisClusterNode node, @Nullable T value, byte[] key) { this.node = node; - this.value = value; - this.key = new ByteArrayWrapper(key); + this.value = value; } /** - * Get the actual value of the command execution. + * Get the {@link RedisClusterNode} the command was executed on. * - * @return can be {@literal null}. + * @return never {@literal null}. */ - @Nullable - public T getValue() { - return value; + public RedisClusterNode getNode() { + return this.node; } /** - * Get the {@link RedisClusterNode} the command was executed on. + * Return the {@link byte[] key} mapped to the value stored in Redis. * - * @return never {@literal null}. + * @return a {@link byte[] byte array} of the key mapped to the value stored in Redis. */ - public RedisClusterNode getNode() { - return node; + public byte[] getKey() { + return this.key.getArray(); } /** - * Returns the key as an array of bytes. + * Get the actual value of the command execution. * - * @return the key as an array of bytes. + * @return can be {@literal null}. */ - public byte[] getKey() { - return key.getArray(); + @Nullable + public T getValue() { + return this.value; } /** @@ -513,6 +531,34 @@ public U mapValue(Function mapper) { return mapper.apply(getValue()); } + + @Override + public boolean equals(@Nullable Object obj) { + + if (obj == this) { + return true; + } + + if (!(obj instanceof NodeResult that)) { + return false; + } + + return ObjectUtils.nullSafeEquals(this.getNode(), that.getNode()) + && Objects.equals(this.key, that.key) + && Objects.equals(this.getValue(), that.getValue()); + } + + @Override + public int hashCode() { + + int hashValue = 17; + + hashValue = 37 * hashValue + ObjectUtils.nullSafeHashCode(getNode()); + hashValue = 37 * hashValue + ObjectUtils.nullSafeHashCode(this.key); + hashValue = 37 * hashValue + ObjectUtils.nullSafeHashCode(getValue()); + + return hashValue; + } } /** @@ -566,12 +612,14 @@ public List resultsAsListSortBy(byte[]... keys) { if (positionalResults.isEmpty()) { List> clone = new ArrayList<>(nodeResults); + clone.sort(new ResultByReferenceKeyPositionComparator(keys)); return toList(clone); } Map> result = new TreeMap<>(new ResultByKeyPositionComparator(keys)); + result.putAll(positionalResults); return result.values().stream().map(tNodeResult -> tNodeResult.value).collect(Collectors.toList()); @@ -605,9 +653,11 @@ public T getFirstNonNullNotEmptyOrDefault(@Nullable T returnValue) { private List toList(Collection> source) { ArrayList result = new ArrayList<>(); + for (NodeResult nodeResult : source) { result.add(nodeResult.getValue()); } + return result; } @@ -678,7 +728,7 @@ static PositionalKey of(byte[] key, int index) { * @return binary key. */ byte[] getBytes() { - return key.getArray(); + return getKey().getArray(); } public ByteArrayWrapper getKey() { @@ -690,23 +740,23 @@ public int getPosition() { } @Override - public boolean equals(@Nullable Object o) { - if (this == o) - return true; - if (o == null || getClass() != o.getClass()) - return false; + public boolean equals(@Nullable Object obj) { - PositionalKey that = (PositionalKey) o; + if (this == obj) { + return true; + } - if (position != that.position) + if (!(obj instanceof PositionalKey that)) return false; - return ObjectUtils.nullSafeEquals(key, that.key); + + return this.getPosition() == that.getPosition() + && ObjectUtils.nullSafeEquals(this.getKey(), that.getKey()); } @Override public int hashCode() { - int result = ObjectUtils.nullSafeHashCode(key); - result = 31 * result + position; + int result = ObjectUtils.nullSafeHashCode(getKey()); + result = 31 * result + ObjectUtils.nullSafeHashCode(getPosition()); return result; } } @@ -753,6 +803,7 @@ static PositionalKeys of(byte[]... keys) { static PositionalKeys of(PositionalKey... keys) { PositionalKeys result = PositionalKeys.empty(); + result.append(keys); return result; @@ -769,12 +820,12 @@ void append(PositionalKey... keys) { * @return index of the {@link PositionalKey}. */ int indexOf(PositionalKey key) { - return keys.indexOf(key); + return this.keys.indexOf(key); } @Override public Iterator iterator() { - return keys.iterator(); + return this.keys.iterator(); } } } diff --git a/src/main/java/org/springframework/data/redis/connection/ClusterTopology.java b/src/main/java/org/springframework/data/redis/connection/ClusterTopology.java index f49236b188..8c4f6b1427 100644 --- a/src/main/java/org/springframework/data/redis/connection/ClusterTopology.java +++ b/src/main/java/org/springframework/data/redis/connection/ClusterTopology.java @@ -17,8 +17,10 @@ import java.util.Arrays; import java.util.Collections; -import java.util.LinkedHashSet; +import java.util.Objects; import java.util.Set; +import java.util.function.Predicate; +import java.util.stream.Collectors; import org.springframework.data.redis.ClusterStateFailureException; import org.springframework.lang.Nullable; @@ -26,20 +28,27 @@ import org.springframework.util.StringUtils; /** - * {@link ClusterTopology} holds snapshot like information about {@link RedisClusterNode}s. + * Holder of snapshot-like information about {@link RedisClusterNode nodes} in the Redis cluster. * * @author Christoph Strobl * @author Mark Paluch + * @author John Blum + * @see org.springframework.data.redis.connection.RedisClusterNode * @since 1.7 */ public class ClusterTopology { + protected static final Predicate ACTIVE_NODE_PREDICATE = + node -> node.isConnected() && !node.isMarkedAsFail(); + + protected static final Predicate MASTER_NODE_PREDICATE = RedisNode::isMaster; + private final Set nodes; /** - * Creates new instance of {@link ClusterTopology}. + * Creates a new {@link ClusterTopology} from the given {@link Set} of {@link RedisClusterNode}s. * - * @param nodes can be {@literal null}. + * @param nodes {@link RedisClusterNode nodes} forming the topology of the Redis cluster; can be {@literal null}. */ public ClusterTopology(@Nullable Set nodes) { this.nodes = nodes != null ? nodes : Collections.emptySet(); @@ -61,14 +70,7 @@ public Set getNodes() { * @return never {@literal null}. */ public Set getActiveNodes() { - - Set activeNodes = new LinkedHashSet<>(nodes.size()); - for (RedisClusterNode node : nodes) { - if (node.isConnected() && !node.isMarkedAsFail()) { - activeNodes.add(node); - } - } - return activeNodes; + return getNodes().stream().filter(ACTIVE_NODE_PREDICATE).collect(Collectors.toSet()); } /** @@ -79,13 +81,9 @@ public Set getActiveNodes() { */ public Set getActiveMasterNodes() { - Set activeMasterNodes = new LinkedHashSet<>(nodes.size()); - for (RedisClusterNode node : nodes) { - if (node.isMaster() && node.isConnected() && !node.isMarkedAsFail()) { - activeMasterNodes.add(node); - } - } - return activeMasterNodes; + return getNodes().stream() + .filter(MASTER_NODE_PREDICATE.and(ACTIVE_NODE_PREDICATE)) + .collect(Collectors.toSet()); } /** @@ -94,14 +92,7 @@ public Set getActiveMasterNodes() { * @return never {@literal null}. */ public Set getMasterNodes() { - - Set masterNodes = new LinkedHashSet<>(nodes.size()); - for (RedisClusterNode node : nodes) { - if (node.isMaster()) { - masterNodes.add(node); - } - } - return masterNodes; + return getNodes().stream().filter(MASTER_NODE_PREDICATE).collect(Collectors.toSet()); } /** @@ -111,14 +102,7 @@ public Set getMasterNodes() { * @return never {@literal null}. */ public Set getSlotServingNodes(int slot) { - - Set slotServingNodes = new LinkedHashSet<>(nodes.size()); - for (RedisClusterNode node : nodes) { - if (node.servesSlot(slot)) { - slotServingNodes.add(node); - } - } - return slotServingNodes; + return getNodes().stream().filter(node -> node.servesSlot(slot)).collect(Collectors.toSet()); } /** @@ -134,14 +118,12 @@ public RedisClusterNode getKeyServingMasterNode(byte[] key) { int slot = ClusterSlotHashUtil.calculateSlot(key); - for (RedisClusterNode node : nodes) { - if (node.isMaster() && node.servesSlot(slot)) { - return node; - } - } - - throw new ClusterStateFailureException( - String.format("Could not find master node serving slot %s for key '%s',", slot, Arrays.toString(key))); + return getNodes().stream() + .filter(MASTER_NODE_PREDICATE.and(node -> node.servesSlot(slot))) + .findFirst() + .orElseThrow(() -> new ClusterStateFailureException( + String.format("Could not find master node serving slot %s for key '%s',", + slot, Arrays.toString(key)))); } /** @@ -154,14 +136,20 @@ public RedisClusterNode getKeyServingMasterNode(byte[] key) { */ public RedisClusterNode lookup(String host, int port) { - for (RedisClusterNode node : nodes) { - if (host.equals(node.getHost()) && (node.getPort() != null && port == node.getPort())) { - return node; - } - } + return getNodes().stream() + .filter(isNodeMatch(host, port)) + .findFirst() + .orElseThrow(() -> new ClusterStateFailureException( + String.format("Could not find node at %s:%s; Is your cluster info up to date", host, port))); + } - throw new ClusterStateFailureException( - String.format("Could not find node at %s:%s; Is your cluster info up to date", host, port)); + private Predicate isNodeMatch(String host, int port) { + return node -> Objects.equals(node.getHost(), host) && resolvePort(node) == port; + } + + private int resolvePort(RedisClusterNode node) { + Integer port = node.getPort(); + return port != null ? port : -1; } /** @@ -175,14 +163,11 @@ public RedisClusterNode lookup(String nodeId) { Assert.notNull(nodeId, "NodeId must not be null"); - for (RedisClusterNode node : nodes) { - if (nodeId.equals(node.getId())) { - return node; - } - } - - throw new ClusterStateFailureException( - String.format("Could not find node at %s; Is your cluster info up to date", nodeId)); + return getNodes().stream() + .filter(node -> nodeId.equals(node.getId())) + .findFirst() + .orElseThrow(() -> new ClusterStateFailureException( + String.format("Could not find node at %s; Is your cluster info up to date", nodeId))); } /** @@ -219,7 +204,8 @@ public RedisClusterNode lookup(RedisClusterNode node) { */ public Set getKeyServingNodes(byte[] key) { - Assert.notNull(key, "Key must not be null for Cluster Node lookup."); + Assert.notNull(key, "Key must not be null for Cluster Node lookup"); + return getSlotServingNodes(ClusterSlotHashUtil.calculateSlot(key)); } } diff --git a/src/test/java/org/springframework/data/redis/connection/ClusterCommandExecutorUnitTests.java b/src/test/java/org/springframework/data/redis/connection/ClusterCommandExecutorUnitTests.java index f99e786ce4..25e3d47562 100644 --- a/src/test/java/org/springframework/data/redis/connection/ClusterCommandExecutorUnitTests.java +++ b/src/test/java/org/springframework/data/redis/connection/ClusterCommandExecutorUnitTests.java @@ -15,16 +15,38 @@ */ package org.springframework.data.redis.connection; -import static org.assertj.core.api.Assertions.*; -import static org.mockito.Mockito.*; -import static org.springframework.data.redis.test.util.MockitoUtils.*; - +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; +import static org.springframework.data.redis.test.util.MockitoUtils.verifyInvocationsAcross; + +import java.time.Instant; import java.util.Arrays; +import java.util.Comparator; +import java.util.HashMap; import java.util.HashSet; -import java.util.LinkedHashSet; +import java.util.Map; +import java.util.Set; import java.util.concurrent.Callable; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentSkipListMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Predicate; +import java.util.function.Supplier; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; @@ -33,6 +55,7 @@ import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.mockito.stubbing.Answer; import org.springframework.core.convert.converter.Converter; import org.springframework.core.task.AsyncTaskExecutor; @@ -45,14 +68,24 @@ import org.springframework.data.redis.connection.ClusterCommandExecutor.ClusterCommandCallback; import org.springframework.data.redis.connection.ClusterCommandExecutor.MultiKeyClusterCommandCallback; import org.springframework.data.redis.connection.ClusterCommandExecutor.MultiNodeResult; +import org.springframework.data.redis.connection.ClusterCommandExecutor.NodeExecution; +import org.springframework.data.redis.connection.ClusterCommandExecutor.NodeResult; import org.springframework.data.redis.connection.RedisClusterNode.LinkState; import org.springframework.data.redis.connection.RedisClusterNode.SlotRange; import org.springframework.data.redis.connection.RedisNode.NodeType; +import org.springframework.data.redis.test.util.MockitoUtils; import org.springframework.scheduling.concurrent.ConcurrentTaskExecutor; +import edu.umd.cs.mtc.MultithreadedTestCase; +import edu.umd.cs.mtc.TestFramework; + /** + * Unit Tests for {@link ClusterCommandExecutor}. + * * @author Christoph Strobl * @author Mark Paluch + * @author John Blum + * @since 1.7 */ @ExtendWith(MockitoExtension.class) class ClusterCommandExecutorUnitTests { @@ -66,17 +99,32 @@ class ClusterCommandExecutorUnitTests { private static final int CLUSTER_NODE_3_PORT = 7381; private static final RedisClusterNode CLUSTER_NODE_1 = RedisClusterNode.newRedisClusterNode() - .listeningAt(CLUSTER_NODE_1_HOST, CLUSTER_NODE_1_PORT).serving(new SlotRange(0, 5460)) - .withId("ef570f86c7b1a953846668debc177a3a16733420").promotedAs(NodeType.MASTER).linkState(LinkState.CONNECTED) + .listeningAt(CLUSTER_NODE_1_HOST, CLUSTER_NODE_1_PORT) + .serving(new SlotRange(0, 5460)) + .withId("ef570f86c7b1a953846668debc177a3a16733420") + .promotedAs(NodeType.MASTER) + .linkState(LinkState.CONNECTED) + .withName("ClusterNodeX") .build(); + private static final RedisClusterNode CLUSTER_NODE_2 = RedisClusterNode.newRedisClusterNode() - .listeningAt(CLUSTER_NODE_2_HOST, CLUSTER_NODE_2_PORT).serving(new SlotRange(5461, 10922)) - .withId("0f2ee5df45d18c50aca07228cc18b1da96fd5e84").promotedAs(NodeType.MASTER).linkState(LinkState.CONNECTED) + .listeningAt(CLUSTER_NODE_2_HOST, CLUSTER_NODE_2_PORT) + .serving(new SlotRange(5461, 10922)) + .withId("0f2ee5df45d18c50aca07228cc18b1da96fd5e84") + .promotedAs(NodeType.MASTER) + .linkState(LinkState.CONNECTED) + .withName("ClusterNodeY") .build(); + private static final RedisClusterNode CLUSTER_NODE_3 = RedisClusterNode.newRedisClusterNode() - .listeningAt(CLUSTER_NODE_3_HOST, CLUSTER_NODE_3_PORT).serving(new SlotRange(10923, 16383)) - .withId("3b9b8192a874fa8f1f09dbc0ee20afab5738eee7").promotedAs(NodeType.MASTER).linkState(LinkState.CONNECTED) + .listeningAt(CLUSTER_NODE_3_HOST, CLUSTER_NODE_3_PORT) + .serving(new SlotRange(10923, 16383)) + .withId("3b9b8192a874fa8f1f09dbc0ee20afab5738eee7") + .promotedAs(NodeType.MASTER) + .linkState(LinkState.CONNECTED) + .withName("ClusterNodeZ") .build(); + private static final RedisClusterNode CLUSTER_NODE_2_LOOKUP = RedisClusterNode.newRedisClusterNode() .withId("0f2ee5df45d18c50aca07228cc18b1da96fd5e84").build(); @@ -88,8 +136,8 @@ class ClusterCommandExecutorUnitTests { private static final Converter exceptionConverter = source -> { - if (source instanceof MovedException) { - return new ClusterRedirectException(1000, ((MovedException) source).host, ((MovedException) source).port, source); + if (source instanceof MovedException movedException) { + return new ClusterRedirectException(1000, movedException.host, movedException.port, source); } return new InvalidDataAccessApiUsageException(source.getMessage(), source); @@ -97,14 +145,14 @@ class ClusterCommandExecutorUnitTests { private static final MultiKeyConnectionCommandCallback MULTIKEY_CALLBACK = Connection::bloodAndAshes; - @Mock Connection con1; - @Mock Connection con2; - @Mock Connection con3; + @Mock Connection connection1; + @Mock Connection connection2; + @Mock Connection connection3; @BeforeEach void setUp() { - this.executor = new ClusterCommandExecutor(new MockClusterNodeProvider(), new MockClusterResourceProvider(), + this.executor = new ClusterCommandExecutor(new MockClusterNodeProvider(), new MockClusterNodeResourceProvider(), new PassThroughExceptionTranslationStrategy(exceptionConverter), new ImmediateExecutor()); } @@ -118,7 +166,7 @@ void executeCommandOnSingleNodeShouldBeExecutedCorrectly() { executor.executeCommandOnSingleNode(COMMAND_CALLBACK, CLUSTER_NODE_2); - verify(con2, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection2, times(1)).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 @@ -127,7 +175,7 @@ void executeCommandOnSingleNodeByHostAndPortShouldBeExecutedCorrectly() { executor.executeCommandOnSingleNode(COMMAND_CALLBACK, new RedisClusterNode(CLUSTER_NODE_2_HOST, CLUSTER_NODE_2_PORT)); - verify(con2, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection2, times(1)).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 @@ -135,15 +183,17 @@ void executeCommandOnSingleNodeByNodeIdShouldBeExecutedCorrectly() { executor.executeCommandOnSingleNode(COMMAND_CALLBACK, new RedisClusterNode(CLUSTER_NODE_2.id)); - verify(con2, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection2, times(1)).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 + @SuppressWarnings("all") void executeCommandOnSingleNodeShouldThrowExceptionWhenNodeIsNull() { assertThatIllegalArgumentException().isThrownBy(() -> executor.executeCommandOnSingleNode(COMMAND_CALLBACK, null)); } @Test // DATAREDIS-315 + @SuppressWarnings("all") void executeCommandOnSingleNodeShouldThrowExceptionWhenCommandCallbackIsNull() { assertThatIllegalArgumentException().isThrownBy(() -> executor.executeCommandOnSingleNode(null, CLUSTER_NODE_1)); } @@ -158,52 +208,52 @@ void executeCommandOnSingleNodeShouldThrowExceptionWhenNodeIsUnknown() { void executeCommandAsyncOnNodesShouldExecuteCommandOnGivenNodes() { ClusterCommandExecutor executor = new ClusterCommandExecutor(new MockClusterNodeProvider(), - new MockClusterResourceProvider(), new PassThroughExceptionTranslationStrategy(exceptionConverter), + new MockClusterNodeResourceProvider(), new PassThroughExceptionTranslationStrategy(exceptionConverter), new ConcurrentTaskExecutor(new SyncTaskExecutor())); executor.executeCommandAsyncOnNodes(COMMAND_CALLBACK, Arrays.asList(CLUSTER_NODE_1, CLUSTER_NODE_2)); - verify(con1, times(1)).theWheelWeavesAsTheWheelWills(); - verify(con2, times(1)).theWheelWeavesAsTheWheelWills(); - verify(con3, never()).theWheelWeavesAsTheWheelWills(); + verify(connection1, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection2, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection3, never()).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 void executeCommandAsyncOnNodesShouldExecuteCommandOnGivenNodesByHostAndPort() { ClusterCommandExecutor executor = new ClusterCommandExecutor(new MockClusterNodeProvider(), - new MockClusterResourceProvider(), new PassThroughExceptionTranslationStrategy(exceptionConverter), + new MockClusterNodeResourceProvider(), new PassThroughExceptionTranslationStrategy(exceptionConverter), new ConcurrentTaskExecutor(new SyncTaskExecutor())); executor.executeCommandAsyncOnNodes(COMMAND_CALLBACK, Arrays.asList(new RedisClusterNode(CLUSTER_NODE_1_HOST, CLUSTER_NODE_1_PORT), new RedisClusterNode(CLUSTER_NODE_2_HOST, CLUSTER_NODE_2_PORT))); - verify(con1, times(1)).theWheelWeavesAsTheWheelWills(); - verify(con2, times(1)).theWheelWeavesAsTheWheelWills(); - verify(con3, never()).theWheelWeavesAsTheWheelWills(); + verify(connection1, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection2, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection3, never()).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 void executeCommandAsyncOnNodesShouldExecuteCommandOnGivenNodesByNodeId() { ClusterCommandExecutor executor = new ClusterCommandExecutor(new MockClusterNodeProvider(), - new MockClusterResourceProvider(), new PassThroughExceptionTranslationStrategy(exceptionConverter), + new MockClusterNodeResourceProvider(), new PassThroughExceptionTranslationStrategy(exceptionConverter), new ConcurrentTaskExecutor(new SyncTaskExecutor())); executor.executeCommandAsyncOnNodes(COMMAND_CALLBACK, Arrays.asList(new RedisClusterNode(CLUSTER_NODE_1.id), CLUSTER_NODE_2_LOOKUP)); - verify(con1, times(1)).theWheelWeavesAsTheWheelWills(); - verify(con2, times(1)).theWheelWeavesAsTheWheelWills(); - verify(con3, never()).theWheelWeavesAsTheWheelWills(); + verify(connection1, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection2, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection3, never()).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 void executeCommandAsyncOnNodesShouldFailOnGivenUnknownNodes() { ClusterCommandExecutor executor = new ClusterCommandExecutor(new MockClusterNodeProvider(), - new MockClusterResourceProvider(), new PassThroughExceptionTranslationStrategy(exceptionConverter), + new MockClusterNodeResourceProvider(), new PassThroughExceptionTranslationStrategy(exceptionConverter), new ConcurrentTaskExecutor(new SyncTaskExecutor())); assertThatIllegalArgumentException().isThrownBy(() -> executor.executeCommandAsyncOnNodes(COMMAND_CALLBACK, @@ -214,42 +264,42 @@ void executeCommandAsyncOnNodesShouldFailOnGivenUnknownNodes() { void executeCommandOnAllNodesShouldExecuteCommandOnEveryKnownClusterNode() { ClusterCommandExecutor executor = new ClusterCommandExecutor(new MockClusterNodeProvider(), - new MockClusterResourceProvider(), new PassThroughExceptionTranslationStrategy(exceptionConverter), + new MockClusterNodeResourceProvider(), new PassThroughExceptionTranslationStrategy(exceptionConverter), new ConcurrentTaskExecutor(new SyncTaskExecutor())); executor.executeCommandOnAllNodes(COMMAND_CALLBACK); - verify(con1, times(1)).theWheelWeavesAsTheWheelWills(); - verify(con2, times(1)).theWheelWeavesAsTheWheelWills(); - verify(con3, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection1, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection2, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection3, times(1)).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 void executeCommandAsyncOnNodesShouldCompleteAndCollectErrorsOfAllNodes() { - when(con1.theWheelWeavesAsTheWheelWills()).thenReturn("rand"); - when(con2.theWheelWeavesAsTheWheelWills()).thenThrow(new IllegalStateException("(error) mat lost the dagger...")); - when(con3.theWheelWeavesAsTheWheelWills()).thenReturn("perrin"); + when(connection1.theWheelWeavesAsTheWheelWills()).thenReturn("rand"); + when(connection2.theWheelWeavesAsTheWheelWills()).thenThrow(new IllegalStateException("(error) mat lost the dagger...")); + when(connection3.theWheelWeavesAsTheWheelWills()).thenReturn("perrin"); try { executor.executeCommandOnAllNodes(COMMAND_CALLBACK); - } catch (ClusterCommandExecutionFailureException e) { + } catch (ClusterCommandExecutionFailureException cause) { - assertThat(e.getSuppressed()).hasSize(1); - assertThat(e.getSuppressed()[0]).isInstanceOf(DataAccessException.class); + assertThat(cause.getSuppressed()).hasSize(1); + assertThat(cause.getSuppressed()[0]).isInstanceOf(DataAccessException.class); } - verify(con1, times(1)).theWheelWeavesAsTheWheelWills(); - verify(con2, times(1)).theWheelWeavesAsTheWheelWills(); - verify(con3, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection1, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection2, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection3, times(1)).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 void executeCommandAsyncOnNodesShouldCollectResultsCorrectly() { - when(con1.theWheelWeavesAsTheWheelWills()).thenReturn("rand"); - when(con2.theWheelWeavesAsTheWheelWills()).thenReturn("mat"); - when(con3.theWheelWeavesAsTheWheelWills()).thenReturn("perrin"); + when(connection1.theWheelWeavesAsTheWheelWills()).thenReturn("rand"); + when(connection2.theWheelWeavesAsTheWheelWills()).thenReturn("mat"); + when(connection3.theWheelWeavesAsTheWheelWills()).thenReturn("perrin"); MultiNodeResult result = executor.executeCommandOnAllNodes(COMMAND_CALLBACK); @@ -261,10 +311,10 @@ void executeMultikeyCommandShouldRunCommandAcrossCluster() { // key-1 and key-9 map both to node1 ArgumentCaptor captor = ArgumentCaptor.forClass(byte[].class); - when(con1.bloodAndAshes(captor.capture())).thenReturn("rand").thenReturn("egwene"); - when(con2.bloodAndAshes(any(byte[].class))).thenReturn("mat"); - when(con3.bloodAndAshes(any(byte[].class))).thenReturn("perrin"); + when(connection1.bloodAndAshes(captor.capture())).thenReturn("rand").thenReturn("egwene"); + when(connection2.bloodAndAshes(any(byte[].class))).thenReturn("mat"); + when(connection3.bloodAndAshes(any(byte[].class))).thenReturn("perrin"); MultiNodeResult result = executor.executeMultiKeyCommand(MULTIKEY_CALLBACK, new HashSet<>( @@ -279,21 +329,21 @@ void executeMultikeyCommandShouldRunCommandAcrossCluster() { @Test // DATAREDIS-315 void executeCommandOnSingleNodeAndFollowRedirect() { - when(con1.theWheelWeavesAsTheWheelWills()).thenThrow(new MovedException(CLUSTER_NODE_3_HOST, CLUSTER_NODE_3_PORT)); + when(connection1.theWheelWeavesAsTheWheelWills()).thenThrow(new MovedException(CLUSTER_NODE_3_HOST, CLUSTER_NODE_3_PORT)); executor.executeCommandOnSingleNode(COMMAND_CALLBACK, CLUSTER_NODE_1); - verify(con1, times(1)).theWheelWeavesAsTheWheelWills(); - verify(con3, times(1)).theWheelWeavesAsTheWheelWills(); - verify(con2, never()).theWheelWeavesAsTheWheelWills(); + verify(connection1, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection3, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection2, never()).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 void executeCommandOnSingleNodeAndFollowRedirectButStopsAfterMaxRedirects() { - when(con1.theWheelWeavesAsTheWheelWills()).thenThrow(new MovedException(CLUSTER_NODE_3_HOST, CLUSTER_NODE_3_PORT)); - when(con3.theWheelWeavesAsTheWheelWills()).thenThrow(new MovedException(CLUSTER_NODE_2_HOST, CLUSTER_NODE_2_PORT)); - when(con2.theWheelWeavesAsTheWheelWills()).thenThrow(new MovedException(CLUSTER_NODE_1_HOST, CLUSTER_NODE_1_PORT)); + when(connection1.theWheelWeavesAsTheWheelWills()).thenThrow(new MovedException(CLUSTER_NODE_3_HOST, CLUSTER_NODE_3_PORT)); + when(connection3.theWheelWeavesAsTheWheelWills()).thenThrow(new MovedException(CLUSTER_NODE_2_HOST, CLUSTER_NODE_2_PORT)); + when(connection2.theWheelWeavesAsTheWheelWills()).thenThrow(new MovedException(CLUSTER_NODE_1_HOST, CLUSTER_NODE_1_PORT)); try { executor.setMaxRedirects(4); @@ -302,9 +352,9 @@ void executeCommandOnSingleNodeAndFollowRedirectButStopsAfterMaxRedirects() { assertThat(e).isInstanceOf(TooManyClusterRedirectionsException.class); } - verify(con1, times(2)).theWheelWeavesAsTheWheelWills(); - verify(con3, times(2)).theWheelWeavesAsTheWheelWills(); - verify(con2, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection1, times(2)).theWheelWeavesAsTheWheelWills(); + verify(connection3, times(2)).theWheelWeavesAsTheWheelWills(); + verify(connection2, times(1)).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 @@ -312,53 +362,249 @@ void executeCommandOnArbitraryNodeShouldPickARandomNode() { executor.executeCommandOnArbitraryNode(COMMAND_CALLBACK); - verifyInvocationsAcross("theWheelWeavesAsTheWheelWills", times(1), con1, con2, con3); + verifyInvocationsAcross("theWheelWeavesAsTheWheelWills", times(1), connection1, connection2, connection3); + } + + @Test // GH-2518 + void collectResultsCompletesSuccessfully() { + + Instant done = Instant.now().plusMillis(5); + + Predicate>> isDone = future -> Instant.now().isAfter(done); + + Map>> futures = new HashMap<>(); + + NodeResult nodeOneA = newNodeResult(CLUSTER_NODE_1, "A"); + NodeResult nodeTwoB = newNodeResult(CLUSTER_NODE_2, "B"); + NodeResult nodeThreeC = newNodeResult(CLUSTER_NODE_3, "C"); + + futures.put(newNodeExecution(CLUSTER_NODE_1), mockFutureAndIsDone(nodeOneA, isDone)); + futures.put(newNodeExecution(CLUSTER_NODE_2), mockFutureAndIsDone(nodeTwoB, isDone)); + futures.put(newNodeExecution(CLUSTER_NODE_3), mockFutureAndIsDone(nodeThreeC, isDone)); + + MultiNodeResult results = this.executor.collectResults(futures); + + assertThat(results).isNotNull(); + assertThat(results.getResults()).containsExactlyInAnyOrder(nodeOneA, nodeTwoB, nodeThreeC); + + futures.values().forEach(future -> + runsSafely(() -> verify(future, times(1)).get(anyLong(), any(TimeUnit.class)))); + } + + @Test // GH-2518 + void collectResultsCompletesSuccessfullyEvenWithTimeouts() throws Exception { + + Map>> futures = new HashMap<>(); + + NodeResult nodeOneA = newNodeResult(CLUSTER_NODE_1, "A"); + NodeResult nodeTwoB = newNodeResult(CLUSTER_NODE_2, "B"); + NodeResult nodeThreeC = newNodeResult(CLUSTER_NODE_3, "C"); + + Future> nodeOneFutureResult = mockFutureThrowingTimeoutException(nodeOneA, 4); + Future> nodeTwoFutureResult = mockFutureThrowingTimeoutException(nodeTwoB, 1); + Future> nodeThreeFutureResult = mockFutureThrowingTimeoutException(nodeThreeC, 2); + + futures.put(newNodeExecution(CLUSTER_NODE_1), nodeOneFutureResult); + futures.put(newNodeExecution(CLUSTER_NODE_2), nodeTwoFutureResult); + futures.put(newNodeExecution(CLUSTER_NODE_3), nodeThreeFutureResult); + + MultiNodeResult results = this.executor.collectResults(futures); + + assertThat(results).isNotNull(); + assertThat(results.getResults()).containsExactlyInAnyOrder(nodeOneA, nodeTwoB, nodeThreeC); + + verify(nodeOneFutureResult, times(4)).get(anyLong(), any(TimeUnit.class)); + verify(nodeTwoFutureResult, times(1)).get(anyLong(), any(TimeUnit.class)); + verify(nodeThreeFutureResult, times(2)).get(anyLong(), any(TimeUnit.class)); + verifyNoMoreInteractions(nodeOneFutureResult, nodeTwoFutureResult, nodeThreeFutureResult); + } + + @Test // GH-2518 + void collectResultsFailsWithExecutionException() { + + Map>> futures = new HashMap<>(); + + NodeResult nodeOneA = newNodeResult(CLUSTER_NODE_1, "A"); + + futures.put(newNodeExecution(CLUSTER_NODE_1), mockFutureAndIsDone(nodeOneA, future -> true)); + futures.put(newNodeExecution(CLUSTER_NODE_2), mockFutureThrowingExecutionException( + new ExecutionException("TestError", new IllegalArgumentException("MockError")))); + + assertThatExceptionOfType(ClusterCommandExecutionFailureException.class) + .isThrownBy(() -> this.executor.collectResults(futures)) + .withMessage("MockError") + .withCauseInstanceOf(InvalidDataAccessApiUsageException.class) + .extracting(Throwable::getCause) + .extracting(Throwable::getCause) + .isInstanceOf(IllegalArgumentException.class) + .extracting(Throwable::getMessage) + .isEqualTo("MockError"); } - class MockClusterNodeProvider implements ClusterTopologyProvider { + @Test // GH-2518 + void collectResultsFailsWithInterruptedException() throws Throwable { + TestFramework.runOnce(new CollectResultsInterruptedMultithreadedTestCase(this.executor)); + } + + // Future.get() for X will get called twice if at least one other Future is not done and Future.get() for X + // threw an ExecutionException in the previous iteration, thereby marking it as done! + @Test // GH-2518 + @SuppressWarnings("all") + void collectResultsCallsFutureGetOnlyOnce() throws Exception { + + AtomicInteger count = new AtomicInteger(0); + Map>> futures = new HashMap<>(); + + Future> clusterNodeOneFutureResult = mockFutureAndIsDone(null, future -> + count.incrementAndGet() % 2 == 0); + + Future> clusterNodeTwoFutureResult = mockFutureThrowingExecutionException( + new ExecutionException("TestError", new IllegalArgumentException("MockError"))); + + futures.put(newNodeExecution(CLUSTER_NODE_1), clusterNodeOneFutureResult); + futures.put(newNodeExecution(CLUSTER_NODE_2), clusterNodeTwoFutureResult); + + assertThatExceptionOfType(ClusterCommandExecutionFailureException.class) + .isThrownBy(() -> this.executor.collectResults(futures)); + + verify(clusterNodeOneFutureResult, times(1)).get(anyLong(), any()); + verify(clusterNodeTwoFutureResult, times(1)).get(anyLong(), any()); + } + + // Covers the case where Future.get() is mistakenly called multiple times, or if the Future.isDone() implementation + // does not properly take into account Future.get() throwing an ExecutionException during computation subsequently + // returning false instead of true. + // This should be properly handled by the "safeguard" (see collectResultsCallsFutureGetOnlyOnce()), but... + // just in case! The ExecutionException handler now stores the [DataAccess]Exception with Map.putIfAbsent(..). + @Test // GH-2518 + @SuppressWarnings("all") + void collectResultsCapturesFirstExecutionExceptionOnly() { + + AtomicInteger count = new AtomicInteger(0); + AtomicInteger exceptionCount = new AtomicInteger(0); + + Map>> futures = new HashMap<>(); + + futures.put(newNodeExecution(CLUSTER_NODE_1), + mockFutureAndIsDone(null, future -> count.incrementAndGet() % 2 == 0)); + + futures.put(newNodeExecution(CLUSTER_NODE_2), mockFutureThrowingExecutionException(() -> + new ExecutionException("TestError", new IllegalStateException("MockError" + exceptionCount.getAndIncrement())))); + + assertThatExceptionOfType(ClusterCommandExecutionFailureException.class) + .isThrownBy(() -> this.executor.collectResults(futures)) + .withMessage("MockError0") + .withCauseInstanceOf(InvalidDataAccessApiUsageException.class) + .extracting(Throwable::getCause) + .extracting(Throwable::getCause) + .isInstanceOf(IllegalStateException.class) + .extracting(Throwable::getMessage) + .isEqualTo("MockError0"); + } + + private Future mockFutureAndIsDone(T result, Predicate> isDone) { + + return MockitoUtils.mockFuture(result, future -> { + doAnswer(invocation -> isDone.test(future)).when(future).isDone(); + return future; + }); + } + + private Future mockFutureThrowingExecutionException(ExecutionException exception) { + return mockFutureThrowingExecutionException(() -> exception); + } + + private Future mockFutureThrowingExecutionException(Supplier exceptionSupplier) { + + Answer getAnswer = invocationOnMock -> { throw exceptionSupplier.get(); }; + + return MockitoUtils.mockFuture(null, future -> { + doReturn(true).when(future).isDone(); + doAnswer(getAnswer).when(future).get(); + doAnswer(getAnswer).when(future).get(anyLong(), any()); + return future; + }); + } + + @SuppressWarnings("unchecked") + private Future mockFutureThrowingTimeoutException(T result, int timeoutCount) { + + AtomicInteger counter = new AtomicInteger(timeoutCount); + + Answer getAnswer = invocationOnMock -> { + + if (counter.decrementAndGet() > 0) { + throw new TimeoutException("TIMES UP"); + } + + doReturn(true).when((Future>) invocationOnMock.getMock()).isDone(); + + return result; + }; + + return MockitoUtils.mockFuture(result, future -> { + + doAnswer(getAnswer).when(future).get(); + doAnswer(getAnswer).when(future).get(anyLong(), any()); + + return future; + }); + } + + private NodeExecution newNodeExecution(RedisClusterNode clusterNode) { + return new NodeExecution(clusterNode); + } + + private NodeResult newNodeResult(RedisClusterNode clusterNode, T value) { + return new NodeResult<>(clusterNode, value); + } + + private void runsSafely(ThrowableOperation operation) { + + try { + operation.run(); + } catch (Throwable ignore) { } + } + + static class MockClusterNodeProvider implements ClusterTopologyProvider { @Override public ClusterTopology getTopology() { - return new ClusterTopology( - new LinkedHashSet<>(Arrays.asList(CLUSTER_NODE_1, CLUSTER_NODE_2, CLUSTER_NODE_3))); + return new ClusterTopology(Set.of(CLUSTER_NODE_1, CLUSTER_NODE_2, CLUSTER_NODE_3)); } - } - class MockClusterResourceProvider implements ClusterNodeResourceProvider { + class MockClusterNodeResourceProvider implements ClusterNodeResourceProvider { @Override - public Connection getResourceForSpecificNode(RedisClusterNode node) { - - if (CLUSTER_NODE_1.equals(node)) { - return con1; - } - if (CLUSTER_NODE_2.equals(node)) { - return con2; - } - if (CLUSTER_NODE_3.equals(node)) { - return con3; - } + @SuppressWarnings("all") + public Connection getResourceForSpecificNode(RedisClusterNode clusterNode) { - return null; + return CLUSTER_NODE_1.equals(clusterNode) ? connection1 + : CLUSTER_NODE_2.equals(clusterNode) ? connection2 + : CLUSTER_NODE_3.equals(clusterNode) ? connection3 + : null; } @Override public void returnResourceForSpecificNode(RedisClusterNode node, Object resource) { - // TODO Auto-generated method stub } - } - static interface ConnectionCommandCallback extends ClusterCommandCallback { + interface ConnectionCommandCallback extends ClusterCommandCallback { } - static interface MultiKeyConnectionCommandCallback extends MultiKeyClusterCommandCallback { + interface MultiKeyConnectionCommandCallback extends MultiKeyClusterCommandCallback { } - static interface Connection { + @FunctionalInterface + interface ThrowableOperation { + void run() throws Throwable; + } + + interface Connection { String theWheelWeavesAsTheWheelWills(); @@ -374,19 +620,20 @@ static class MovedException extends RuntimeException { this.host = host; this.port = port; } - } static class ImmediateExecutor implements AsyncTaskExecutor { @Override - public void execute(Runnable runnable, long l) { + public void execute(Runnable runnable) { runnable.run(); } @Override public Future submit(Runnable runnable) { + return submit(() -> { + runnable.run(); return null; @@ -395,19 +642,103 @@ public Future submit(Runnable runnable) { @Override public Future submit(Callable callable) { + try { return CompletableFuture.completedFuture(callable.call()); - } catch (Exception e) { + } catch (Exception cause) { CompletableFuture future = new CompletableFuture<>(); - future.completeExceptionally(e); + + future.completeExceptionally(cause); + return future; } } + } + + @SuppressWarnings("all") + static class CollectResultsInterruptedMultithreadedTestCase extends MultithreadedTestCase { + + private static final CountDownLatch latch = new CountDownLatch(1); + + private static final Comparator NODE_COMPARATOR = + Comparator.comparing(nodeExecution -> nodeExecution.getNode().getName()); + + private final ClusterCommandExecutor clusterCommandExecutor; + + private final Map>> futureNodeResults; + + private Future> mockNodeOneFutureResult; + private Future> mockNodeTwoFutureResult; + + private volatile Thread collectResultsThread; + + private CollectResultsInterruptedMultithreadedTestCase(ClusterCommandExecutor clusterCommandExecutor) { + this.clusterCommandExecutor = clusterCommandExecutor; + this.futureNodeResults = new ConcurrentSkipListMap<>(NODE_COMPARATOR); + } @Override - public void execute(Runnable runnable) { - runnable.run(); + public void initialize() { + + super.initialize(); + + this.mockNodeOneFutureResult = this.futureNodeResults.computeIfAbsent(new NodeExecution(CLUSTER_NODE_1), + nodeExecution -> MockitoUtils.mockFuture(null, mockFuture -> { + doReturn(false).when(mockFuture).isDone(); + return mockFuture; + })); + + this.mockNodeTwoFutureResult = this.futureNodeResults.computeIfAbsent(new NodeExecution(CLUSTER_NODE_2), + nodeExecution -> MockitoUtils.mockFuture(null, mockFuture -> { + + doReturn(true).when(mockFuture).isDone(); + + doAnswer(invocation -> { + latch.await(); + return null; + }).when(mockFuture).get(anyLong(), any()); + + return mockFuture; + })); + } + + public void thread1() { + + assertTick(0); + + this.collectResultsThread = Thread.currentThread(); + this.collectResultsThread.setName("CollectResults Thread"); + + assertThatExceptionOfType(ClusterCommandExecutionFailureException.class) + .isThrownBy(() -> this.clusterCommandExecutor.collectResults(this.futureNodeResults)); + + assertThat(this.collectResultsThread.isInterrupted()).isTrue(); + } + + public void thread2() { + + assertTick(0); + + Thread.currentThread().setName("Interrupting Thread"); + + waitForTick(1); + + assertThat(this.collectResultsThread).isNotNull(); + assertThat(this.collectResultsThread.getName()).isEqualTo("CollectResults Thread"); + + this.collectResultsThread.interrupt(); + } + + @Override + public void finish() { + + try { + verify(this.mockNodeOneFutureResult, never()).get(); + verify(this.mockNodeTwoFutureResult, times(1)).get(anyLong(), any()); + } catch (ExecutionException | InterruptedException | TimeoutException cause) { + throw new RuntimeException(cause); + } } } } diff --git a/src/test/java/org/springframework/data/redis/test/util/MockitoUtils.java b/src/test/java/org/springframework/data/redis/test/util/MockitoUtils.java index e42f0749d8..318d47ec85 100644 --- a/src/test/java/org/springframework/data/redis/test/util/MockitoUtils.java +++ b/src/test/java/org/springframework/data/redis/test/util/MockitoUtils.java @@ -15,34 +15,131 @@ */ package org.springframework.data.redis.test.util; -import static org.mockito.Mockito.*; +import static org.mockito.Mockito.anyBoolean; +import static org.mockito.Mockito.anyLong; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.isA; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockingDetails; +import static org.mockito.Mockito.withSettings; import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.concurrent.CancellationException; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Function; import org.mockito.internal.invocation.InvocationMatcher; import org.mockito.internal.verification.api.VerificationData; import org.mockito.invocation.Invocation; import org.mockito.invocation.MatchableInvocation; +import org.mockito.quality.Strictness; +import org.mockito.stubbing.Answer; import org.mockito.verification.VerificationMode; + +import org.springframework.lang.NonNull; +import org.springframework.lang.Nullable; import org.springframework.util.StringUtils; /** + * Utilities for using {@literal Mockito} and creating {@link Object mock objects} in {@literal unit tests}. + * * @author Christoph Strobl + * @author John Blum + * @see org.mockito.Mockito * @since 1.7 */ -public class MockitoUtils { +@SuppressWarnings("unused") +public abstract class MockitoUtils { + + /** + * Creates a mock {@link Future} returning the given {@link Object result}. + * + * @param {@link Class type} of {@link Object result} returned by the mock {@link Future}. + * @param result {@link Object value} returned as the {@literal result} of the mock {@link Future}. + * @return a new mock {@link Future}. + * @see java.util.concurrent.Future + */ + @SuppressWarnings("unchecked") + public static @NonNull Future mockFuture(@Nullable T result) { + + try { + + AtomicBoolean cancelled = new AtomicBoolean(false); + AtomicBoolean done = new AtomicBoolean(false); + + Future mockFuture = mock(Future.class, withSettings().strictness(Strictness.LENIENT)); + + // A Future can only be cancelled if not done, it was not already cancelled, and no error occurred. + // The cancel(..) logic is not Thread-safe due to compound actions involving multiple variables. + // However, the cancel(..) logic does not necessarily need to be Thread-safe given the task execution + // of a Future is asynchronous and cancellation is driven by Thread interrupt from another Thread. + Answer cancelAnswer = invocation -> !done.get() + && cancelled.compareAndSet(done.get(), true) + && done.compareAndSet(done.get(), true); + + Answer getAnswer = invocation -> { + + // The Future is done no matter if it returns the result or was cancelled/interrupted. + done.set(true); + + if (Thread.currentThread().isInterrupted()) { + throw new InterruptedException("Thread was interrupted"); + } + + if (cancelled.get()) { + throw new CancellationException("Task was cancelled"); + } + + return result; + }; + + doAnswer(invocation -> cancelled.get()).when(mockFuture).isCancelled(); + doAnswer(invocation -> done.get()).when(mockFuture).isDone(); + doAnswer(cancelAnswer).when(mockFuture).cancel(anyBoolean()); + doAnswer(getAnswer).when(mockFuture).get(); + doAnswer(getAnswer).when(mockFuture).get(anyLong(), isA(TimeUnit.class)); + + return mockFuture; + } + catch (Exception cause) { + String message = String.format("Failed to create a mock of Future having result [%s]", result); + throw new IllegalStateException(message, cause); + } + } + + /** + * Creates a mock {@link Future} returning the given {@link Object result}, customized with the given, + * required {@link Function}. + * + * @param {@link Class type} of {@link Object result} returned by the mock {@link Future}. + * @param result {@link Object value} returned as the {@literal result} of the mock {@link Future}. + * @param futureFunction {@link Function} used to customize the mock {@link Future} on creation; + * must not be {@literal null}. + * @return a new mock {@link Future}. + * @see java.util.concurrent.Future + * @see java.util.function.Function + * @see #mockFuture(Object) + */ + public static @NonNull Future mockFuture(@Nullable T result, + @NonNull ThrowableFunction, Future> futureFunction) { + + Future mockFuture = mockFuture(result); + + return futureFunction.apply(mockFuture); + } /** * Verifies a given method is called a total number of times across all given mocks. * - * @param method - * @param mode - * @param mocks + * @param method {@link String name} of a {@link java.lang.reflect.Method} on the {@link Object mock object}. + * @param mode mode of verification used by {@literal Mockito} to verify invocations on {@link Object mock objects}. + * @param mocks array of {@link Object mock objects} to verify. */ - @SuppressWarnings({ "rawtypes", "serial" }) - public static void verifyInvocationsAcross(final String method, final VerificationMode mode, Object... mocks) { + public static void verifyInvocationsAcross(String method, VerificationMode mode, Object... mocks) { mode.verify(new VerificationDataImpl(getInvocations(method, mocks), new InvocationMatcher(null, Collections .singletonList(org.mockito.internal.matchers.Any.ANY)) { @@ -56,17 +153,15 @@ public boolean matches(Invocation actual) { public String toString() { return String.format("%s for method: %s", mode, method); } - })); } private static List getInvocations(String method, Object... mocks) { List invocations = new ArrayList<>(); - for (Object mock : mocks) { + for (Object mock : mocks) { if (StringUtils.hasText(method)) { - for (Invocation invocation : mockingDetails(mock).getInvocations()) { if (invocation.getMethod().getName().equals(method)) { invocations.add(invocation); @@ -76,6 +171,7 @@ private static List getInvocations(String method, Object... mocks) { invocations.addAll(mockingDetails(mock).getInvocations()); } } + return invocations; } @@ -98,7 +194,31 @@ public List getAllInvocations() { public MatchableInvocation getTarget() { return wanted; } - } + @FunctionalInterface + public interface ThrowableFunction extends Function { + + @Override + default R apply(T target) { + + try { + return applyThrowingException(target); + } + catch (Throwable cause) { + String message = String.format("Failed to apply Function [%s] to target [%s]", this, target); + throw new IllegalStateException(message, cause); + } + } + + R applyThrowingException(T target) throws Throwable; + + @SuppressWarnings("unchecked") + default @NonNull ThrowableFunction andThen( + @Nullable ThrowableFunction after) { + + return after == null ? (ThrowableFunction) this + : target -> after.apply(this.apply(target)); + } + } }