Skip to content

Commit

Permalink
Support multi with re-auth
Browse files Browse the repository at this point in the history
Defer the re-auth operation in case there is on-going multi
Tx in lettuce need to be externally synchronised when used in multithreaded env. Since re-auth happens from different thread we need to make sure it does not happen while there is ongoing transaction.
  • Loading branch information
ggivo committed Dec 11, 2024
1 parent 6f46022 commit 7185e22
Show file tree
Hide file tree
Showing 11 changed files with 187 additions and 145 deletions.
11 changes: 0 additions & 11 deletions src/main/java/io/lettuce/core/ConnectionBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -113,17 +113,6 @@ public void apply(RedisURI redisURI) {
bootstrap.attr(REDIS_URI, redisURI.toString());
}

public void registerAuthenticationHandler(RedisCredentialsProvider credentialsProvider, ConnectionState state,
Boolean isPubSubConnection) {
LettuceAssert.assertState(endpoint != null, "Endpoint must be set");
LettuceAssert.assertState(connection != null, "Connection must be set");
LettuceAssert.assertState(clientResources != null, "ClientResources must be set");

RedisAuthenticationHandler authenticationHandler = new RedisAuthenticationHandler(connection, credentialsProvider,
state, clientResources.eventBus(), isPubSubConnection);
endpoint.registerAuthenticationHandler(authenticationHandler);
}

