Skip to content

Commit

Permalink
Closes redis#2940 Sharded PubSub subscriptions not recovered after di…
Browse files Browse the repository at this point in the history
…sconnection and re-connection. (redis#3024)
  • Loading branch information
ggivo authored Oct 24, 2024
1 parent 29afe13 commit 8114354
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 5 deletions.
8 changes: 8 additions & 0 deletions src/main/java/io/lettuce/core/pubsub/PubSubEndpoint.java
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,14 @@ public Set<K> getChannels() {
return unwrap(this.channels);
}

public boolean hasShardChannelSubscriptions() {
return !shardChannels.isEmpty();
}

public Set<K> getShardChannels() {
return unwrap(this.shardChannels);
}

public boolean hasPatternSubscriptions() {
return !patterns.isEmpty();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,10 @@ protected List<RedisFuture<Void>> resubscribe() {
result.add(async().subscribe(toArray(endpoint.getChannels())));
}

if (endpoint.hasShardChannelSubscriptions()) {
result.add(async().ssubscribe(toArray(endpoint.getShardChannels())));
}

if (endpoint.hasPatternSubscriptions()) {
result.add(async().psubscribe(toArray(endpoint.getPatterns())));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ class PubSubCommandIntegrationTests extends AbstractRedisClientTest {

BlockingQueue<Long> counts = listener.getCounts();

BlockingQueue<Long> shardCounts = listener.getShardCounts();

String channel = "channel0";

String shardChannel = "shard-channel";
Expand Down Expand Up @@ -523,6 +525,24 @@ void resubscribePatternsOnReconnect() throws Exception {
assertThat(messages.take()).isEqualTo(message);
}

@Test
void resubscribeShardChannelsOnReconnect() throws Exception {
pubsub.ssubscribe(shardChannel);
assertThat(shardChannels.take()).isEqualTo(shardChannel);
assertThat((long) shardCounts.take()).isEqualTo(1);

pubsub.quit();

assertThat(shardChannels.take()).isEqualTo(shardChannel);
assertThat((long) shardCounts.take()).isEqualTo(1);

Wait.untilTrue(pubsub::isOpen).waitOrTimeout();

redis.spublish(shardChannel, shardMessage);
assertThat(shardChannels.take()).isEqualTo(shardChannel);
assertThat(messages.take()).isEqualTo(shardMessage);
}

@Test
void adapter() throws Exception {
final BlockingQueue<Long> localCounts = LettuceFactories.newBlockingQueue();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import org.junit.jupiter.api.Test;

import static io.lettuce.TestTags.UNIT_TEST;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.mockito.Mockito.*;

Expand Down Expand Up @@ -81,6 +81,7 @@ void resubscribeChannelSubscription() {
when(mockedEndpoint.hasChannelSubscriptions()).thenReturn(true);
when(mockedEndpoint.getChannels()).thenReturn(new HashSet<>(Arrays.asList(new String[] { "channel1", "channel2" })));
when(mockedEndpoint.hasPatternSubscriptions()).thenReturn(false);
when(mockedEndpoint.hasShardChannelSubscriptions()).thenReturn(false);

List<RedisFuture<Void>> subscriptions = connection.resubscribe();
RedisFuture<Void> commandFuture = subscriptions.get(0);
Expand All @@ -90,17 +91,35 @@ void resubscribeChannelSubscription() {
}

@Test
void resubscribeChannelAndPatternSubscription() {
void resubscribeShardChannelSubscription() {
when(mockedEndpoint.hasShardChannelSubscriptions()).thenReturn(true);
when(mockedEndpoint.getShardChannels())
.thenReturn(new HashSet<>(Arrays.asList(new String[] { "shard_channel1", "shard_channel2" })));
when(mockedEndpoint.hasChannelSubscriptions()).thenReturn(false);
when(mockedEndpoint.hasPatternSubscriptions()).thenReturn(false);

List<RedisFuture<Void>> subscriptions = connection.resubscribe();
RedisFuture<Void> commandFuture = subscriptions.get(0);

assertEquals(1, subscriptions.size());
assertInstanceOf(AsyncCommand.class, commandFuture);
}

@Test
void resubscribeChannelAndPatternAndShardChanelSubscription() {
when(mockedEndpoint.hasChannelSubscriptions()).thenReturn(true);
when(mockedEndpoint.getChannels()).thenReturn(new HashSet<>(Arrays.asList(new String[] { "channel1", "channel2" })));
when(mockedEndpoint.hasPatternSubscriptions()).thenReturn(true);
when(mockedEndpoint.hasShardChannelSubscriptions()).thenReturn(true);
when(mockedEndpoint.getChannels()).thenReturn(new HashSet<>(Arrays.asList(new String[] { "channel1", "channel2" })));
when(mockedEndpoint.getPatterns()).thenReturn(new HashSet<>(Arrays.asList(new String[] { "bcast*", "echo" })));

when(mockedEndpoint.getShardChannels())
.thenReturn(new HashSet<>(Arrays.asList(new String[] { "shard_channel1", "shard_channel2" })));
List<RedisFuture<Void>> subscriptions = connection.resubscribe();

assertEquals(2, subscriptions.size());
assertEquals(3, subscriptions.size());
assertInstanceOf(AsyncCommand.class, subscriptions.get(0));
assertInstanceOf(AsyncCommand.class, subscriptions.get(1));
assertInstanceOf(AsyncCommand.class, subscriptions.get(1));
}

}

0 comments on commit 8114354

Please sign in to comment.