diff --git a/src/main/java/io/lettuce/core/protocol/DefaultEndpoint.java b/src/main/java/io/lettuce/core/protocol/DefaultEndpoint.java index d3b9e751fe..3ef913c086 100644 --- a/src/main/java/io/lettuce/core/protocol/DefaultEndpoint.java +++ b/src/main/java/io/lettuce/core/protocol/DefaultEndpoint.java @@ -15,8 +15,6 @@ */ package io.lettuce.core.protocol; -import static io.lettuce.core.protocol.CommandHandler.*; - import java.io.IOException; import java.nio.channels.ClosedChannelException; import java.util.ArrayList; @@ -52,6 +50,8 @@ import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; +import static io.lettuce.core.protocol.CommandHandler.SUPPRESS_IO_EXCEPTION_MESSAGES; + /** * Default {@link Endpoint} implementation. * @@ -63,11 +63,11 @@ public class DefaultEndpoint implements RedisChannelWriter, Endpoint, PushHandle private static final AtomicLong ENDPOINT_COUNTER = new AtomicLong(); - private static final AtomicIntegerFieldUpdater QUEUE_SIZE = AtomicIntegerFieldUpdater - .newUpdater(DefaultEndpoint.class, "queueSize"); + private static final AtomicIntegerFieldUpdater QUEUE_SIZE = AtomicIntegerFieldUpdater.newUpdater( + DefaultEndpoint.class, "queueSize"); - private static final AtomicIntegerFieldUpdater STATUS = AtomicIntegerFieldUpdater - .newUpdater(DefaultEndpoint.class, "status"); + private static final AtomicIntegerFieldUpdater STATUS = AtomicIntegerFieldUpdater.newUpdater( + DefaultEndpoint.class, "status"); private static final int ST_OPEN = 0; @@ -176,8 +176,9 @@ public List getPushListeners() { public RedisCommand write(RedisCommand command) { LettuceAssert.notNull(command, "Command must not be null"); + final boolean isActivationCommand = inActivation || ActivationCommand.isActivationCommand(command); - RedisException validation = validateWrite(1); + RedisException validation = validateWrite(1, isActivationCommand); if (validation != null) { command.completeExceptionally(validation); return command; @@ -190,10 +191,11 @@ public RedisCommand write(RedisCommand command) { command = processActivationCommand(command); } - if (autoFlushCommands) { + if (autoFlushCommands || isActivationCommand /* activation command should not get dangled in writeBuffer */) { - if (isConnected()) { - writeToChannelAndFlush(command); + final Channel chan = this.channel; + if (isConnected(chan)) { + writeToChannelAndFlush(chan, command); } else { writeToDisconnectedBuffer(command); } @@ -213,11 +215,12 @@ public RedisCommand write(RedisCommand command) { @SuppressWarnings("unchecked") @Override - public Collection> write(Collection> commands) { + public Collection> write( + Collection> commands /* commands should not contains activation command */) { LettuceAssert.notNull(commands, "Commands must not be null"); - RedisException validation = validateWrite(commands.size()); + RedisException validation = validateWrite(commands.size(), inActivation); if (validation != null) { commands.forEach(it -> it.completeExceptionally(validation)); @@ -231,10 +234,11 @@ public RedisCommand write(RedisCommand command) { commands = processActivationCommands(commands); } - if (autoFlushCommands) { + if (autoFlushCommands || inActivation /* activation commands should not get dangled in writeBuffer */) { - if (isConnected()) { - writeToChannelAndFlush(commands); + final Channel chan = this.channel; + if (isConnected(chan)) { + writeToChannelAndFlush(chan, commands); } else { writeToDisconnectedBuffer(commands); } @@ -278,13 +282,13 @@ private RedisCommand processActivationCommand(RedisCommand command) { commandBuffer.add(command); } - private void writeToChannelAndFlush(RedisCommand command) { + private void writeToChannelAndFlush(Channel channel, RedisCommand command) { QUEUE_SIZE.incrementAndGet(this); - ChannelFuture channelFuture = channelWriteAndFlush(command); + ChannelFuture channelFuture = channelWriteAndFlush(channel, command); if (reliability == Reliability.AT_MOST_ONCE) { // cancel on exceptions and remove from queue, because there is no housekeeping @@ -379,11 +384,11 @@ private void writeToChannelAndFlush(RedisCommand command) { if (reliability == Reliability.AT_LEAST_ONCE) { // commands are ok to stay within the queue, reconnect will retrigger them - channelFuture.addListener(RetryListener.newInstance(this, command)); + channelFuture.addListener(RetryListener.newInstance(this, command, channel)); } } - private void writeToChannelAndFlush(Collection> commands) { + private void writeToChannelAndFlush(Channel channel, Collection> commands) { QUEUE_SIZE.addAndGet(this, commands.size()); @@ -391,7 +396,7 @@ private void writeToChannelAndFlush(Collection> // cancel on exceptions and remove from queue, because there is no housekeeping for (RedisCommand command : commands) { - channelWrite(command).addListener(AtMostOnceWriteListener.newInstance(this, command)); + channelWrite(channel, command).addListener(AtMostOnceWriteListener.newInstance(this, command)); } } @@ -399,14 +404,14 @@ private void writeToChannelAndFlush(Collection> // commands are ok to stay within the queue, reconnect will retrigger them for (RedisCommand command : commands) { - channelWrite(command).addListener(RetryListener.newInstance(this, command)); + channelWrite(channel, command).addListener(RetryListener.newInstance(this, command, channel)); } } - channelFlush(); + channelFlush(channel); } - private void channelFlush() { + private void channelFlush(Channel channel) { if (debugEnabled) { logger.debug("{} write() channelFlush", logPrefix()); @@ -415,7 +420,7 @@ private void channelFlush() { channel.flush(); } - private ChannelFuture channelWrite(RedisCommand command) { + private ChannelFuture channelWrite(Channel channel, RedisCommand command) { if (debugEnabled) { logger.debug("{} write() channelWrite command {}", logPrefix(), command); @@ -424,7 +429,7 @@ private ChannelFuture channelWrite(RedisCommand command) { return channel.write(command); } - private ChannelFuture channelWriteAndFlush(RedisCommand command) { + private ChannelFuture channelWriteAndFlush(Channel channel, RedisCommand command) { if (debugEnabled) { logger.debug("{} write() writeAndFlush command {}", logPrefix(), command); @@ -549,7 +554,8 @@ private void flushCommands(Queue> queue) { logger.debug("{} flushCommands()", logPrefix()); } - if (isConnected()) { + Channel chan = this.channel; + if (isConnected(chan)) { List> commands = sharedLock.doExclusive(() -> { @@ -565,7 +571,7 @@ private void flushCommands(Queue> queue) { } if (!commands.isEmpty()) { - writeToChannelAndFlush(commands); + writeToChannelAndFlush(chan, commands); } } } @@ -764,8 +770,12 @@ protected T doExclusive(Supplier supplier) { RedisCommand cmd; while ((cmd = source.poll()) != null) { - if (!cmd.isDone() && !ActivationCommand.isActivationCommand(cmd)) { - target.add(cmd); + if (!cmd.isDone()) { + if (!ActivationCommand.isActivationCommand(cmd)) { + target.add(cmd); + } else { + cmd.completeExceptionally(new RedisException("Activation command not processed")); + } } } @@ -793,6 +803,10 @@ private boolean isConnected() { return channel != null && channel.isActive(); } + private boolean isConnected(Channel channel) { + return channel != null && channel.isActive(); + } + protected String logPrefix() { if (logPrefix != null) { @@ -829,30 +843,16 @@ private static boolean isRejectCommand(ClientOptions clientOptions) { static class ListenerSupport { - Collection> sentCommands; - RedisCommand sentCommand; DefaultEndpoint endpoint; void dequeue() { - - if (sentCommand != null) { - QUEUE_SIZE.decrementAndGet(endpoint); - } else { - QUEUE_SIZE.addAndGet(endpoint, -sentCommands.size()); - } + QUEUE_SIZE.decrementAndGet(endpoint); } protected void complete(Throwable t) { - - if (sentCommand != null) { - sentCommand.completeExceptionally(t); - } else { - for (RedisCommand sentCommand : sentCommands) { - sentCommand.completeExceptionally(t); - } - } + sentCommand.completeExceptionally(t); } } @@ -879,31 +879,20 @@ static AtMostOnceWriteListener newInstance(DefaultEndpoint endpoint, RedisComman AtMostOnceWriteListener entry = RECYCLER.get(); entry.endpoint = endpoint; - entry.sentCommand = command; - - return entry; - } - - static AtMostOnceWriteListener newInstance(DefaultEndpoint endpoint, - Collection> commands) { - - AtMostOnceWriteListener entry = RECYCLER.get(); - - entry.endpoint = endpoint; - entry.sentCommands = commands; return entry; } @Override public void operationComplete(ChannelFuture future) { - try { - dequeue(); - if (!future.isSuccess() && future.cause() != null) { - complete(future.cause()); + if (!future.isSuccess()) { + Throwable cause = future.cause(); + if (cause != null) { + complete(cause); + } } } finally { recycle(); @@ -914,7 +903,6 @@ private void recycle() { this.endpoint = null; this.sentCommand = null; - this.sentCommands = null; handle.recycle(this); } @@ -937,26 +925,19 @@ protected RetryListener newObject(Handle handle) { private final Recycler.Handle handle; + private Channel writeChan; + RetryListener(Recycler.Handle handle) { this.handle = handle; } - static RetryListener newInstance(DefaultEndpoint endpoint, RedisCommand command) { + static RetryListener newInstance(DefaultEndpoint endpoint, RedisCommand command, Channel writeChan) { RetryListener entry = RECYCLER.get(); entry.endpoint = endpoint; entry.sentCommand = command; - - return entry; - } - - static RetryListener newInstance(DefaultEndpoint endpoint, Collection> commands) { - - RetryListener entry = RECYCLER.get(); - - entry.endpoint = endpoint; - entry.sentCommands = commands; + entry.writeChan = writeChan; return entry; } @@ -973,9 +954,6 @@ public void operationComplete(Future future) { } private void doComplete(Future future) { - - Throwable cause = future.cause(); - boolean success = future.isSuccess(); dequeue(); @@ -983,17 +961,22 @@ private void doComplete(Future future) { return; } + Throwable cause = future.cause(); if (cause instanceof EncoderException || cause instanceof Error || cause.getCause() instanceof Error) { complete(cause); return; } - Channel channel = endpoint.channel; + final Channel channel = endpoint.channel; + if (channel != writeChan && ActivationCommand.isActivationCommand( + sentCommand) /* activation command should never be retried in a different connection */) { + complete(cause); + return; + } // Capture values before recycler clears these. RedisCommand sentCommand = this.sentCommand; - Collection> sentCommands = this.sentCommands; - potentiallyRequeueCommands(channel, sentCommand, sentCommands); + potentiallyRequeueCommands(channel, sentCommand); if (!(cause instanceof ClosedChannelException)) { @@ -1013,59 +996,26 @@ private void doComplete(Future future) { * * @param channel * @param sentCommand - * @param sentCommands */ - private void potentiallyRequeueCommands(Channel channel, RedisCommand sentCommand, - Collection> sentCommands) { - - if (sentCommand != null && sentCommand.isDone()) { + private void potentiallyRequeueCommands(Channel channel, RedisCommand sentCommand) { + if (sentCommand.isDone()) { return; } - if (sentCommands != null) { - - boolean foundToSend = false; - - for (RedisCommand command : sentCommands) { - if (!command.isDone()) { - foundToSend = true; - break; - } - } - - if (!foundToSend) { - return; - } - } - if (channel != null) { DefaultEndpoint endpoint = this.endpoint; - channel.eventLoop().submit(() -> { - requeueCommands(sentCommand, sentCommands, endpoint); - }); + channel.eventLoop().submit(() -> requeueCommands(sentCommand, endpoint)); } else { - requeueCommands(sentCommand, sentCommands, endpoint); + requeueCommands(sentCommand, endpoint); } } @SuppressWarnings({ "unchecked", "rawtypes" }) - private void requeueCommands(RedisCommand sentCommand, - Collection> sentCommands, DefaultEndpoint endpoint) { - - if (sentCommand != null) { - try { - endpoint.write(sentCommand); - } catch (Exception e) { - sentCommand.completeExceptionally(e); - } - } else { - try { - endpoint.write((Collection) sentCommands); - } catch (Exception e) { - for (RedisCommand command : sentCommands) { - command.completeExceptionally(e); - } - } + private void requeueCommands(RedisCommand sentCommand, DefaultEndpoint endpoint) { + try { + endpoint.write(sentCommand); + } catch (Exception e) { + sentCommand.completeExceptionally(e); } } @@ -1073,7 +1023,7 @@ private void recycle() { this.endpoint = null; this.sentCommand = null; - this.sentCommands = null; + this.writeChan = null; handle.recycle(this); } @@ -1088,8 +1038,6 @@ static class Lazy implements Supplier { private static final Lazy EMPTY = new Lazy<>(() -> null, null, true); - static final String UNRESOLVED = "[Unresolved]"; - private final Supplier supplier; private T value; diff --git a/src/test/java/io/lettuce/core/protocol/DefaultEndpointUnitTests.java b/src/test/java/io/lettuce/core/protocol/DefaultEndpointUnitTests.java index e7e36bd955..5b31c0b525 100644 --- a/src/test/java/io/lettuce/core/protocol/DefaultEndpointUnitTests.java +++ b/src/test/java/io/lettuce/core/protocol/DefaultEndpointUnitTests.java @@ -15,10 +15,6 @@ */ package io.lettuce.core.protocol; -import static org.assertj.core.api.Assertions.*; -import static org.mockito.ArgumentMatchers.*; -import static org.mockito.Mockito.*; - import java.nio.channels.ClosedChannelException; import java.util.Collection; import java.util.Collections; @@ -30,6 +26,21 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; +import io.lettuce.core.ClientOptions; +import io.lettuce.core.RedisException; +import io.lettuce.core.codec.StringCodec; +import io.lettuce.core.internal.LettuceFactories; +import io.lettuce.core.output.StatusOutput; +import io.lettuce.core.resource.ClientResources; +import io.lettuce.test.ConnectionTestUtil; +import io.lettuce.test.ReflectionTestUtils; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelPromise; +import io.netty.channel.DefaultChannelPromise; +import io.netty.channel.EventLoop; +import io.netty.handler.codec.EncoderException; +import io.netty.util.concurrent.ImmediateEventExecutor; import org.apache.logging.log4j.Level; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.core.LoggerContext; @@ -46,21 +57,13 @@ import org.mockito.junit.jupiter.MockitoSettings; import org.mockito.quality.Strictness; -import io.lettuce.core.ClientOptions; -import io.lettuce.core.RedisException; -import io.lettuce.core.codec.StringCodec; -import io.lettuce.core.internal.LettuceFactories; -import io.lettuce.core.output.StatusOutput; -import io.lettuce.core.resource.ClientResources; -import io.lettuce.test.ConnectionTestUtil; -import io.lettuce.test.ReflectionTestUtils; -import io.netty.channel.Channel; -import io.netty.channel.ChannelFuture; -import io.netty.channel.ChannelPromise; -import io.netty.channel.DefaultChannelPromise; -import io.netty.channel.EventLoop; -import io.netty.handler.codec.EncoderException; -import io.netty.util.concurrent.ImmediateEventExecutor; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; /** * @author Mark Paluch @@ -324,7 +327,7 @@ void closeAllowsOnlyOneCall() { @Test void retryListenerCompletesSuccessfullyAfterDeferredRequeue() { - DefaultEndpoint.RetryListener listener = DefaultEndpoint.RetryListener.newInstance(sut, command); + DefaultEndpoint.RetryListener listener = DefaultEndpoint.RetryListener.newInstance(sut, command, mock(Channel.class)); ChannelFuture future = mock(ChannelFuture.class); EventLoop eventLoopGroup = mock(EventLoop.class); @@ -350,7 +353,7 @@ void retryListenerCompletesSuccessfullyAfterDeferredRequeue() { @Test void retryListenerDoesNotRetryCompletedCommands() { - DefaultEndpoint.RetryListener listener = DefaultEndpoint.RetryListener.newInstance(sut, command); + DefaultEndpoint.RetryListener listener = DefaultEndpoint.RetryListener.newInstance(sut, command, mock(Channel.class)); when(channel.eventLoop()).thenReturn(mock(EventLoop.class)); @@ -396,11 +399,9 @@ void shouldNotReplayActivationCommands() { when(channel.isActive()).thenReturn(true); ConnectionTestUtil.getDisconnectedBuffer(sut) - .add(new ActivationCommand<>( - new Command<>(CommandType.SELECT, new StatusOutput<>(StringCodec.UTF8)))); + .add(new ActivationCommand<>(new Command<>(CommandType.SELECT, new StatusOutput<>(StringCodec.UTF8)))); ConnectionTestUtil.getDisconnectedBuffer(sut).add(new LatencyMeteredCommand<>( - new ActivationCommand<>( - new Command<>(CommandType.SUBSCRIBE, new StatusOutput<>(StringCodec.UTF8))))); + new ActivationCommand<>(new Command<>(CommandType.SUBSCRIBE, new StatusOutput<>(StringCodec.UTF8))))); doAnswer(i -> {