Skip to content

Commit

Permalink
Add RandomTwoChoice load balancing policy
Browse files Browse the repository at this point in the history
The RandomTwoChoice policy is a wrapper load balancing policy that
adds the "Power of 2 Choice" algorithm to its child policy. It will
compare the first two hosts returned from the child policy query plan,
and will first return the host with the target shard having fewer
inflight requests. The rest of the child query plan will be left intact.

It is intended to be used with TokenAwarePolicy and RoundRobinPolicy,
to send queries to the replica nodes by always avoiding the worst option
(the replica with the target shard having the most inflight requests will not be chosen).
  • Loading branch information
Gor027 committed Mar 9, 2023
1 parent b3f3eba commit 2307011
Show file tree
Hide file tree
Showing 3 changed files with 299 additions and 0 deletions.
32 changes: 32 additions & 0 deletions driver-core/src/main/java/com/datastax/driver/core/Metrics.java
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,35 @@ public Integer getValue() {
return value;
}
});
private final Gauge<Map<Host, Map<Integer, Integer>>> perShardInflightRequestInfo =
registry.register(
"per-shard-inflight-request-info",
new Gauge<Map<Host, Map<Integer, Integer>>>() {
@Override
public Map<Host, Map<Integer, Integer>> getValue() {
Map<Host, Map<Integer, Integer>> result = new HashMap<Host, Map<Integer, Integer>>();
for (SessionManager session : manager.sessions) {
for (Map.Entry<Host, HostConnectionPool> poolEntry : session.pools.entrySet()) {
HostConnectionPool hostConnectionPool = poolEntry.getValue();
Map<Integer, Integer> perShardInflightRequests = new HashMap<Integer, Integer>();

for (int shardId = 0;
shardId < hostConnectionPool.connections.length;
shardId++) {
int shardInflightRequests = 0;
for (Connection connection : hostConnectionPool.connections[shardId]) {
shardInflightRequests += connection.inFlight.get();
}
perShardInflightRequests.put(shardId, shardInflightRequests);
}

result.put(poolEntry.getKey(), perShardInflightRequests);
}
}

return result;
}
});
private final Gauge<Integer> executorQueueDepth;
private final Gauge<Integer> blockingExecutorQueueDepth;
private final Gauge<Integer> reconnectionSchedulerQueueSize;
Expand Down Expand Up @@ -374,6 +402,10 @@ public Gauge<Map<Host, Integer>> getShardAwarenessInfo() {
return shardAwarenessInfo;
}

public Gauge<Map<Host, Map<Integer, Integer>>> getPerShardInflightRequestInfo() {
return perShardInflightRequestInfo;
}

