diff --git a/src/main/java/io/lettuce/core/StatefulRedisConnectionImpl.java b/src/main/java/io/lettuce/core/StatefulRedisConnectionImpl.java
index 57d164dc2..c85191a1e 100644
--- a/src/main/java/io/lettuce/core/StatefulRedisConnectionImpl.java
+++ b/src/main/java/io/lettuce/core/StatefulRedisConnectionImpl.java
@@ -22,7 +22,6 @@
import static io.lettuce.core.ClientOptions.DEFAULT_JSON_PARSER;
import static io.lettuce.core.protocol.CommandType.*;
-import java.nio.CharBuffer;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collection;
@@ -355,6 +354,29 @@ public void setClientName(String clientName) {
dispatch((RedisCommand) async);
}
+ /**
+ * Authenticates the current connection using the provided credentials.
+ *
+ * Unlike using dispatch of {@link RedisAsyncCommands#auth}, this method defers the {@code AUTH} command if the connection is within an active
+ * transaction. The authentication command will only be dispatched after the enclosing {@code DISCARD} or {@code EXEC}
+ * command is executed, ensuring that authentication does not interfere with ongoing transactions.
+ *
+ *
+ * @param credentials the {@link RedisCredentials} to authenticate the connection. If {@code null}, no action is performed.
+ *
+ *
+ * Behavior:
+ *
+ * - If the provided credentials are {@code null}, the method exits immediately.
+ * - If a transaction is active (as indicated by {@code inTransaction}), the {@code AUTH} command is not dispatched
+ * immediately but deferred until the transaction ends.
+ * - If no transaction is active, the {@code AUTH} command is dispatched immediately using the provided
+ * credentials.
+ *
+ *
+ *
+ * @see RedisAsyncCommands#auth
+ */
public void setCredentials(RedisCredentials credentials) {
if (credentials == null) {
return;
@@ -363,7 +385,7 @@ public void setCredentials(RedisCredentials credentials) {
try {
credentialsRef.set(credentials);
if (!inTransaction.get()) {
- dispatchAuthCommand(credentialsRef.getAndSet(null));
+ dispatchAuth(credentialsRef.getAndSet(null));
}
} finally {
reAuthSafety.unlock();
@@ -394,16 +416,16 @@ public void setAuthenticationHandler(RedisAuthenticationHandler handler) {
authHandler = handler;
}
- private void dispatchAuthCommand(RedisCredentials credentials) {
+ protected void dispatchAuth(RedisCredentials credentials) {
if (credentials == null) {
return;
}
RedisFuture auth;
if (credentials.getUsername() != null) {
- auth = async().auth(credentials.getUsername(), CharBuffer.wrap(credentials.getPassword()));
+ auth = async().auth(credentials.getUsername(), String.valueOf(credentials.getPassword()));
} else {
- auth = async().auth(CharBuffer.wrap(credentials.getPassword()));
+ auth = async().auth(String.valueOf(credentials.getPassword()));
}
auth.thenRun(() -> {
publishReauthEvent();
@@ -441,5 +463,4 @@ private String getEpid() {
return ((Endpoint) writer).getId();
}
-
}
diff --git a/src/test/java/io/lettuce/core/StatefulRedisConnectionImplUnitTests.java b/src/test/java/io/lettuce/core/StatefulRedisConnectionImplUnitTests.java
new file mode 100644
index 000000000..245eef09b
--- /dev/null
+++ b/src/test/java/io/lettuce/core/StatefulRedisConnectionImplUnitTests.java
@@ -0,0 +1,142 @@
+package io.lettuce.core;
+
+import io.lettuce.core.codec.StringCodec;
+import io.lettuce.core.protocol.AsyncCommand;
+import io.lettuce.core.protocol.PushHandler;
+import io.lettuce.core.resource.ClientResources;
+import io.lettuce.core.tracing.Tracing;
+import io.lettuce.test.ReflectionTestUtils;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.ExtendWith;
+import org.mockito.Mock;
+import org.mockito.junit.jupiter.MockitoExtension;
+import org.mockito.junit.jupiter.MockitoSettings;
+import org.mockito.quality.Strictness;
+
+import java.lang.reflect.Field;
+import java.time.Duration;
+import java.util.concurrent.atomic.AtomicBoolean;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyString;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+@ExtendWith(MockitoExtension.class)
+@MockitoSettings(strictness = Strictness.LENIENT)
+public class StatefulRedisConnectionImplUnitTests extends TestSupport {
+
+ RedisCommandBuilder commandBuilder = new RedisCommandBuilder<>(StringCodec.UTF8);
+ StatefulRedisConnectionImpl connection;
+
+ @Mock
+ RedisAsyncCommandsImpl asyncCommands;
+
+ @Mock
+ PushHandler pushHandler;
+
+ @Mock
+ RedisChannelWriter writer;
+
+ @Mock
+ ClientResources clientResources;
+
+ @Mock
+ Tracing tracing;
+
+ @BeforeEach
+ void setup() throws NoSuchFieldException, IllegalAccessException {
+ when(writer.getClientResources()).thenReturn(clientResources);
+ when(clientResources.tracing()).thenReturn(tracing);
+ when(tracing.isEnabled()).thenReturn(false);
+ when(asyncCommands.auth(any(CharSequence.class)))
+ .thenAnswer( invocation -> {
+ String pass = invocation.getArgument(0);
+ AsyncCommand auth = new AsyncCommand<>(commandBuilder.auth(pass));
+ auth.complete();
+ return auth;
+ });
+ when(asyncCommands.auth(anyString(), any(CharSequence.class)))
+ .thenAnswer( invocation -> {
+ String user = invocation.getArgument(0); // Capture username
+ String pass = invocation.getArgument(1); // Capture password
+ AsyncCommand auth = new AsyncCommand<>(commandBuilder.auth(user, pass));
+ auth.complete();
+ return auth;
+ });
+
+ Field asyncField = StatefulRedisConnectionImpl.class.getDeclaredField("async");
+ asyncField.setAccessible(true);
+
+
+ connection = new StatefulRedisConnectionImpl<>(writer, pushHandler, StringCodec.UTF8, Duration.ofSeconds(1));
+ asyncField.set(connection,asyncCommands);
+ }
+
+ @Test
+ public void testSetCredentialsWhenCredentialsAreNull() {
+ connection.setCredentials(null);
+
+ verify(asyncCommands, never()).auth(any(CharSequence.class));
+ verify(asyncCommands, never()).auth(anyString(), any(CharSequence.class));
+ }
+
+ @Test
+ void testSetCredentialsDispatchesAuthWhenNotInTransaction() {
+ connection.setCredentials(new StaticRedisCredentials("user", "pass".toCharArray()));
+ verify(asyncCommands).auth(eq("user"), eq("pass"));
+ }
+
+
+ @Test
+ void testSetCredentialsDoesNotDispatchAuthIfInTransaction() {
+ AtomicBoolean inTransaction = ReflectionTestUtils.getField(connection, "inTransaction");
+ inTransaction.set(true);
+
+ connection.setCredentials(new StaticRedisCredentials("user", "pass".toCharArray()));
+
+ verify(asyncCommands, never()).auth(any(CharSequence.class));
+ verify(asyncCommands, never()).auth(anyString(), any(CharSequence.class));
+ }
+
+
+ @Test
+ void testSetCredentialsDispatchesAuthAfterTransaction() {
+ AtomicBoolean inTransaction = ReflectionTestUtils.getField(connection, "inTransaction");
+
+ connection.dispatch(commandBuilder.multi());
+ assertThat(inTransaction.get()).isTrue();
+
+ connection.setCredentials(new StaticRedisCredentials("user", "pass".toCharArray()));
+ connection.dispatch(commandBuilder.discard());
+
+ assertThat(inTransaction.get()).isFalse();
+
+ verify(asyncCommands).auth(eq("user"), eq("pass"));
+ }
+
+ @Test
+ void testSetCredentialsDispatchesAuthAfterTransactionInAnotherThread() throws InterruptedException {
+ AtomicBoolean inTransaction = ReflectionTestUtils.getField(connection, "inTransaction");
+
+ connection.dispatch(commandBuilder.multi());
+ assertThat(inTransaction.get()).isTrue();
+
+ Thread thread = new Thread(() -> {
+ connection.setCredentials(new StaticRedisCredentials("user", "pass".toCharArray()));
+ });
+ thread.start();
+
+ connection.dispatch(commandBuilder.discard());
+
+ thread.join();
+
+ assertThat(inTransaction.get()).isFalse();
+ verify(asyncCommands).auth(eq("user"), eq("pass"));
+ }
+
+}