Skip to content

Commit

Permalink
TBA auth support for cluster connections
Browse files Browse the repository at this point in the history
  • Loading branch information
ggivo committed Nov 29, 2024
1 parent ec791b2 commit 24b901a
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 134 deletions.
105 changes: 7 additions & 98 deletions src/main/java/io/lettuce/core/RedisAuthenticationHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,125 +19,34 @@
*/
package io.lettuce.core;

import io.lettuce.core.codec.StringCodec;
import io.lettuce.core.protocol.AsyncCommand;
import io.lettuce.core.cluster.pubsub.StatefulRedisClusterPubSubConnection;
import io.lettuce.core.protocol.ProtocolVersion;
import io.lettuce.core.protocol.RedisCommand;
import io.lettuce.core.pubsub.StatefulRedisPubSubConnectionImpl;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;
import reactor.core.Disposable;
import reactor.core.publisher.Flux;

import java.nio.CharBuffer;
import java.util.concurrent.atomic.AtomicReference;

/**
* Handles reauthentication of a connection each time a new authentication token is provided by
* `RenewableRedisCredentialsProvider`.
*
* <p>
* This class is part of the internal API.
*
* @author Ivo Gaydajiev
*/
class RedisAuthenticationHandler {

private static final InternalLogger log = InternalLoggerFactory.getInstance(RedisAuthenticationHandler.class);

private final RedisCommandBuilder<String, String> commandBuilder = new RedisCommandBuilder<>(StringCodec.UTF8);
class RedisAuthenticationHandler extends BaseRedisAuthenticationHandler<StatefulRedisConnectionImpl<?, ?>> {

private final StatefulRedisConnectionImpl<?, ?> connection;

private final AtomicReference<Disposable> credentialsSubscription = new AtomicReference<>();
private static final InternalLogger logger = InternalLoggerFactory.getInstance(RedisChannelHandler.class);

public RedisAuthenticationHandler(StatefulRedisConnectionImpl<?, ?> connection) {
this.connection = connection;
}

/**
* Subscribes to the provided `Flux` of credentials if the given `RedisCredentialsProvider` supports streaming credentials.
*
* Each time new credentials are received, the client is reauthenticated.
*
* @param credentialsProvider the credentials provider to subscribe to
*/
public void subscribe(RedisCredentialsProvider credentialsProvider) {
if (credentialsProvider instanceof RenewableRedisCredentialsProvider) {
if (!isSupportedConnection()) {
return;
}

Flux<RedisCredentials> credentialsFlux = ((RenewableRedisCredentialsProvider) credentialsProvider)
.credentialsStream();

Disposable subscription = credentialsFlux.subscribe(this::onNext, this::onError, this::complete);

Disposable oldSubscription = credentialsSubscription.getAndSet(subscription);
if (oldSubscription != null && !oldSubscription.isDisposed()) {
oldSubscription.dispose();
}
}
}

/**
* Unsubscribes from the current credentials stream.
*/
public void unsubscribe() {
Disposable subscription = credentialsSubscription.getAndSet(null);
if (subscription != null && !subscription.isDisposed()) {
subscription.dispose();
}
}

private void complete() {
log.debug("Credentials stream completed");
}

public void onNext(RedisCredentials credentials) {
reauthenticate(credentials);
}

public void onError(Throwable e) {
log.error("Credentials renew failed.", e);
}

/**
* Performs re-authentication with the provided credentials.
*
* @param credentials the new credentials
*/
private void reauthenticate(RedisCredentials credentials) {
CharSequence password = CharBuffer.wrap(credentials.getPassword());
if (credentials.hasUsername()) {
dispatchAuth(new AsyncCommand<>(commandBuilder.auth(credentials.getUsername(), password)))
.exceptionally(throwable -> {
log.error("Re-authentication with username failed.", throwable);
return null;
});
} else {
dispatchAuth(new AsyncCommand<>(commandBuilder.auth(password))).exceptionally(throwable -> {
log.error("Re-authentication without username failed.", throwable);
return null;
});
}
super(connection);
}

protected boolean isSupportedConnection() {
if (connection instanceof StatefulRedisPubSubConnectionImpl
if (connection instanceof StatefulRedisClusterPubSubConnection
&& ProtocolVersion.RESP2 == connection.getConnectionState().getNegotiatedProtocolVersion()) {
log.warn("Renewable credentials are not supported with RESP2 protocol on a pub/sub connection.");
logger.warn("Renewable credentials are not supported with RESP2 protocol on a pub/sub connection.");
return false;
}
return true;
}

public AsyncCommand<String, String, String> dispatchAuth(RedisCommand<String, String, String> cmd) {
AsyncCommand<String, String, String> asyncCommand = new AsyncCommand<>(cmd);
RedisCommand<String, String, String> dispatched = connection.getChannelWriter().write(asyncCommand);
if (dispatched instanceof AsyncCommand) {
return (AsyncCommand<String, String, String>) dispatched;
}
return asyncCommand;
}

}

This file was deleted.

15 changes: 15 additions & 0 deletions src/main/java/io/lettuce/core/StreamingCredentialsProvider.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package io.lettuce.core;

import reactor.core.publisher.Flux;

public interface StreamingCredentialsProvider extends RedisCredentialsProvider {

/**
* Returns a {@link Flux} emitting {@link RedisCredentials} that can be used to authorize a Redis connection. This
* credential provider supports streaming credentials, meaning that it can emit multiple credentials over time.
*
* @return
*/
Flux<RedisCredentials> credentials();

}
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
package io.lettuce.core;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

Expand All @@ -11,7 +9,7 @@
import redis.clients.authentication.core.TokenListener;
import redis.clients.authentication.core.TokenManager;

public class TokenBasedRedisCredentialsProvider implements RenewableRedisCredentialsProvider {
public class TokenBasedRedisCredentialsProvider implements StreamingCredentialsProvider {

private final TokenManager tokenManager;

Expand Down Expand Up @@ -66,7 +64,7 @@ public Mono<RedisCredentials> resolveCredentials() {
* Expose the Flux for all credential updates.
*/
@Override
public Flux<RedisCredentials> credentialsStream() {
public Flux<RedisCredentials> credentials() {

return credentialsSink.asFlux().onBackpressureLatest(); // Provide a continuous stream of credentials
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Copyright 2019-Present, Redis Ltd. and Contributors
* All rights reserved.
*
* Licensed under the MIT License.
*
* This file contains contributions from third-party contributors
* licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.lettuce.core.cluster;

import io.lettuce.core.*;
import io.lettuce.core.cluster.pubsub.StatefulRedisClusterPubSubConnection;

import io.lettuce.core.protocol.*;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;

class RedisClusterAuthenticationHandler extends BaseRedisAuthenticationHandler<StatefulRedisClusterConnectionImpl<?, ?>> {

private static final InternalLogger logger = InternalLoggerFactory.getInstance(RedisChannelHandler.class);

public RedisClusterAuthenticationHandler(StatefulRedisClusterConnectionImpl<?, ?> connection) {
super(connection);
}

protected boolean isSupportedConnection() {
if (connection instanceof StatefulRedisClusterPubSubConnection
&& ProtocolVersion.RESP2 == connection.getConnectionState().getNegotiatedProtocolVersion()) {
logger.warn("Renewable credentials are not supported with RESP2 protocol on a pub/sub connection.");
return false;
}
return true;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,13 @@
*/
package io.lettuce.core.cluster;

import static io.lettuce.core.protocol.CommandType.*;

import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Proxy;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.function.Consumer;
import java.util.stream.Collectors;

import io.lettuce.core.AbstractRedisClient;
import io.lettuce.core.ClientOptions;
import io.lettuce.core.ConnectionState;
import io.lettuce.core.ReadFrom;
import io.lettuce.core.RedisChannelHandler;
import io.lettuce.core.RedisChannelWriter;
import io.lettuce.core.RedisCredentialsProvider;
import io.lettuce.core.RedisException;
import io.lettuce.core.RedisURI;
import io.lettuce.core.api.StatefulRedisConnection;
Expand All @@ -60,9 +49,23 @@
import io.lettuce.core.protocol.RedisCommand;
import reactor.core.publisher.Mono;

import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Proxy;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.function.Consumer;
import java.util.stream.Collectors;

import static io.lettuce.core.protocol.CommandType.AUTH;
import static io.lettuce.core.protocol.CommandType.READONLY;
import static io.lettuce.core.protocol.CommandType.READWRITE;

/**
* A thread-safe connection to a Redis Cluster. Multiple threads may share one {@link StatefulRedisClusterConnectionImpl}
*
* <p>
* A {@link ConnectionWatchdog} monitors each connection and reconnects automatically until {@link #close} is called. All
* pending commands will be (re)sent after successful reconnection.
*
Expand All @@ -88,6 +91,8 @@ public class StatefulRedisClusterConnectionImpl<K, V> extends RedisChannelHandle

private volatile Partitions partitions;

private final RedisClusterAuthenticationHandler authHandler;

/**
* Initialize a new connection.
*
Expand All @@ -107,6 +112,8 @@ public StatefulRedisClusterConnectionImpl(RedisChannelWriter writer, ClusterPush
this.async = newRedisAdvancedClusterAsyncCommandsImpl();
this.sync = newRedisAdvancedClusterCommandsImpl();
this.reactive = newRedisAdvancedClusterReactiveCommandsImpl();

this.authHandler = new RedisClusterAuthenticationHandler(this);
}

protected RedisAdvancedClusterReactiveCommandsImpl<K, V> newRedisAdvancedClusterReactiveCommandsImpl() {
Expand Down Expand Up @@ -187,8 +194,7 @@ public CompletableFuture<StatefulRedisConnection<K, V>> getConnectionAsync(Strin
throw new RedisException("NodeId " + nodeId + " does not belong to the cluster");
}

AsyncClusterConnectionProvider provider = (AsyncClusterConnectionProvider) getClusterDistributionChannelWriter()
.getClusterConnectionProvider();
AsyncClusterConnectionProvider provider = (AsyncClusterConnectionProvider) getClusterDistributionChannelWriter().getClusterConnectionProvider();

return provider.getConnectionAsync(connectionIntent, nodeId);
}
Expand All @@ -203,8 +209,7 @@ public StatefulRedisConnection<K, V> getConnection(String host, int port, Connec
public CompletableFuture<StatefulRedisConnection<K, V>> getConnectionAsync(String host, int port,
ConnectionIntent connectionIntent) {

AsyncClusterConnectionProvider provider = (AsyncClusterConnectionProvider) getClusterDistributionChannelWriter()
.getClusterConnectionProvider();
AsyncClusterConnectionProvider provider = (AsyncClusterConnectionProvider) getClusterDistributionChannelWriter().getClusterConnectionProvider();

return provider.getConnectionAsync(connectionIntent, host, port);
}
Expand All @@ -213,6 +218,17 @@ public CompletableFuture<StatefulRedisConnection<K, V>> getConnectionAsync(Strin
public void activated() {
super.activated();
async.clusterMyId().thenAccept(connectionState::setNodeId);
RedisCredentialsProvider credentialsProvider = connectionState.getCredentialsProvider();
if (credentialsProvider != null && authHandler != null) {
authHandler.subscribe(credentialsProvider);
}
}

@Override
public void deactivated() {
if (authHandler != null) {
authHandler.unsubscribe();
}
}

ClusterDistributionChannelWriter getClusterDistributionChannelWriter() {
Expand Down Expand Up @@ -249,8 +265,8 @@ private <T> RedisCommand<K, V, T> preProcessCommand(RedisCommand<K, V, T> comman
} else {

List<String> stringArgs = CommandArgsAccessor.getStringArguments(command.getArgs());
this.connectionState
.setUserNamePassword(stringArgs.stream().map(String::toCharArray).collect(Collectors.toList()));
this.connectionState.setUserNamePassword(
stringArgs.stream().map(String::toCharArray).collect(Collectors.toList()));
}
}
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ public void shouldWaitForAndReturnTokenWhenEmittedLater() {

@Test
public void shouldCompleteAllSubscribersOnStop() {
Flux<RedisCredentials> credentialsFlux1 = credentialsProvider.credentialsStream();
Flux<RedisCredentials> credentialsFlux2 = credentialsProvider.credentialsStream();
Flux<RedisCredentials> credentialsFlux1 = credentialsProvider.credentials();
Flux<RedisCredentials> credentialsFlux2 = credentialsProvider.credentials();

Disposable subscription1 = credentialsFlux1.subscribe();
Disposable subscription2 = credentialsFlux2.subscribe();
Expand Down Expand Up @@ -116,7 +116,7 @@ public void shouldCompleteAllSubscribersOnStop() {
@Test
public void shouldPropagateMultipleTokensOnStream() {

Flux<RedisCredentials> result = credentialsProvider.credentialsStream();
Flux<RedisCredentials> result = credentialsProvider.credentials();
StepVerifier.create(result).then(() -> tokenManager.emitToken(testToken("test-user", "token1")))
.then(() -> tokenManager.emitToken(testToken("test-user", "token2")))
.assertNext(credentials -> assertThat(String.valueOf(credentials.getPassword())).isEqualTo("token1"))
Expand All @@ -129,7 +129,7 @@ public void shouldHandleTokenRequestErrorGracefully() {
Exception simulatedError = new RuntimeException("Token request failed");
tokenManager.emitError(simulatedError);

Flux<RedisCredentials> result = credentialsProvider.credentialsStream();
Flux<RedisCredentials> result = credentialsProvider.credentials();

StepVerifier.create(result).expectErrorMatches(
throwable -> throwable instanceof RuntimeException && "Token request failed".equals(throwable.getMessage()))
Expand Down

0 comments on commit 24b901a

Please sign in to comment.