Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

3.x: Add RandomTwoChoice load balancing policy #198

Open
wants to merge 1 commit into
base: scylla-3.x
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions driver-core/src/main/java/com/datastax/driver/core/Metrics.java
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,35 @@ public Integer getValue() {
return value;
}
});
private final Gauge<Map<Host, Map<Integer, Integer>>> perShardInflightRequestInfo =
registry.register(
"per-shard-inflight-request-info",
new Gauge<Map<Host, Map<Integer, Integer>>>() {
@Override
public Map<Host, Map<Integer, Integer>> getValue() {
Map<Host, Map<Integer, Integer>> result = new HashMap<Host, Map<Integer, Integer>>();
for (SessionManager session : manager.sessions) {
for (Map.Entry<Host, HostConnectionPool> poolEntry : session.pools.entrySet()) {
HostConnectionPool hostConnectionPool = poolEntry.getValue();
Map<Integer, Integer> perShardInflightRequests = new HashMap<Integer, Integer>();

for (int shardId = 0;
shardId < hostConnectionPool.connections.length;
shardId++) {
int shardInflightRequests = 0;
for (Connection connection : hostConnectionPool.connections[shardId]) {
shardInflightRequests += connection.inFlight.get();
}
perShardInflightRequests.put(shardId, shardInflightRequests);
}

result.put(poolEntry.getKey(), perShardInflightRequests);
}
}

return result;
}
});
private final Gauge<Integer> executorQueueDepth;
private final Gauge<Integer> blockingExecutorQueueDepth;
private final Gauge<Integer> reconnectionSchedulerQueueSize;
Expand Down Expand Up @@ -374,6 +402,10 @@ public Gauge<Map<Host, Integer>> getShardAwarenessInfo() {
return shardAwarenessInfo;
}

public Gauge<Map<Host, Map<Integer, Integer>>> getPerShardInflightRequestInfo() {
return perShardInflightRequestInfo;
}