/**
* Returns the number of bytes sent so far.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
package com.datastax.driver.core.policies;

import com.datastax.driver.core.*;
import com.google.common.collect.AbstractIterator;
import java.nio.ByteBuffer;
import java.util.*;
import java.util.concurrent.atomic.AtomicInteger;

/**
* A wrapper load balancing policy that adds "Power of 2 Choice" algorithm to a child policy.
*
* <p>This policy encapsulates another policy. The resulting policy works in the following way:
*
* <ul>
* <li>the {@code distance} method is inherited from the child policy.
* <li>the {@code newQueryPlan} method will compare first two hosts (by number of inflight
* requests) returned from the {@code newQueryPlan} method of the child policy, and the host
* with fewer number of inflight requests will be returned the first. It will allow to always
* avoid the worst option (comparing by number of inflight requests).
* <li>besides the first two hosts returned by the child policy's {@code newQueryPlan} method, the
* ordering of the rest of the hosts will remain the same.
* </ul>
*
* <p>If you wrap the {@code RandomTwoChoicePolicy} policy with {@code TokenAwarePolicy}, it will
* compare the first two replicas by the number of inflight requests, and the worse option will
* always be avoided. In that case, it is recommended to use the TokenAwarePolicy with {@code
* ReplicaOrdering.RANDOM strategy}, which will return the replicas in a shuffled order and thus
* will make the "Power of 2 Choice" algorithm more efficient.
*/
public class RandomTwoChoicePolicy implements ChainableLoadBalancingPolicy {
private final LoadBalancingPolicy childPolicy;
private volatile Metrics metrics;
private volatile Metadata clusterMetadata;
private volatile ProtocolVersion protocolVersion;
private volatile CodecRegistry codecRegistry;

/**
* Creates a new {@code RandomTwoChoicePolicy}.
*
* @param childPolicy the load balancing policy to wrap with "Power of 2 Choice" algorithm.
*/
public RandomTwoChoicePolicy(LoadBalancingPolicy childPolicy) {
this.childPolicy = childPolicy;
}

@Override
public LoadBalancingPolicy getChildPolicy() {
return childPolicy;
}

@Override
public void init(Cluster cluster, Collection<Host> hosts) {
this.metrics = cluster.getMetrics();
this.clusterMetadata = cluster.getMetadata();
this.protocolVersion = cluster.getConfiguration().getProtocolOptions().getProtocolVersion();
this.codecRegistry = cluster.getConfiguration().getCodecRegistry();
childPolicy.init(cluster, hosts);
}

/**
* {@inheritDoc}
*
* <p>This implementation always returns distances as reported by the wrapped policy.
*/
@Override
public HostDistance distance(Host host) {
return childPolicy.distance(host);
}

/**
* {@inheritDoc}
*
* <p>The returned plan will compare (by the number of inflight requests) the first 2 hosts
* returned by the child policy's {@code newQueryPlan} method, and the host with fewer inflight
* requests will be returned the first. The rest of the child policy's query plan will be left
* intact.
*/
@Override
public Iterator<Host> newQueryPlan(String loggedKeyspace, Statement statement) {
String keyspace = statement.getKeyspace();
if (keyspace == null) keyspace = loggedKeyspace;

ByteBuffer routingKey = statement.getRoutingKey(protocolVersion, codecRegistry);
if (routingKey == null || keyspace == null) {
return childPolicy.newQueryPlan(loggedKeyspace, statement);
}

final Token t = clusterMetadata.newToken(statement.getPartitioner(), routingKey);
final Iterator<Host> childIterator = childPolicy.newQueryPlan(keyspace, statement);

final Host host1 = childIterator.hasNext() ? childIterator.next() : null;
final Host host2 = childIterator.hasNext() ? childIterator.next() : null;

final AtomicInteger host1ShardInflightRequests = new AtomicInteger(0);
final AtomicInteger host2ShardInflightRequests = new AtomicInteger(0);

if (host1 != null) {
final int host1ShardId = host1.getShardingInfo().shardId(t);
host1ShardInflightRequests.set(
metrics.getPerShardInflightRequestInfo().getValue().get(host1).get(host1ShardId));
}

if (host2 != null) {
final int host2ShardId = host2.getShardingInfo().shardId(t);
host2ShardInflightRequests.set(
metrics.getPerShardInflightRequestInfo().getValue().get(host2).get(host2ShardId));
}

return new AbstractIterator<Host>() {
private final Host firstChosenHost =
host1ShardInflightRequests.get() < host2ShardInflightRequests.get() ? host1 : host2;
private final Host secondChosenHost =
host1ShardInflightRequests.get() < host2ShardInflightRequests.get() ? host2 : host1;
private int index = 0;

@Override
protected Host computeNext() {
if (index == 0) {
index++;
return firstChosenHost;
} else if (index == 1) {
index++;
return secondChosenHost;
} else if (childIterator.hasNext()) {
return childIterator.next();
}

return endOfData();
}
};
}

@Override
public void onAdd(Host host) {
childPolicy.onAdd(host);
}

@Override
public void onUp(Host host) {
childPolicy.onUp(host);
}

@Override
public void onDown(Host host) {
childPolicy.onDown(host);
}

@Override
public void onRemove(Host host) {
childPolicy.onRemove(host);
}

@Override
public void close() {
childPolicy.close();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
package com.datastax.driver.core.policies;

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

import com.codahale.metrics.Gauge;
import com.datastax.driver.core.*;
import java.nio.ByteBuffer;
import java.util.*;
import org.assertj.core.util.Sets;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.Test;

public class RandomTwoChoicePolicyTest {
private final ByteBuffer routingKey = ByteBuffer.wrap(new byte[] {1, 2, 3, 4});
private final RegularStatement statement =
new SimpleStatement("irrelevant").setRoutingKey(routingKey).setKeyspace("keyspace");
private final Host host1 = mock(Host.class);
private final Host host2 = mock(Host.class);
private final Host host3 = mock(Host.class);
private Cluster cluster;

@SuppressWarnings("unchecked")
private final Gauge<Map<Host, Map<Integer, Integer>>> gauge =
mock((Class<Gauge<Map<Host, Map<Integer, Integer>>>>) (Object) Gauge.class);

@BeforeMethod(groups = "unit")
public void initMocks() {
CodecRegistry codecRegistry = new CodecRegistry();
cluster = mock(Cluster.class);
Configuration configuration = mock(Configuration.class);
ProtocolOptions protocolOptions = mock(ProtocolOptions.class);
Metadata metadata = mock(Metadata.class);
Metrics metrics = mock(Metrics.class);
Token t = mock(Token.class);
ShardingInfo shardingInfo = mock(ShardingInfo.class);

when(metrics.getPerShardInflightRequestInfo()).thenReturn(gauge);
when(cluster.getConfiguration()).thenReturn(configuration);
when(configuration.getCodecRegistry()).thenReturn(codecRegistry);
when(configuration.getProtocolOptions()).thenReturn(protocolOptions);
when(protocolOptions.getProtocolVersion()).thenReturn(ProtocolVersion.DEFAULT);
when(cluster.getMetadata()).thenReturn(metadata);
when(cluster.getMetrics()).thenReturn(metrics);
when(metadata.getReplicas(Metadata.quote("keyspace"), null, routingKey))
.thenReturn(Sets.newLinkedHashSet(host1, host2, host3));
when(metadata.newToken(null, routingKey)).thenReturn(t);
when(host1.getShardingInfo()).thenReturn(shardingInfo);
when(host2.getShardingInfo()).thenReturn(shardingInfo);
when(host3.getShardingInfo()).thenReturn(shardingInfo);
when(shardingInfo.shardId(t)).thenReturn(0);
when(host1.isUp()).thenReturn(true);
when(host2.isUp()).thenReturn(true);
when(host3.isUp()).thenReturn(true);
}

@Test(groups = "unit")
public void should_prefer_host_with_less_inflight_requests() {
// given
Map<Host, Map<Integer, Integer>> perHostInflightRequests =
new HashMap<Host, Map<Integer, Integer>>() {
{
put(
host1,
new HashMap<Integer, Integer>() {
{
put(0, 6);
}
});
put(
host2,
new HashMap<Integer, Integer>() {
{
put(0, 2);
}
});
put(
host3,
new HashMap<Integer, Integer>() {
{
put(0, 4);
}
});
}
};
RandomTwoChoicePolicy policy =
new RandomTwoChoicePolicy(
new TokenAwarePolicy(
new RoundRobinPolicy(), TokenAwarePolicy.ReplicaOrdering.TOPOLOGICAL));
policy.init(
cluster,
new ArrayList<Host>() {

{
add(host1);
add(host2);
add(host3);
}
});
when(gauge.getValue()).thenReturn(perHostInflightRequests);

Iterator<Host> queryPlan = policy.newQueryPlan("keyspace", statement);
// host2 should appear first in the query plan with fewer inflight requests than host1

assertThat(queryPlan.next()).isEqualTo(host2);
assertThat(queryPlan.next()).isEqualTo(host1);
assertThat(queryPlan.next()).isEqualTo(host3);
}
}

0 comments on commit 2307011

Please sign in to comment.