Skip to content

Commit

Permalink
fix inTransaction lock with dispatch command batch
Browse files Browse the repository at this point in the history
  • Loading branch information
ggivo committed Dec 17, 2024
1 parent 110eb1a commit 8e9ab48
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 11 deletions.
23 changes: 22 additions & 1 deletion src/main/java/io/lettuce/core/RedisAuthenticationHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,15 @@
import reactor.core.Disposable;
import reactor.core.publisher.Flux;

import java.util.Collection;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.locks.ReentrantLock;

import static io.lettuce.core.protocol.CommandType.AUTH;
import static io.lettuce.core.protocol.CommandType.DISCARD;
import static io.lettuce.core.protocol.CommandType.EXEC;
import static io.lettuce.core.protocol.CommandType.MULTI;

/**
* Redis authentication handler. Internally used to authenticate a Redis connection. This class is part of the internal API.
Expand Down Expand Up @@ -189,6 +191,25 @@ public void postProcess(RedisCommand<K, V, ?> toSend) {
}
}

public void postProcess(Collection<? extends RedisCommand<K, V, ?>> dispatched) {
Boolean transactionComplete = null;
for (RedisCommand<K, V, ?> command : dispatched) {
if (command.getType() == EXEC || command.getType() == DISCARD) {
transactionComplete = true;
}
if (command.getType() == MULTI) {
transactionComplete = false;
}
}

if (transactionComplete != null) {
if (transactionComplete) {
inTransaction.set(false);
setCredentials(credentialsRef.getAndSet(null));
}
}
}

/**
* Marks that the current connection has started a transaction.
* <p>
Expand Down Expand Up @@ -257,7 +278,7 @@ protected void dispatchAuth(RedisCredentials credentials) {
}

// dispatch directly to avoid AUTH preprocessing overrides credentials provider
RedisCommand<K, V, ?> auth = connection.dispatch(authCommand(credentials));
RedisCommand<K, V, ?> auth = connection.getChannelWriter().write(authCommand(credentials));
if (auth instanceof CompleteableCommand) {
((CompleteableCommand<?>) auth).onComplete((status, throwable) -> {
if (throwable != null) {
Expand Down
26 changes: 18 additions & 8 deletions src/main/java/io/lettuce/core/StatefulRedisConnectionImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -192,24 +192,34 @@ public <T> RedisCommand<K, V, T> dispatch(RedisCommand<K, V, T> command) {
@Override
public Collection<RedisCommand<K, V, ?>> dispatch(Collection<? extends RedisCommand<K, V, ?>> commands) {

List<RedisCommand<K, V, ?>> sentCommands = new ArrayList<>(commands.size());
Collection<RedisCommand<K, V, ?>> sentCommands = preProcessCommands(commands);

commands.forEach(o -> {
RedisCommand<K, V, ?> preprocessed = preProcessCommand(o);
sentCommands.add(preprocessed);
});
Collection<RedisCommand<K, V, ?>> dispatchedCommands = super.dispatch(sentCommands);

super.dispatch(sentCommands);
return this.postProcessCommands(dispatchedCommands);
}

sentCommands.forEach(this::postProcessCommand);
return sentCommands;
protected Collection<RedisCommand<K, V, ?>> postProcessCommands(Collection<RedisCommand<K, V, ?>> commands) {
authHandler.postProcess(commands);
return commands;
}

protected <T> RedisCommand<K, V, T> postProcessCommand(RedisCommand<K, V, T> command) {
authHandler.postProcess(command);
return command;
}

protected Collection<RedisCommand<K, V, ?>> preProcessCommands(Collection<? extends RedisCommand<K, V, ?>> commands) {
List<RedisCommand<K, V, ?>> sentCommands = new ArrayList<>(commands.size());

commands.forEach(o -> {
RedisCommand<K, V, ?> preprocessed = preProcessCommand(o);
sentCommands.add(preprocessed);
});

return sentCommands;
}

// TODO [tihomir.mateev] Refactor to include as part of the Command interface
// All these if statements clearly indicate this is a problem best solve by each command
// (defining a pre and post processing behaviour of the command)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ public class RedisAuthenticationHandlerUnitTests {

private StatefulRedisConnectionImpl<String, String> connection;

RedisChannelWriter writer;

ClientResources resources;

EventBus eventBus;
Expand All @@ -47,6 +49,7 @@ public class RedisAuthenticationHandlerUnitTests {
@BeforeEach
void setUp() {
eventBus = new DefaultEventBus(Schedulers.immediate());
writer = mock(RedisChannelWriter.class);
connection = mock(StatefulRedisConnectionImpl.class);
resources = mock(ClientResources.class);
when(resources.eventBus()).thenReturn(eventBus);
Expand All @@ -55,6 +58,7 @@ void setUp() {
when(connection.getResources()).thenReturn(resources);
when(connection.getCodec()).thenReturn(StringCodec.UTF8);
when(connection.getConnectionState()).thenReturn(connectionState);
when(connection.getChannelWriter()).thenReturn(writer);
}

@SuppressWarnings("unchecked")
Expand All @@ -70,7 +74,7 @@ void subscribeWithStreamingCredentialsProviderInvokesReauth() {
credentialsProvider.emitCredentials("newuser", "newpassword".toCharArray());

ArgumentCaptor<AsyncCommand<String, String, String>> captor = ArgumentCaptor.forClass(AsyncCommand.class);
verify(connection).dispatch(captor.capture());
verify(writer).write(captor.capture());

AsyncCommand<String, String, String> credentialsCommand = captor.getValue();
assertThat(credentialsCommand.getType()).isEqualTo(AUTH);
Expand Down Expand Up @@ -176,7 +180,7 @@ void testSetCredentialsDoesNotDispatchAuthIfInTransaction() {
handler.endTransaction();

ArgumentCaptor<AsyncCommand<String, String, String>> captor = ArgumentCaptor.forClass(AsyncCommand.class);
verify(connection).dispatch(captor.capture());
verify(writer).write(captor.capture());

AsyncCommand<String, String, String> credentialsCommand = captor.getValue();
assertThat(credentialsCommand.getType()).isEqualTo(AUTH);
Expand Down

0 comments on commit 8e9ab48

Please sign in to comment.