/**
* Returns the number of bytes sent so far.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
package com.datastax.driver.core.policies;

import com.datastax.driver.core.*;
import com.google.common.collect.AbstractIterator;
import java.nio.ByteBuffer;
import java.util.*;
import java.util.concurrent.atomic.AtomicInteger;

/**
* A wrapper load balancing policy that adds "Power of 2 Choice" algorithm to a child policy.
*
* <p>This policy encapsulates another policy. The resulting policy works in the following way:
*
* <ul>
* <li>the {@code distance} method is inherited from the child policy.
* <li>the {@code newQueryPlan} method will compare first two hosts (by number of inflight
* requests) returned from the {@code newQueryPlan} method of the child policy, and the host
* with fewer number of inflight requests will be returned the first. It will allow to always
* avoid the worst option (comparing by number of inflight requests).
* <li>besides the first two hosts returned by the child policy's {@code newQueryPlan} method, the
* ordering of the rest of the hosts will remain the same.
* </ul>
*
* <p>If you wrap the {@code RandomTwoChoicePolicy} policy with {@code TokenAwarePolicy}, it will
* compare the first two replicas by the number of inflight requests, and the worse option will
* always be avoided. In that case, it is recommended to use the TokenAwarePolicy with {@code
* ReplicaOrdering.RANDOM strategy}, which will return the replicas in a shuffled order and thus
* will make the "Power of 2 Choice" algorithm more efficient.
*/
public class RandomTwoChoicePolicy implements ChainableLoadBalancingPolicy {
private final LoadBalancingPolicy childPolicy;
private volatile Metrics metrics;
private volatile Metadata clusterMetadata;
private volatile ProtocolVersion protocolVersion;
private volatile CodecRegistry codecRegistry;

/**
* Creates a new {@code RandomTwoChoicePolicy}.
*
* @param childPolicy the load balancing policy to wrap with "Power of 2 Choice" algorithm.
*/
public RandomTwoChoicePolicy(LoadBalancingPolicy childPolicy) {
this.childPolicy = childPolicy;
}

@Override
public LoadBalancingPolicy getChildPolicy() {
return childPolicy;
}

@Override
public void init(Cluster cluster, Collection<Host> hosts) {
this.metrics = cluster.getMetrics();
this.clusterMetadata = cluster.getMetadata();
this.protocolVersion = cluster.getConfiguration().getProtocolOptions().getProtocolVersion();
this.codecRegistry = cluster.getConfiguration().getCodecRegistry();
childPolicy.init(cluster, hosts);
}

/**
* {@inheritDoc}
*
* <p>This implementation always returns distances as reported by the wrapped policy.
*/
@Override
public HostDistance distance(Host host) {
return childPolicy.distance(host);
}

/**
* {@inheritDoc}
*
* <p>The returned plan will compare (by the number of inflight requests) the first 2 hosts
* returned by the child policy's {@code newQueryPlan} method, and the host with fewer inflight
* requests will be returned the first. The rest of the child policy's query plan will be left
* intact.
*/
@Override
public Iterator<Host> newQueryPlan(String loggedKeyspace, Statement statement) {
String keyspace = statement.getKeyspace();
if (keyspace == null) keyspace = loggedKeyspace;

ByteBuffer routingKey = statement.getRoutingKey(protocolVersion, codecRegistry);
if (routingKey == null || keyspace == null) {
return childPolicy.newQueryPlan(loggedKeyspace, statement);
}

final Token t = clusterMetadata.newToken(statement.getPartitioner(), routingKey);
final Iterator<Host> childIterator = childPolicy.newQueryPlan(keyspace, statement);

final Host host1 = childIterator.hasNext() ? childIterator.next() : null;
final Host host2 = childIterator.hasNext() ? childIterator.next() : null;

final AtomicInteger host1ShardInflightRequests = new AtomicInteger(0);
final AtomicInteger host2ShardInflightRequests = new AtomicInteger(0);

if (host1 != null) {
final int host1ShardId = host1.getShardingInfo().shardId(t);
host1ShardInflightRequests.set(
metrics.getPerShardInflightRequestInfo().getValue().get(host1).get(host1ShardId));
}

if (host2 != null) {
final int host2ShardId = host2.getShardingInfo().shardId(t);
host2ShardInflightRequests.set(
metrics.getPerShardInflightRequestInfo().getValue().get(host2).get(host2ShardId));
}

return new AbstractIterator<Host>() {
private final Host firstChosenHost =
host1ShardInflightRequests.get() < host2ShardInflightRequests.get() ? host1 : host2;
private final Host secondChosenHost =
host1ShardInflightRequests.get() < host2ShardInflightRequests.get() ? host2 : host1;
private int index = 0;

@Override
protected Host computeNext() {
if (index == 0) {
index++;
return firstChosenHost;
} else if (index == 1) {
index++;
return secondChosenHost;
} else if (childIterator.hasNext()) {
return childIterator.next();
}

return endOfData();
}
};
}

@Override
public void onAdd(Host host) {
childPolicy.onAdd(host);
}

@Override
public void onUp(Host host) {
childPolicy.onUp(host);
}

@Override
public void onDown(Host host) {
childPolicy.onDown(host);
}

@Override
public void onRemove(Host host) {
childPolicy.onRemove(host);
}

@Override
public void close() {
childPolicy.close();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
package com.datastax.driver.core.policies;

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

import com.codahale.metrics.Gauge;
import com.datastax.driver.core.*;
import java.nio.ByteBuffer;
import java.util.*;
import org.assertj.core.util.Sets;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.Test;

public class RandomTwoChoicePolicyTest {
private final ByteBuffer routingKey = ByteBuffer.wrap(new byte[] {1, 2, 3, 4});
private final RegularStatement statement =
new SimpleStatement("irrelevant").setRoutingKey(routingKey).setKeyspace("keyspace");
private final Host host1 = mock(Host.class);
private final Host host2 = mock(Host.class);
private final Host host3 = mock(Host.class);
private Cluster cluster;

@SuppressWarnings("unchecked")
private final Gauge<Map<Host, Map<Integer, Integer>>> gauge =
mock((Class<Gauge<Map<Host, Map<Integer, Integer>>>>) (Object) Gauge.class);

@BeforeMethod(groups = "unit")
public void initMocks() {
CodecRegistry codecRegistry = new CodecRegistry();
cluster = mock(Cluster.class);
Configuration configuration = mock(Configuration.class);
ProtocolOptions protocolOptions = mock(ProtocolOptions.class);
Metadata metadata = mock(Metadata.class);
Metrics metrics = mock(Metrics.class);
Token t = mock(Token.class);
ShardingInfo shardingInfo = mock(ShardingInfo.class);

when(metrics.getPerShardInflightRequestInfo()).thenReturn(gauge);
when(cluster.getConfiguration()).thenReturn(configuration);
when(configuration.getCodecRegistry()).thenReturn(codecRegistry);
when(configuration.getProtocolOptions()).thenReturn(protocolOptions);
when(protocolOptions.getProtocolVersion()).thenReturn(ProtocolVersion.DEFAULT);
when(cluster.getMetadata()).thenReturn(metadata);
when(cluster.getMetrics()).thenReturn(metrics);
when(metadata.getReplicas(Metadata.quote("keyspace"), null, routingKey))
.thenReturn(Sets.newLinkedHashSet(host1, host2, host3));
when(metadata.newToken(null, routingKey)).thenReturn(t);
when(host1.getShardingInfo()).thenReturn(shardingInfo);
when(host2.getShardingInfo()).thenReturn(shardingInfo);
when(host3.getShardingInfo()).thenReturn(shardingInfo);
when(shardingInfo.shardId(t)).thenReturn(0);
when(host1.isUp()).thenReturn(true);
when(host2.isUp()).thenReturn(true);
when(host3.isUp()).thenReturn(true);
}

@Test(groups = "unit")
public void should_prefer_host_with_less_inflight_requests() {
// given
Map<Host, Map<Integer, Integer>> perHostInflightRequests =
new HashMap<Host, Map<Integer, Integer>>() {
{
put(
host1,
new HashMap<Integer, Integer>() {
{
put(0, 6);
}
});
put(
host2,
new HashMap<Integer, Integer>() {
{
put(0, 2);
}
});
put(
host3,
new HashMap<Integer, Integer>() {
{
put(0, 4);
}
});
}
};
RandomTwoChoicePolicy policy =
new RandomTwoChoicePolicy(
new TokenAwarePolicy(
new RoundRobinPolicy(), TokenAwarePolicy.ReplicaOrdering.TOPOLOGICAL));
policy.init(
cluster,
new ArrayList<Host>() {

{
add(host1);
add(host2);
add(host3);
}
});
when(gauge.getValue()).thenReturn(perHostInflightRequests);

Iterator<Host> queryPlan = policy.newQueryPlan("keyspace", statement);
// host2 should appear first in the query plan with fewer inflight requests than host1

assertThat(queryPlan.next()).isEqualTo(host2);
assertThat(queryPlan.next()).isEqualTo(host1);
assertThat(queryPlan.next()).isEqualTo(host3);
}
}