diff --git a/src/main/java/io/lettuce/core/ConnectionBuilder.java b/src/main/java/io/lettuce/core/ConnectionBuilder.java index 2a397b6e1..4bb3127b8 100644 --- a/src/main/java/io/lettuce/core/ConnectionBuilder.java +++ b/src/main/java/io/lettuce/core/ConnectionBuilder.java @@ -113,17 +113,6 @@ public void apply(RedisURI redisURI) { bootstrap.attr(REDIS_URI, redisURI.toString()); } - public void registerAuthenticationHandler(RedisCredentialsProvider credentialsProvider, ConnectionState state, - Boolean isPubSubConnection) { - LettuceAssert.assertState(endpoint != null, "Endpoint must be set"); - LettuceAssert.assertState(connection != null, "Connection must be set"); - LettuceAssert.assertState(clientResources != null, "ClientResources must be set"); - - RedisAuthenticationHandler authenticationHandler = new RedisAuthenticationHandler(connection, credentialsProvider, - state, clientResources.eventBus(), isPubSubConnection); - endpoint.registerAuthenticationHandler(authenticationHandler); - } - protected List buildHandlers() { LettuceAssert.assertState(channelGroup != null, "ChannelGroup must be set"); diff --git a/src/main/java/io/lettuce/core/RedisAuthenticationHandler.java b/src/main/java/io/lettuce/core/RedisAuthenticationHandler.java index cd3cf6239..1bf329e86 100644 --- a/src/main/java/io/lettuce/core/RedisAuthenticationHandler.java +++ b/src/main/java/io/lettuce/core/RedisAuthenticationHandler.java @@ -46,26 +46,18 @@ public class RedisAuthenticationHandler { private static final InternalLogger log = InternalLoggerFactory.getInstance(RedisAuthenticationHandler.class); - private final RedisChannelHandler connection; - - private final ConnectionState state; - - private final RedisCommandBuilder commandBuilder = new RedisCommandBuilder<>(StringCodec.UTF8); + private final StatefulRedisConnectionImpl connection; private final RedisCredentialsProvider credentialsProvider; private final AtomicReference credentialsSubscription = new AtomicReference<>(); - private final EventBus eventBus; - private final Boolean isPubSubConnection; - public RedisAuthenticationHandler(RedisChannelHandler connection, RedisCredentialsProvider credentialsProvider, - ConnectionState state, EventBus eventBus, Boolean isPubSubConnection) { + public RedisAuthenticationHandler(StatefulRedisConnectionImpl connection, + RedisCredentialsProvider credentialsProvider, Boolean isPubSubConnection) { this.connection = connection; - this.state = state; this.credentialsProvider = credentialsProvider; - this.eventBus = eventBus; this.isPubSubConnection = isPubSubConnection; } @@ -125,55 +117,19 @@ protected void onError(Throwable e) { * @param credentials the new credentials */ protected void reauthenticate(RedisCredentials credentials) { - CharSequence password = CharBuffer.wrap(credentials.getPassword()); - - AsyncCommand authCmd; - if (credentials.hasUsername()) { - authCmd = new AsyncCommand<>(commandBuilder.auth(credentials.getUsername(), password)); - } else { - authCmd = new AsyncCommand<>(commandBuilder.auth(password)); - } - - dispatchAuth(authCmd).thenRun(() -> { - publishReauthEvent(); - log.info("Re-authentication succeeded for endpoint {}.", getEpid()); - }).exceptionally(throwable -> { - publishReauthFailedEvent(throwable); - log.error("Re-authentication failed for endpoint {}.", getEpid(), throwable); - return null; - }); - } - - private AsyncCommand dispatchAuth(RedisCommand authCommand) { - AsyncCommand asyncCommand = new AsyncCommand<>(authCommand); - RedisCommand dispatched = connection.dispatch(asyncCommand); - if (dispatched instanceof AsyncCommand) { - return (AsyncCommand) dispatched; - } - return asyncCommand; - } - - private void publishReauthEvent() { - eventBus.publish(new ReauthenticateEvent(getEpid())); - } - - private void publishReauthFailedEvent(Throwable throwable) { - eventBus.publish(new ReauthenticateFailedEvent(getEpid(), throwable)); + connection.setCredentials(credentials); } protected boolean isSupportedConnection() { - if (isPubSubConnection && ProtocolVersion.RESP2 == state.getNegotiatedProtocolVersion()) { + if (isPubSubConnection && ProtocolVersion.RESP2 == connection.getConnectionState().getNegotiatedProtocolVersion()) { log.warn("Renewable credentials are not supported with RESP2 protocol on a pub/sub connection."); return false; } return true; } - private String getEpid() { - if (connection.getChannelWriter() instanceof Endpoint) { - return ((Endpoint) connection.getChannelWriter()).getId(); - } - return "unknown"; + private void publishReauthFailedEvent(Throwable throwable) { + connection.getResources().eventBus().publish(new ReauthenticateFailedEvent(throwable)); } public static boolean isSupported(ClientOptions clientOptions) { diff --git a/src/main/java/io/lettuce/core/RedisClient.java b/src/main/java/io/lettuce/core/RedisClient.java index 26801d949..d2d04e307 100644 --- a/src/main/java/io/lettuce/core/RedisClient.java +++ b/src/main/java/io/lettuce/core/RedisClient.java @@ -317,7 +317,10 @@ private ConnectionFuture connectStatefulAsync(StatefulRedisConnecti ConnectionState state = connection.getConnectionState(); state.apply(redisURI); state.setDb(redisURI.getDatabase()); - + if (RedisAuthenticationHandler.isSupported(getOptions())) { + connection.setAuthenticationHandler( + new RedisAuthenticationHandler(connection, redisURI.getCredentialsProvider(), isPubSub)); + } connectionBuilder.connection(connection); connectionBuilder.clientOptions(getOptions()); connectionBuilder.clientResources(getResources()); @@ -326,11 +329,6 @@ private ConnectionFuture connectStatefulAsync(StatefulRedisConnecti connectionBuilder(getSocketAddressSupplier(redisURI), connectionBuilder, connection.getConnectionEvents(), redisURI); connectionBuilder.connectionInitializer(createHandshake(state)); - if (RedisAuthenticationHandler.isSupported(getOptions())) { - connectionBuilder.registerAuthenticationHandler(redisURI.getCredentialsProvider(), connection.getConnectionState(), - isPubSub); - } - ConnectionFuture> future = initializeChannelAsync(connectionBuilder); return future.thenApply(channelHandler -> (S) connection); diff --git a/src/main/java/io/lettuce/core/StatefulRedisConnectionImpl.java b/src/main/java/io/lettuce/core/StatefulRedisConnectionImpl.java index 853f9ee24..de78c26d0 100644 --- a/src/main/java/io/lettuce/core/StatefulRedisConnectionImpl.java +++ b/src/main/java/io/lettuce/core/StatefulRedisConnectionImpl.java @@ -22,10 +22,14 @@ import static io.lettuce.core.ClientOptions.DEFAULT_JSON_PARSER; import static io.lettuce.core.protocol.CommandType.*; +import java.nio.CharBuffer; import java.time.Duration; import java.util.ArrayList; import java.util.Collection; import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.concurrent.locks.ReentrantLock; import java.util.function.Consumer; import java.util.stream.Collectors; @@ -37,10 +41,14 @@ import io.lettuce.core.cluster.api.sync.RedisClusterCommands; import io.lettuce.core.codec.RedisCodec; import io.lettuce.core.codec.StringCodec; +import io.lettuce.core.event.connection.ReauthenticateEvent; +import io.lettuce.core.event.connection.ReauthenticateFailedEvent; import io.lettuce.core.json.JsonParser; import io.lettuce.core.output.MultiOutput; import io.lettuce.core.output.StatusOutput; import io.lettuce.core.protocol.*; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; import reactor.core.publisher.Mono; /** @@ -55,6 +63,8 @@ */ public class StatefulRedisConnectionImpl extends RedisChannelHandler implements StatefulRedisConnection { + private static final InternalLogger logger = InternalLoggerFactory.getInstance(StatefulRedisConnectionImpl.class); + protected final RedisCodec codec; protected final RedisCommands sync; @@ -71,6 +81,14 @@ public class StatefulRedisConnectionImpl extends RedisChannelHandler protected MultiOutput multi; + private RedisAuthenticationHandler authHandler; + + private AtomicReference credentialsRef = new AtomicReference<>(); + + private final ReentrantLock reAuthSafety = new ReentrantLock(); + + private AtomicBoolean inTransaction = new AtomicBoolean(false); + /** * Initialize a new connection. * @@ -181,7 +199,13 @@ public boolean isMulti() { public RedisCommand dispatch(RedisCommand command) { RedisCommand toSend = preProcessCommand(command); - return super.dispatch(toSend); + RedisCommand result = super.dispatch(toSend); + if (toSend.getType() == EXEC || toSend.getType() == DISCARD) { + inTransaction.set(false); + setCredentials(credentialsRef.getAndSet(null)); + } + + return result; } @Override @@ -189,12 +213,24 @@ public RedisCommand dispatch(RedisCommand command) { List> sentCommands = new ArrayList<>(commands.size()); - commands.forEach(o -> { + boolean transactionComplete = false; + for (RedisCommand o : commands) { RedisCommand command = preProcessCommand(o); sentCommands.add(command); - }); + if (command.getType() == EXEC) { + transactionComplete = true; + } + if (command.getType() == MULTI || command.getType() == DISCARD) { + transactionComplete = false; + } + } - return super.dispatch(sentCommands); + Collection> result = super.dispatch(sentCommands); + if (transactionComplete) { + inTransaction.set(false); + setCredentials(credentialsRef.getAndSet(null)); + } + return result; } // TODO [tihomir.mateev] Refactor to include as part of the Command interface @@ -273,12 +309,20 @@ protected RedisCommand preProcessCommand(RedisCommand comm if (commandType.equals(MULTI.name())) { + reAuthSafety.lock(); + try { + inTransaction.set(true); + } finally { + reAuthSafety.unlock(); + } multi = (multi == null ? new MultiOutput<>(codec) : multi); if (command instanceof CompleteableCommand) { ((CompleteableCommand) command).onComplete((ignored, e) -> { if (e != null) { multi = null; + inTransaction.set(false); + setCredentials(credentialsRef.getAndSet(null)); } }); } @@ -318,11 +362,72 @@ public ConnectionState getConnectionState() { @Override public void activated() { super.activated(); + if (authHandler != null) { + authHandler.subscribe(); + } } @Override public void deactivated() { + if (authHandler != null) { + authHandler.unsubscribe(); + } super.deactivated(); } + public void setAuthenticationHandler(RedisAuthenticationHandler handler) { + authHandler = handler; + } + + public void setCredentials(RedisCredentials credentials) { + if (credentials == null) { + return; + } + reAuthSafety.lock(); + try { + credentialsRef.set(credentials); + if (!inTransaction.get()) { + dispatchAuthCommand(credentialsRef.getAndSet(null)); + } + } finally { + reAuthSafety.unlock(); + } + } + + private void dispatchAuthCommand(RedisCredentials credentials) { + if (credentials == null) { + return; + } + + RedisFuture auth; + if (credentials.getUsername() != null) { + auth = async().auth(credentials.getUsername(), CharBuffer.wrap(credentials.getPassword())); + } else { + auth = async().auth(CharBuffer.wrap(credentials.getPassword())); + } + auth.thenRun(() -> { + publishReauthEvent(); + logger.info("Re-authentication succeeded {}.", getEpid()); + }).exceptionally(throwable -> { + publishReauthFailedEvent(throwable); + logger.error("Re-authentication failed {}.", getEpid(), throwable); + return null; + }); + } + + private void publishReauthEvent() { + getResources().eventBus().publish(new ReauthenticateEvent(getEpid())); + } + + private void publishReauthFailedEvent(Throwable throwable) { + getResources().eventBus().publish(new ReauthenticateFailedEvent(getEpid(), throwable)); + } + + private String getEpid() { + if (getChannelWriter() instanceof Endpoint) { + return ((Endpoint) getChannelWriter()).getId(); + } + return ""; + } + } diff --git a/src/main/java/io/lettuce/core/api/StatefulRedisConnection.java b/src/main/java/io/lettuce/core/api/StatefulRedisConnection.java index 1be962ba3..7cac57887 100644 --- a/src/main/java/io/lettuce/core/api/StatefulRedisConnection.java +++ b/src/main/java/io/lettuce/core/api/StatefulRedisConnection.java @@ -1,5 +1,6 @@ package io.lettuce.core.api; +import io.lettuce.core.RedisCredentials; import io.lettuce.core.api.async.RedisAsyncCommands; import io.lettuce.core.api.push.PushListener; import io.lettuce.core.api.reactive.RedisReactiveCommands; diff --git a/src/main/java/io/lettuce/core/cluster/RedisClusterClient.java b/src/main/java/io/lettuce/core/cluster/RedisClusterClient.java index f384cbda0..422dcbb15 100644 --- a/src/main/java/io/lettuce/core/cluster/RedisClusterClient.java +++ b/src/main/java/io/lettuce/core/cluster/RedisClusterClient.java @@ -558,7 +558,12 @@ ConnectionFuture> connectToNodeAsync(RedisC ConnectionFuture> connectionFuture = connectStatefulAsync(connection, endpoint, getFirstUri(), socketAddressSupplier, - () -> new CommandHandler(getClusterClientOptions(), getResources(), endpoint), false); + () -> new CommandHandler(getClusterClientOptions(), getResources(), endpoint)); + + if (RedisAuthenticationHandler.isSupported(getOptions())) { + connection.setAuthenticationHandler( + new RedisAuthenticationHandler(connection, getFirstUri().getCredentialsProvider(), false)); + } return connectionFuture.whenComplete((conn, throwable) -> { if (throwable != null) { @@ -623,7 +628,13 @@ ConnectionFuture> connectPubSubToNode ConnectionFuture> connectionFuture = connectStatefulAsync(connection, endpoint, getFirstUri(), socketAddressSupplier, - () -> new PubSubCommandHandler<>(getClusterClientOptions(), getResources(), codec, endpoint), true); + () -> new PubSubCommandHandler<>(getClusterClientOptions(), getResources(), codec, endpoint)); + + if (RedisAuthenticationHandler.isSupported(getOptions())) { + connection.setAuthenticationHandler( + new RedisAuthenticationHandler(connection, getFirstUri().getCredentialsProvider(), true)); + } + return connectionFuture.whenComplete((conn, throwable) -> { if (throwable != null) { connection.closeAsync(); @@ -679,11 +690,11 @@ private CompletableFuture> connectCl Mono socketAddressSupplier = getSocketAddressSupplier(connection::getPartitions, TopologyComparators::sortByClientCount); Mono> connectionMono = Mono - .defer(() -> connect(socketAddressSupplier, endpoint, connection, commandHandlerSupplier, false)); + .defer(() -> connect(socketAddressSupplier, endpoint, connection, commandHandlerSupplier)); for (int i = 1; i < getConnectionAttempts(); i++) { connectionMono = connectionMono - .onErrorResume(t -> connect(socketAddressSupplier, endpoint, connection, commandHandlerSupplier, false)); + .onErrorResume(t -> connect(socketAddressSupplier, endpoint, connection, commandHandlerSupplier)); } return connectionMono @@ -713,20 +724,19 @@ protected StatefulRedisClusterConnectionImpl newStatefulRedisCluste } private Mono connect(Mono socketAddressSupplier, DefaultEndpoint endpoint, - StatefulRedisClusterConnectionImpl connection, Supplier commandHandlerSupplier, - Boolean isPubSub) { + StatefulRedisClusterConnectionImpl connection, Supplier commandHandlerSupplier) { ConnectionFuture future = connectStatefulAsync(connection, endpoint, getFirstUri(), socketAddressSupplier, - commandHandlerSupplier, isPubSub); + commandHandlerSupplier); return Mono.fromCompletionStage(future).doOnError(t -> logger.warn(t.getMessage())); } private Mono connect(Mono socketAddressSupplier, DefaultEndpoint endpoint, - StatefulRedisConnectionImpl connection, Supplier commandHandlerSupplier, Boolean isPubSub) { + StatefulRedisConnectionImpl connection, Supplier commandHandlerSupplier) { ConnectionFuture future = connectStatefulAsync(connection, endpoint, getFirstUri(), socketAddressSupplier, - commandHandlerSupplier, isPubSub); + commandHandlerSupplier); return Mono.fromCompletionStage(future).doOnError(t -> logger.warn(t.getMessage())); } @@ -779,11 +789,11 @@ private CompletableFuture> con Mono socketAddressSupplier = getSocketAddressSupplier(connection::getPartitions, TopologyComparators::sortByClientCount); Mono> connectionMono = Mono - .defer(() -> connect(socketAddressSupplier, endpoint, connection, commandHandlerSupplier, true)); + .defer(() -> connect(socketAddressSupplier, endpoint, connection, commandHandlerSupplier)); for (int i = 1; i < getConnectionAttempts(); i++) { connectionMono = connectionMono - .onErrorResume(t -> connect(socketAddressSupplier, endpoint, connection, commandHandlerSupplier, true)); + .onErrorResume(t -> connect(socketAddressSupplier, endpoint, connection, commandHandlerSupplier)); } return connectionMono @@ -803,10 +813,10 @@ private int getConnectionAttempts() { @SuppressWarnings("unchecked") private , S> ConnectionFuture connectStatefulAsync(T connection, DefaultEndpoint endpoint, RedisURI connectionSettings, Mono socketAddressSupplier, - Supplier commandHandlerSupplier, Boolean isPubSub) { + Supplier commandHandlerSupplier) { ConnectionBuilder connectionBuilder = createConnectionBuilder(connection, connection.getConnectionState(), endpoint, - connectionSettings, socketAddressSupplier, commandHandlerSupplier, isPubSub); + connectionSettings, socketAddressSupplier, commandHandlerSupplier); ConnectionFuture> future = initializeChannelAsync(connectionBuilder); @@ -820,10 +830,10 @@ private , S> Connection @SuppressWarnings("unchecked") private , S> ConnectionFuture connectStatefulAsync(T connection, DefaultEndpoint endpoint, RedisURI connectionSettings, Mono socketAddressSupplier, - Supplier commandHandlerSupplier, Boolean isPubSub) { + Supplier commandHandlerSupplier) { ConnectionBuilder connectionBuilder = createConnectionBuilder(connection, connection.getConnectionState(), endpoint, - connectionSettings, socketAddressSupplier, commandHandlerSupplier, isPubSub); + connectionSettings, socketAddressSupplier, commandHandlerSupplier); ConnectionFuture> future = initializeChannelAsync(connectionBuilder); @@ -832,7 +842,7 @@ private , S> ConnectionFuture< private ConnectionBuilder createConnectionBuilder(RedisChannelHandler connection, ConnectionState state, DefaultEndpoint endpoint, RedisURI connectionSettings, Mono socketAddressSupplier, - Supplier commandHandlerSupplier, Boolean isPubSub) { + Supplier commandHandlerSupplier) { ConnectionBuilder connectionBuilder; if (connectionSettings.isSsl()) { @@ -854,10 +864,6 @@ private ConnectionBuilder createConnectionBuilder(RedisChannelHandler channelHandler; + private StatefulRedisConnectionImpl connection; + + ClientResources resources; EventBus eventBus; @@ -34,8 +38,13 @@ public class RedisAuthenticationHandlerTest { @BeforeEach void setUp() { eventBus = new DefaultEventBus(Schedulers.immediate()); - channelHandler = mock(RedisChannelHandler.class); + connection = mock(StatefulRedisConnectionImpl.class); + resources = mock(ClientResources.class); + when(resources.eventBus()).thenReturn(eventBus); + connectionState = mock(ConnectionState.class); + when(connection.getResources()).thenReturn(resources); + when(connection.getConnectionState()).thenReturn(connectionState); } @SuppressWarnings("unchecked") @@ -43,20 +52,18 @@ void setUp() { void subscribeWithStreamingCredentialsProviderInvokesReauth() { MyStreamingRedisCredentialsProvider credentialsProvider = new MyStreamingRedisCredentialsProvider(); - RedisAuthenticationHandler handler = new RedisAuthenticationHandler(channelHandler, credentialsProvider, - connectionState, eventBus, false); + RedisAuthenticationHandler handler = new RedisAuthenticationHandler(connection, credentialsProvider, false); // Subscribe to the provider handler.subscribe(); credentialsProvider.emitCredentials("newuser", "newpassword".toCharArray()); - ArgumentCaptor> captor = ArgumentCaptor.forClass(RedisCommand.class); - verify(channelHandler).dispatch(captor.capture()); + ArgumentCaptor captor = ArgumentCaptor.forClass(RedisCredentials.class); + verify(connection).setCredentials(captor.capture()); - RedisCommand capturedCommand = captor.getValue(); - assertThat(capturedCommand.getType()).isEqualTo(CommandType.AUTH); - assertThat(capturedCommand.getArgs().toCommandString()).contains("newuser"); - assertThat(capturedCommand.getArgs().toCommandString()).contains("newpassword"); + RedisCredentials credentials = captor.getValue(); + assertThat(credentials.getUsername()).isEqualTo("newuser"); + assertThat(credentials.getPassword()).isEqualTo("newpassword".toCharArray()); credentialsProvider.shutdown(); } @@ -65,10 +72,9 @@ void subscribeWithStreamingCredentialsProviderInvokesReauth() { void shouldHandleErrorInCredentialsStream() { MyStreamingRedisCredentialsProvider credentialsProvider = new MyStreamingRedisCredentialsProvider(); - RedisAuthenticationHandler handler = new RedisAuthenticationHandler(channelHandler, credentialsProvider, - connectionState, eventBus, false); + RedisAuthenticationHandler handler = new RedisAuthenticationHandler(connection, credentialsProvider, false); - verify(channelHandler, times(0)).dispatch(any(RedisCommand.class)); // No command should be sent + verify(connection, times(0)).dispatch(any(RedisCommand.class)); // No command should be sent // Verify the event was published StepVerifier.create(eventBus.get()).then(() -> { @@ -82,10 +88,9 @@ void shouldHandleErrorInCredentialsStream() { @Test void shouldNotSubscribeIfConnectionIsNotSupported() { StreamingCredentialsProvider credentialsProvider = mock(StreamingCredentialsProvider.class); - when(connectionState.getNegotiatedProtocolVersion()).thenReturn(ProtocolVersion.RESP2); - RedisAuthenticationHandler handler = new RedisAuthenticationHandler(channelHandler, credentialsProvider, - connectionState, eventBus, true); + + RedisAuthenticationHandler handler = new RedisAuthenticationHandler(connection, credentialsProvider, true); // Subscribe to the provider (it should not subscribe due to unsupported connection) handler.subscribe(); @@ -96,25 +101,22 @@ void shouldNotSubscribeIfConnectionIsNotSupported() { @Test void testIsSupportedConnectionWithRESP2ProtocolOnPubSubConnection() { - RedisChannelHandler connection = mock(RedisChannelHandler.class); - ConnectionState connectionState = mock(ConnectionState.class); when(connectionState.getNegotiatedProtocolVersion()).thenReturn(ProtocolVersion.RESP2); RedisAuthenticationHandler handler = new RedisAuthenticationHandler(connection, mock(RedisCredentialsProvider.class), - connectionState, mock(EventBus.class), true); + true); assertFalse(handler.isSupportedConnection()); } @Test void testIsSupportedConnectionWithNonPubSubConnection() { - RedisChannelHandler connection = mock(RedisChannelHandler.class); - ConnectionState connectionState = mock(ConnectionState.class); + when(connectionState.getNegotiatedProtocolVersion()).thenReturn(ProtocolVersion.RESP2); RedisAuthenticationHandler handler = new RedisAuthenticationHandler(connection, mock(RedisCredentialsProvider.class), - connectionState, mock(EventBus.class), false); + false); assertTrue(handler.isSupportedConnection()); } @@ -122,12 +124,10 @@ void testIsSupportedConnectionWithNonPubSubConnection() { @Test void testIsSupportedConnectionWithRESP3ProtocolOnPubSubConnection() { - RedisChannelHandler connection = mock(RedisChannelHandler.class); - ConnectionState connectionState = mock(ConnectionState.class); when(connectionState.getNegotiatedProtocolVersion()).thenReturn(ProtocolVersion.RESP3); RedisAuthenticationHandler handler = new RedisAuthenticationHandler(connection, mock(RedisCredentialsProvider.class), - connectionState, mock(EventBus.class), true); + true); assertTrue(handler.isSupportedConnection()); } diff --git a/src/test/java/io/lettuce/core/event/ConnectionEventsTriggeredIntegrationTests.java b/src/test/java/io/lettuce/core/event/ConnectionEventsTriggeredIntegrationTests.java index 732e5dcda..ee7fb43e1 100644 --- a/src/test/java/io/lettuce/core/event/ConnectionEventsTriggeredIntegrationTests.java +++ b/src/test/java/io/lettuce/core/event/ConnectionEventsTriggeredIntegrationTests.java @@ -11,12 +11,14 @@ import io.lettuce.core.event.connection.AuthenticateEvent; import io.lettuce.core.event.connection.ReauthenticateEvent; import io.lettuce.core.event.connection.ReauthenticateFailedEvent; +import io.lettuce.test.LettuceExtension; import io.lettuce.test.WithPassword; import io.lettuce.test.settings.TestSettings; import org.assertj.core.api.InstanceOfAssertFactories; import org.junit.jupiter.api.Tag; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; import reactor.core.publisher.Flux; import reactor.test.StepVerifier; import io.lettuce.core.RedisClient; @@ -26,11 +28,14 @@ import io.lettuce.test.resource.FastShutdown; import io.lettuce.test.resource.TestClientResources; +import javax.inject.Inject; + /** * @author Mark Paluch * @author Ivo Gaydajiev */ @Tag(INTEGRATION_TEST) +@ExtendWith(LettuceExtension.class) class ConnectionEventsTriggeredIntegrationTests extends TestSupport { @Test @@ -51,26 +56,25 @@ void testConnectionEvents() { } @Test - void testReauthenticateEvents() { + @Inject + void testReauthenticateEvents(RedisClient client) { MyStreamingRedisCredentialsProvider credentialsProvider = new MyStreamingRedisCredentialsProvider(); credentialsProvider.emitCredentials(TestSettings.username(), TestSettings.password().toString().toCharArray()); - RedisClient client = RedisClient.create(TestClientResources.get(), - RedisURI.Builder.redis(host, port).withAuthentication(credentialsProvider).build()); client.setOptions(ClientOptions.builder() .reauthenticateBehavior(ClientOptions.ReauthenticateBehavior.ON_NEW_CREDENTIALS).build()); + RedisURI uri = RedisURI.Builder.redis(host, port).withAuthentication(credentialsProvider).build(); Flux publisher = client.getResources().eventBus().get() .filter(event -> event instanceof AuthenticateEvent).cast(AuthenticateEvent.class); - - StepVerifier.create(publisher).then(() -> WithPassword.run(client, () -> client.connect().close())) + WithPassword.run(client, () -> StepVerifier.create(publisher).then(() -> client.connect(uri)) .assertNext(event -> assertThat(event).asInstanceOf(InstanceOfAssertFactories.type(ReauthenticateEvent.class)) .extracting(ReauthenticateEvent::getEpId).isNotNull()) .then(() -> credentialsProvider.emitCredentials(TestSettings.username(), "invalid".toCharArray())) .assertNext( event -> assertThat(event).asInstanceOf(InstanceOfAssertFactories.type(ReauthenticateFailedEvent.class)) .extracting(ReauthenticateFailedEvent::getEpId).isNotNull()) - .thenCancel().verify(Duration.of(1, ChronoUnit.SECONDS)); + .thenCancel().verify(Duration.of(1, ChronoUnit.SECONDS))); FastShutdown.shutdown(client); }