protected List<ChannelHandler> buildHandlers() {

LettuceAssert.assertState(channelGroup != null, "ChannelGroup must be set");
Expand Down
58 changes: 7 additions & 51 deletions src/main/java/io/lettuce/core/RedisAuthenticationHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -46,26 +46,18 @@ public class RedisAuthenticationHandler {

private static final InternalLogger log = InternalLoggerFactory.getInstance(RedisAuthenticationHandler.class);

private final RedisChannelHandler<?, ?> connection;

private final ConnectionState state;

private final RedisCommandBuilder<String, String> commandBuilder = new RedisCommandBuilder<>(StringCodec.UTF8);
private final StatefulRedisConnectionImpl<?, ?> connection;

private final RedisCredentialsProvider credentialsProvider;

private final AtomicReference<Disposable> credentialsSubscription = new AtomicReference<>();

private final EventBus eventBus;

private final Boolean isPubSubConnection;

public RedisAuthenticationHandler(RedisChannelHandler<?, ?> connection, RedisCredentialsProvider credentialsProvider,
ConnectionState state, EventBus eventBus, Boolean isPubSubConnection) {
public RedisAuthenticationHandler(StatefulRedisConnectionImpl<?, ?> connection,
RedisCredentialsProvider credentialsProvider, Boolean isPubSubConnection) {
this.connection = connection;
this.state = state;
this.credentialsProvider = credentialsProvider;
this.eventBus = eventBus;
this.isPubSubConnection = isPubSubConnection;
}

Expand Down Expand Up @@ -125,55 +117,19 @@ protected void onError(Throwable e) {
* @param credentials the new credentials
*/
protected void reauthenticate(RedisCredentials credentials) {
CharSequence password = CharBuffer.wrap(credentials.getPassword());

AsyncCommand<String, String, String> authCmd;
if (credentials.hasUsername()) {
authCmd = new AsyncCommand<>(commandBuilder.auth(credentials.getUsername(), password));
} else {
authCmd = new AsyncCommand<>(commandBuilder.auth(password));
}

dispatchAuth(authCmd).thenRun(() -> {
publishReauthEvent();
log.info("Re-authentication succeeded for endpoint {}.", getEpid());
}).exceptionally(throwable -> {
publishReauthFailedEvent(throwable);
log.error("Re-authentication failed for endpoint {}.", getEpid(), throwable);
return null;
});
}

private AsyncCommand<?, ?, ?> dispatchAuth(RedisCommand<?, ?, ?> authCommand) {
AsyncCommand asyncCommand = new AsyncCommand<>(authCommand);
RedisCommand<?, ?, ?> dispatched = connection.dispatch(asyncCommand);
if (dispatched instanceof AsyncCommand) {
return (AsyncCommand<?, ?, ?>) dispatched;
}
return asyncCommand;
}

private void publishReauthEvent() {
eventBus.publish(new ReauthenticateEvent(getEpid()));
}

private void publishReauthFailedEvent(Throwable throwable) {
eventBus.publish(new ReauthenticateFailedEvent(getEpid(), throwable));
connection.setCredentials(credentials);
}

protected boolean isSupportedConnection() {
if (isPubSubConnection && ProtocolVersion.RESP2 == state.getNegotiatedProtocolVersion()) {
if (isPubSubConnection && ProtocolVersion.RESP2 == connection.getConnectionState().getNegotiatedProtocolVersion()) {
log.warn("Renewable credentials are not supported with RESP2 protocol on a pub/sub connection.");
return false;
}
return true;
}

private String getEpid() {
if (connection.getChannelWriter() instanceof Endpoint) {
return ((Endpoint) connection.getChannelWriter()).getId();
}
return "unknown";
private void publishReauthFailedEvent(Throwable throwable) {
connection.getResources().eventBus().publish(new ReauthenticateFailedEvent(throwable));
}

public static boolean isSupported(ClientOptions clientOptions) {
Expand Down
10 changes: 4 additions & 6 deletions src/main/java/io/lettuce/core/RedisClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,10 @@ private <K, V, S> ConnectionFuture<S> connectStatefulAsync(StatefulRedisConnecti
ConnectionState state = connection.getConnectionState();
state.apply(redisURI);
state.setDb(redisURI.getDatabase());

if (RedisAuthenticationHandler.isSupported(getOptions())) {
connection.setAuthenticationHandler(
new RedisAuthenticationHandler(connection, redisURI.getCredentialsProvider(), isPubSub));
}
connectionBuilder.connection(connection);
connectionBuilder.clientOptions(getOptions());
connectionBuilder.clientResources(getResources());
Expand All @@ -326,11 +329,6 @@ private <K, V, S> ConnectionFuture<S> connectStatefulAsync(StatefulRedisConnecti
connectionBuilder(getSocketAddressSupplier(redisURI), connectionBuilder, connection.getConnectionEvents(), redisURI);
connectionBuilder.connectionInitializer(createHandshake(state));

if (RedisAuthenticationHandler.isSupported(getOptions())) {
connectionBuilder.registerAuthenticationHandler(redisURI.getCredentialsProvider(), connection.getConnectionState(),
isPubSub);
}

ConnectionFuture<RedisChannelHandler<K, V>> future = initializeChannelAsync(connectionBuilder);

return future.thenApply(channelHandler -> (S) connection);
Expand Down
113 changes: 109 additions & 4 deletions src/main/java/io/lettuce/core/StatefulRedisConnectionImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,14 @@
import static io.lettuce.core.ClientOptions.DEFAULT_JSON_PARSER;
import static io.lettuce.core.protocol.CommandType.*;

import java.nio.CharBuffer;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Consumer;
import java.util.stream.Collectors;

Expand All @@ -37,10 +41,14 @@
import io.lettuce.core.cluster.api.sync.RedisClusterCommands;
import io.lettuce.core.codec.RedisCodec;
import io.lettuce.core.codec.StringCodec;
import io.lettuce.core.event.connection.ReauthenticateEvent;
import io.lettuce.core.event.connection.ReauthenticateFailedEvent;
import io.lettuce.core.json.JsonParser;
import io.lettuce.core.output.MultiOutput;
import io.lettuce.core.output.StatusOutput;
import io.lettuce.core.protocol.*;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;
import reactor.core.publisher.Mono;

/**
Expand All @@ -55,6 +63,8 @@
*/
public class StatefulRedisConnectionImpl<K, V> extends RedisChannelHandler<K, V> implements StatefulRedisConnection<K, V> {

private static final InternalLogger logger = InternalLoggerFactory.getInstance(StatefulRedisConnectionImpl.class);

protected final RedisCodec<K, V> codec;

protected final RedisCommands<K, V> sync;
Expand All @@ -71,6 +81,14 @@ public class StatefulRedisConnectionImpl<K, V> extends RedisChannelHandler<K, V>

protected MultiOutput<K, V> multi;

private RedisAuthenticationHandler authHandler;

private AtomicReference<RedisCredentials> credentialsRef = new AtomicReference<>();

private final ReentrantLock reAuthSafety = new ReentrantLock();

private AtomicBoolean inTransaction = new AtomicBoolean(false);

/**
* Initialize a new connection.
*
Expand Down Expand Up @@ -181,20 +199,38 @@ public boolean isMulti() {
public <T> RedisCommand<K, V, T> dispatch(RedisCommand<K, V, T> command) {

RedisCommand<K, V, T> toSend = preProcessCommand(command);
return super.dispatch(toSend);
RedisCommand<K, V, T> result = super.dispatch(toSend);
if (toSend.getType() == EXEC || toSend.getType() == DISCARD) {
inTransaction.set(false);
setCredentials(credentialsRef.getAndSet(null));
}

return result;
}

@Override
public Collection<RedisCommand<K, V, ?>> dispatch(Collection<? extends RedisCommand<K, V, ?>> commands) {

List<RedisCommand<K, V, ?>> sentCommands = new ArrayList<>(commands.size());

commands.forEach(o -> {
boolean transactionComplete = false;
for (RedisCommand<K, V, ?> o : commands) {
RedisCommand<K, V, ?> command = preProcessCommand(o);
sentCommands.add(command);
});
if (command.getType() == EXEC) {
transactionComplete = true;
}
if (command.getType() == MULTI || command.getType() == DISCARD) {
transactionComplete = false;
}
}

return super.dispatch(sentCommands);
Collection<RedisCommand<K, V, ?>> result = super.dispatch(sentCommands);
if (transactionComplete) {
inTransaction.set(false);
setCredentials(credentialsRef.getAndSet(null));
}
return result;
}

// TODO [tihomir.mateev] Refactor to include as part of the Command interface
Expand Down Expand Up @@ -273,12 +309,20 @@ protected <T> RedisCommand<K, V, T> preProcessCommand(RedisCommand<K, V, T> comm

if (commandType.equals(MULTI.name())) {

reAuthSafety.lock();
try {
inTransaction.set(true);
} finally {
reAuthSafety.unlock();
}
multi = (multi == null ? new MultiOutput<>(codec) : multi);

if (command instanceof CompleteableCommand) {
((CompleteableCommand<?>) command).onComplete((ignored, e) -> {
if (e != null) {
multi = null;
inTransaction.set(false);
setCredentials(credentialsRef.getAndSet(null));
}
});
}
Expand Down Expand Up @@ -318,11 +362,72 @@ public ConnectionState getConnectionState() {
@Override
public void activated() {
super.activated();
if (authHandler != null) {
authHandler.subscribe();
}
}

@Override
public void deactivated() {
if (authHandler != null) {
authHandler.unsubscribe();
}
super.deactivated();
}

public void setAuthenticationHandler(RedisAuthenticationHandler handler) {
authHandler = handler;
}

public void setCredentials(RedisCredentials credentials) {
if (credentials == null) {
return;
}
reAuthSafety.lock();
try {
credentialsRef.set(credentials);
if (!inTransaction.get()) {
dispatchAuthCommand(credentialsRef.getAndSet(null));
}
} finally {
reAuthSafety.unlock();
}
}

private void dispatchAuthCommand(RedisCredentials credentials) {
if (credentials == null) {
return;
}

RedisFuture<String> auth;
if (credentials.getUsername() != null) {
auth = async().auth(credentials.getUsername(), CharBuffer.wrap(credentials.getPassword()));
} else {
auth = async().auth(CharBuffer.wrap(credentials.getPassword()));
}
auth.thenRun(() -> {
publishReauthEvent();
logger.info("Re-authentication succeeded {}.", getEpid());
}).exceptionally(throwable -> {
publishReauthFailedEvent(throwable);
logger.error("Re-authentication failed {}.", getEpid(), throwable);
return null;
});
}

private void publishReauthEvent() {
getResources().eventBus().publish(new ReauthenticateEvent(getEpid()));
}

private void publishReauthFailedEvent(Throwable throwable) {
getResources().eventBus().publish(new ReauthenticateFailedEvent(getEpid(), throwable));
}

private String getEpid() {
if (getChannelWriter() instanceof Endpoint) {
return ((Endpoint) getChannelWriter()).getId();
}
return "";
}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.lettuce.core.api;

import io.lettuce.core.RedisCredentials;
import io.lettuce.core.api.async.RedisAsyncCommands;
import io.lettuce.core.api.push.PushListener;
import io.lettuce.core.api.reactive.RedisReactiveCommands;
Expand Down
Loading

0 comments on commit 7185e22

Please sign in to comment.