From 2307011be22d4c1233686383531520bfe72c3af5 Mon Sep 17 00:00:00 2001 From: Gor Stepanyan Date: Thu, 19 Jan 2023 16:09:41 +0100 Subject: [PATCH] Add RandomTwoChoice load balancing policy The RandomTwoChoice policy is a wrapper load balancing policy that adds the "Power of 2 Choice" algorithm to its child policy. It will compare the first two hosts returned from the child policy query plan, and will first return the host with the target shard having fewer inflight requests. The rest of the child query plan will be left intact. It is intended to be used with TokenAwarePolicy and RoundRobinPolicy, to send queries to the replica nodes by always avoiding the worst option (the replica with the target shard having the most inflight requests will not be chosen). --- .../com/datastax/driver/core/Metrics.java | 32 ++++ .../core/policies/RandomTwoChoicePolicy.java | 157 ++++++++++++++++++ .../policies/RandomTwoChoicePolicyTest.java | 110 ++++++++++++ 3 files changed, 299 insertions(+) create mode 100644 driver-core/src/main/java/com/datastax/driver/core/policies/RandomTwoChoicePolicy.java create mode 100644 driver-core/src/test/java/com/datastax/driver/core/policies/RandomTwoChoicePolicyTest.java diff --git a/driver-core/src/main/java/com/datastax/driver/core/Metrics.java b/driver-core/src/main/java/com/datastax/driver/core/Metrics.java index 6eacfdf0e28..7c9437319c4 100644 --- a/driver-core/src/main/java/com/datastax/driver/core/Metrics.java +++ b/driver-core/src/main/java/com/datastax/driver/core/Metrics.java @@ -130,7 +130,35 @@ public Integer getValue() { return value; } }); + private final Gauge>> perShardInflightRequestInfo = + registry.register( + "per-shard-inflight-request-info", + new Gauge>>() { + @Override + public Map> getValue() { + Map> result = new HashMap>(); + for (SessionManager session : manager.sessions) { + for (Map.Entry poolEntry : session.pools.entrySet()) { + HostConnectionPool hostConnectionPool = poolEntry.getValue(); + Map perShardInflightRequests = new HashMap(); + + for (int shardId = 0; + shardId < hostConnectionPool.connections.length; + shardId++) { + int shardInflightRequests = 0; + for (Connection connection : hostConnectionPool.connections[shardId]) { + shardInflightRequests += connection.inFlight.get(); + } + perShardInflightRequests.put(shardId, shardInflightRequests); + } + + result.put(poolEntry.getKey(), perShardInflightRequests); + } + } + return result; + } + }); private final Gauge executorQueueDepth; private final Gauge blockingExecutorQueueDepth; private final Gauge reconnectionSchedulerQueueSize; @@ -374,6 +402,10 @@ public Gauge> getShardAwarenessInfo() { return shardAwarenessInfo; } + public Gauge>> getPerShardInflightRequestInfo() { + return perShardInflightRequestInfo; + } + /** * Returns the number of bytes sent so far. * diff --git a/driver-core/src/main/java/com/datastax/driver/core/policies/RandomTwoChoicePolicy.java b/driver-core/src/main/java/com/datastax/driver/core/policies/RandomTwoChoicePolicy.java new file mode 100644 index 00000000000..a9830244be3 --- /dev/null +++ b/driver-core/src/main/java/com/datastax/driver/core/policies/RandomTwoChoicePolicy.java @@ -0,0 +1,157 @@ +package com.datastax.driver.core.policies; + +import com.datastax.driver.core.*; +import com.google.common.collect.AbstractIterator; +import java.nio.ByteBuffer; +import java.util.*; +import java.util.concurrent.atomic.AtomicInteger; + +/** + * A wrapper load balancing policy that adds "Power of 2 Choice" algorithm to a child policy. + * + *

This policy encapsulates another policy. The resulting policy works in the following way: + * + *

    + *
  • the {@code distance} method is inherited from the child policy. + *
  • the {@code newQueryPlan} method will compare first two hosts (by number of inflight + * requests) returned from the {@code newQueryPlan} method of the child policy, and the host + * with fewer number of inflight requests will be returned the first. It will allow to always + * avoid the worst option (comparing by number of inflight requests). + *
  • besides the first two hosts returned by the child policy's {@code newQueryPlan} method, the + * ordering of the rest of the hosts will remain the same. + *
+ * + *

If you wrap the {@code RandomTwoChoicePolicy} policy with {@code TokenAwarePolicy}, it will + * compare the first two replicas by the number of inflight requests, and the worse option will + * always be avoided. In that case, it is recommended to use the TokenAwarePolicy with {@code + * ReplicaOrdering.RANDOM strategy}, which will return the replicas in a shuffled order and thus + * will make the "Power of 2 Choice" algorithm more efficient. + */ +public class RandomTwoChoicePolicy implements ChainableLoadBalancingPolicy { + private final LoadBalancingPolicy childPolicy; + private volatile Metrics metrics; + private volatile Metadata clusterMetadata; + private volatile ProtocolVersion protocolVersion; + private volatile CodecRegistry codecRegistry; + + /** + * Creates a new {@code RandomTwoChoicePolicy}. + * + * @param childPolicy the load balancing policy to wrap with "Power of 2 Choice" algorithm. + */ + public RandomTwoChoicePolicy(LoadBalancingPolicy childPolicy) { + this.childPolicy = childPolicy; + } + + @Override + public LoadBalancingPolicy getChildPolicy() { + return childPolicy; + } + + @Override + public void init(Cluster cluster, Collection hosts) { + this.metrics = cluster.getMetrics(); + this.clusterMetadata = cluster.getMetadata(); + this.protocolVersion = cluster.getConfiguration().getProtocolOptions().getProtocolVersion(); + this.codecRegistry = cluster.getConfiguration().getCodecRegistry(); + childPolicy.init(cluster, hosts); + } + + /** + * {@inheritDoc} + * + *

This implementation always returns distances as reported by the wrapped policy. + */ + @Override + public HostDistance distance(Host host) { + return childPolicy.distance(host); + } + + /** + * {@inheritDoc} + * + *

The returned plan will compare (by the number of inflight requests) the first 2 hosts + * returned by the child policy's {@code newQueryPlan} method, and the host with fewer inflight + * requests will be returned the first. The rest of the child policy's query plan will be left + * intact. + */ + @Override + public Iterator newQueryPlan(String loggedKeyspace, Statement statement) { + String keyspace = statement.getKeyspace(); + if (keyspace == null) keyspace = loggedKeyspace; + + ByteBuffer routingKey = statement.getRoutingKey(protocolVersion, codecRegistry); + if (routingKey == null || keyspace == null) { + return childPolicy.newQueryPlan(loggedKeyspace, statement); + } + + final Token t = clusterMetadata.newToken(statement.getPartitioner(), routingKey); + final Iterator childIterator = childPolicy.newQueryPlan(keyspace, statement); + + final Host host1 = childIterator.hasNext() ? childIterator.next() : null; + final Host host2 = childIterator.hasNext() ? childIterator.next() : null; + + final AtomicInteger host1ShardInflightRequests = new AtomicInteger(0); + final AtomicInteger host2ShardInflightRequests = new AtomicInteger(0); + + if (host1 != null) { + final int host1ShardId = host1.getShardingInfo().shardId(t); + host1ShardInflightRequests.set( + metrics.getPerShardInflightRequestInfo().getValue().get(host1).get(host1ShardId)); + } + + if (host2 != null) { + final int host2ShardId = host2.getShardingInfo().shardId(t); + host2ShardInflightRequests.set( + metrics.getPerShardInflightRequestInfo().getValue().get(host2).get(host2ShardId)); + } + + return new AbstractIterator() { + private final Host firstChosenHost = + host1ShardInflightRequests.get() < host2ShardInflightRequests.get() ? host1 : host2; + private final Host secondChosenHost = + host1ShardInflightRequests.get() < host2ShardInflightRequests.get() ? host2 : host1; + private int index = 0; + + @Override + protected Host computeNext() { + if (index == 0) { + index++; + return firstChosenHost; + } else if (index == 1) { + index++; + return secondChosenHost; + } else if (childIterator.hasNext()) { + return childIterator.next(); + } + + return endOfData(); + } + }; + } + + @Override + public void onAdd(Host host) { + childPolicy.onAdd(host); + } + + @Override + public void onUp(Host host) { + childPolicy.onUp(host); + } + + @Override + public void onDown(Host host) { + childPolicy.onDown(host); + } + + @Override + public void onRemove(Host host) { + childPolicy.onRemove(host); + } + + @Override + public void close() { + childPolicy.close(); + } +} diff --git a/driver-core/src/test/java/com/datastax/driver/core/policies/RandomTwoChoicePolicyTest.java b/driver-core/src/test/java/com/datastax/driver/core/policies/RandomTwoChoicePolicyTest.java new file mode 100644 index 00000000000..c895c36ef28 --- /dev/null +++ b/driver-core/src/test/java/com/datastax/driver/core/policies/RandomTwoChoicePolicyTest.java @@ -0,0 +1,110 @@ +package com.datastax.driver.core.policies; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.codahale.metrics.Gauge; +import com.datastax.driver.core.*; +import java.nio.ByteBuffer; +import java.util.*; +import org.assertj.core.util.Sets; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +public class RandomTwoChoicePolicyTest { + private final ByteBuffer routingKey = ByteBuffer.wrap(new byte[] {1, 2, 3, 4}); + private final RegularStatement statement = + new SimpleStatement("irrelevant").setRoutingKey(routingKey).setKeyspace("keyspace"); + private final Host host1 = mock(Host.class); + private final Host host2 = mock(Host.class); + private final Host host3 = mock(Host.class); + private Cluster cluster; + + @SuppressWarnings("unchecked") + private final Gauge>> gauge = + mock((Class>>>) (Object) Gauge.class); + + @BeforeMethod(groups = "unit") + public void initMocks() { + CodecRegistry codecRegistry = new CodecRegistry(); + cluster = mock(Cluster.class); + Configuration configuration = mock(Configuration.class); + ProtocolOptions protocolOptions = mock(ProtocolOptions.class); + Metadata metadata = mock(Metadata.class); + Metrics metrics = mock(Metrics.class); + Token t = mock(Token.class); + ShardingInfo shardingInfo = mock(ShardingInfo.class); + + when(metrics.getPerShardInflightRequestInfo()).thenReturn(gauge); + when(cluster.getConfiguration()).thenReturn(configuration); + when(configuration.getCodecRegistry()).thenReturn(codecRegistry); + when(configuration.getProtocolOptions()).thenReturn(protocolOptions); + when(protocolOptions.getProtocolVersion()).thenReturn(ProtocolVersion.DEFAULT); + when(cluster.getMetadata()).thenReturn(metadata); + when(cluster.getMetrics()).thenReturn(metrics); + when(metadata.getReplicas(Metadata.quote("keyspace"), null, routingKey)) + .thenReturn(Sets.newLinkedHashSet(host1, host2, host3)); + when(metadata.newToken(null, routingKey)).thenReturn(t); + when(host1.getShardingInfo()).thenReturn(shardingInfo); + when(host2.getShardingInfo()).thenReturn(shardingInfo); + when(host3.getShardingInfo()).thenReturn(shardingInfo); + when(shardingInfo.shardId(t)).thenReturn(0); + when(host1.isUp()).thenReturn(true); + when(host2.isUp()).thenReturn(true); + when(host3.isUp()).thenReturn(true); + } + + @Test(groups = "unit") + public void should_prefer_host_with_less_inflight_requests() { + // given + Map> perHostInflightRequests = + new HashMap>() { + { + put( + host1, + new HashMap() { + { + put(0, 6); + } + }); + put( + host2, + new HashMap() { + { + put(0, 2); + } + }); + put( + host3, + new HashMap() { + { + put(0, 4); + } + }); + } + }; + RandomTwoChoicePolicy policy = + new RandomTwoChoicePolicy( + new TokenAwarePolicy( + new RoundRobinPolicy(), TokenAwarePolicy.ReplicaOrdering.TOPOLOGICAL)); + policy.init( + cluster, + new ArrayList() { + + { + add(host1); + add(host2); + add(host3); + } + }); + when(gauge.getValue()).thenReturn(perHostInflightRequests); + + Iterator queryPlan = policy.newQueryPlan("keyspace", statement); + // host2 should appear first in the query plan with fewer inflight requests than host1 + + assertThat(queryPlan.next()).isEqualTo(host2); + assertThat(queryPlan.next()).isEqualTo(host1); + assertThat(queryPlan.next()).isEqualTo(host3); + } +}