diff --git a/src/main/java/io/lettuce/core/StatefulRedisConnectionImpl.java b/src/main/java/io/lettuce/core/StatefulRedisConnectionImpl.java index 57d164dc2..c85191a1e 100644 --- a/src/main/java/io/lettuce/core/StatefulRedisConnectionImpl.java +++ b/src/main/java/io/lettuce/core/StatefulRedisConnectionImpl.java @@ -22,7 +22,6 @@ import static io.lettuce.core.ClientOptions.DEFAULT_JSON_PARSER; import static io.lettuce.core.protocol.CommandType.*; -import java.nio.CharBuffer; import java.time.Duration; import java.util.ArrayList; import java.util.Collection; @@ -355,6 +354,29 @@ public void setClientName(String clientName) { dispatch((RedisCommand) async); } + /** + * Authenticates the current connection using the provided credentials. + *

+ * Unlike using dispatch of {@link RedisAsyncCommands#auth}, this method defers the {@code AUTH} command if the connection is within an active + * transaction. The authentication command will only be dispatched after the enclosing {@code DISCARD} or {@code EXEC} + * command is executed, ensuring that authentication does not interfere with ongoing transactions. + *

+ * + * @param credentials the {@link RedisCredentials} to authenticate the connection. If {@code null}, no action is performed. + * + *

+ * Behavior: + *

+ *

+ * + * @see RedisAsyncCommands#auth + */ public void setCredentials(RedisCredentials credentials) { if (credentials == null) { return; @@ -363,7 +385,7 @@ public void setCredentials(RedisCredentials credentials) { try { credentialsRef.set(credentials); if (!inTransaction.get()) { - dispatchAuthCommand(credentialsRef.getAndSet(null)); + dispatchAuth(credentialsRef.getAndSet(null)); } } finally { reAuthSafety.unlock(); @@ -394,16 +416,16 @@ public void setAuthenticationHandler(RedisAuthenticationHandler handler) { authHandler = handler; } - private void dispatchAuthCommand(RedisCredentials credentials) { + protected void dispatchAuth(RedisCredentials credentials) { if (credentials == null) { return; } RedisFuture auth; if (credentials.getUsername() != null) { - auth = async().auth(credentials.getUsername(), CharBuffer.wrap(credentials.getPassword())); + auth = async().auth(credentials.getUsername(), String.valueOf(credentials.getPassword())); } else { - auth = async().auth(CharBuffer.wrap(credentials.getPassword())); + auth = async().auth(String.valueOf(credentials.getPassword())); } auth.thenRun(() -> { publishReauthEvent(); @@ -441,5 +463,4 @@ private String getEpid() { return ((Endpoint) writer).getId(); } - } diff --git a/src/test/java/io/lettuce/core/StatefulRedisConnectionImplUnitTests.java b/src/test/java/io/lettuce/core/StatefulRedisConnectionImplUnitTests.java new file mode 100644 index 000000000..245eef09b --- /dev/null +++ b/src/test/java/io/lettuce/core/StatefulRedisConnectionImplUnitTests.java @@ -0,0 +1,142 @@ +package io.lettuce.core; + +import io.lettuce.core.codec.StringCodec; +import io.lettuce.core.protocol.AsyncCommand; +import io.lettuce.core.protocol.PushHandler; +import io.lettuce.core.resource.ClientResources; +import io.lettuce.core.tracing.Tracing; +import io.lettuce.test.ReflectionTestUtils; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.mockito.junit.jupiter.MockitoSettings; +import org.mockito.quality.Strictness; + +import java.lang.reflect.Field; +import java.time.Duration; +import java.util.concurrent.atomic.AtomicBoolean; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +@MockitoSettings(strictness = Strictness.LENIENT) +public class StatefulRedisConnectionImplUnitTests extends TestSupport { + + RedisCommandBuilder commandBuilder = new RedisCommandBuilder<>(StringCodec.UTF8); + StatefulRedisConnectionImpl connection; + + @Mock + RedisAsyncCommandsImpl asyncCommands; + + @Mock + PushHandler pushHandler; + + @Mock + RedisChannelWriter writer; + + @Mock + ClientResources clientResources; + + @Mock + Tracing tracing; + + @BeforeEach + void setup() throws NoSuchFieldException, IllegalAccessException { + when(writer.getClientResources()).thenReturn(clientResources); + when(clientResources.tracing()).thenReturn(tracing); + when(tracing.isEnabled()).thenReturn(false); + when(asyncCommands.auth(any(CharSequence.class))) + .thenAnswer( invocation -> { + String pass = invocation.getArgument(0); + AsyncCommand auth = new AsyncCommand<>(commandBuilder.auth(pass)); + auth.complete(); + return auth; + }); + when(asyncCommands.auth(anyString(), any(CharSequence.class))) + .thenAnswer( invocation -> { + String user = invocation.getArgument(0); // Capture username + String pass = invocation.getArgument(1); // Capture password + AsyncCommand auth = new AsyncCommand<>(commandBuilder.auth(user, pass)); + auth.complete(); + return auth; + }); + + Field asyncField = StatefulRedisConnectionImpl.class.getDeclaredField("async"); + asyncField.setAccessible(true); + + + connection = new StatefulRedisConnectionImpl<>(writer, pushHandler, StringCodec.UTF8, Duration.ofSeconds(1)); + asyncField.set(connection,asyncCommands); + } + + @Test + public void testSetCredentialsWhenCredentialsAreNull() { + connection.setCredentials(null); + + verify(asyncCommands, never()).auth(any(CharSequence.class)); + verify(asyncCommands, never()).auth(anyString(), any(CharSequence.class)); + } + + @Test + void testSetCredentialsDispatchesAuthWhenNotInTransaction() { + connection.setCredentials(new StaticRedisCredentials("user", "pass".toCharArray())); + verify(asyncCommands).auth(eq("user"), eq("pass")); + } + + + @Test + void testSetCredentialsDoesNotDispatchAuthIfInTransaction() { + AtomicBoolean inTransaction = ReflectionTestUtils.getField(connection, "inTransaction"); + inTransaction.set(true); + + connection.setCredentials(new StaticRedisCredentials("user", "pass".toCharArray())); + + verify(asyncCommands, never()).auth(any(CharSequence.class)); + verify(asyncCommands, never()).auth(anyString(), any(CharSequence.class)); + } + + + @Test + void testSetCredentialsDispatchesAuthAfterTransaction() { + AtomicBoolean inTransaction = ReflectionTestUtils.getField(connection, "inTransaction"); + + connection.dispatch(commandBuilder.multi()); + assertThat(inTransaction.get()).isTrue(); + + connection.setCredentials(new StaticRedisCredentials("user", "pass".toCharArray())); + connection.dispatch(commandBuilder.discard()); + + assertThat(inTransaction.get()).isFalse(); + + verify(asyncCommands).auth(eq("user"), eq("pass")); + } + + @Test + void testSetCredentialsDispatchesAuthAfterTransactionInAnotherThread() throws InterruptedException { + AtomicBoolean inTransaction = ReflectionTestUtils.getField(connection, "inTransaction"); + + connection.dispatch(commandBuilder.multi()); + assertThat(inTransaction.get()).isTrue(); + + Thread thread = new Thread(() -> { + connection.setCredentials(new StaticRedisCredentials("user", "pass".toCharArray())); + }); + thread.start(); + + connection.dispatch(commandBuilder.discard()); + + thread.join(); + + assertThat(inTransaction.get()).isFalse(); + verify(asyncCommands).auth(eq("user"), eq("pass")); + } + +}