diff --git a/src/main/java/io/lettuce/core/RedisCommandBuilder.java b/src/main/java/io/lettuce/core/RedisCommandBuilder.java index 46e5cb809b..68520174e3 100644 --- a/src/main/java/io/lettuce/core/RedisCommandBuilder.java +++ b/src/main/java/io/lettuce/core/RedisCommandBuilder.java @@ -2919,7 +2919,7 @@ Command stralgoLcs(StrAlgoArgs strAlgoArgs) { CommandArgs args = new CommandArgs<>(codec); strAlgoArgs.build(args); - return createCommand(STRALGO, new StringMatchResultOutput<>(codec, strAlgoArgs.isWithIdx()), args); + return createCommand(STRALGO, new StringMatchResultOutput<>(codec), args); } Command> sunion(K... keys) { diff --git a/src/main/java/io/lettuce/core/dynamic/output/OutputRegistry.java b/src/main/java/io/lettuce/core/dynamic/output/OutputRegistry.java index 0ce4557ef1..8d59004cf3 100644 --- a/src/main/java/io/lettuce/core/dynamic/output/OutputRegistry.java +++ b/src/main/java/io/lettuce/core/dynamic/output/OutputRegistry.java @@ -56,6 +56,8 @@ public class OutputRegistry { register(registry, StringListOutput.class, StringListOutput::new); register(registry, VoidOutput.class, VoidOutput::new); + register(registry, StringMatchResultOutput.class, StringMatchResultOutput::new); + BUILTIN.putAll(registry); } diff --git a/src/main/java/io/lettuce/core/output/StringMatchResultOutput.java b/src/main/java/io/lettuce/core/output/StringMatchResultOutput.java index 2653f8f2cc..2217217139 100644 --- a/src/main/java/io/lettuce/core/output/StringMatchResultOutput.java +++ b/src/main/java/io/lettuce/core/output/StringMatchResultOutput.java @@ -19,15 +19,16 @@ */ package io.lettuce.core.output; -import static io.lettuce.core.StringMatchResult.MatchedPosition; -import static io.lettuce.core.StringMatchResult.Position; +import io.lettuce.core.StringMatchResult; +import io.lettuce.core.codec.RedisCodec; import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.List; -import io.lettuce.core.StringMatchResult; -import io.lettuce.core.codec.RedisCodec; +import static io.lettuce.core.StringMatchResult.MatchedPosition; +import static io.lettuce.core.StringMatchResult.Position; /** * Command output for {@code STRALGO} returning {@link StringMatchResult}. @@ -37,7 +38,7 @@ */ public class StringMatchResultOutput extends CommandOutput { - private final boolean withIdx; + private static final ByteBuffer LEN = StandardCharsets.US_ASCII.encode("len"); private String matchString; @@ -45,30 +46,31 @@ public class StringMatchResultOutput extends CommandOutput positions; + private boolean readingLen = true; + private final List matchedPositions = new ArrayList<>(); - public StringMatchResultOutput(RedisCodec codec, boolean withIdx) { + public StringMatchResultOutput(RedisCodec codec) { super(codec, null); - this.withIdx = withIdx; } @Override public void set(ByteBuffer bytes) { - - if (!withIdx && matchString == null) { - matchString = (String) codec.decodeKey(bytes); - } + matchString = (String) codec.decodeKey(bytes); + readingLen = LEN.equals(bytes); } @Override public void set(long integer) { - - this.len = (int) integer; - - if (positions == null) { - positions = new ArrayList<>(); + if (readingLen) { + this.len = (int) integer; + } else { + if (positions == null) { + positions = new ArrayList<>(); + } + positions.add(integer); } - positions.add(integer); + matchString = null; } @Override diff --git a/src/test/java/io/lettuce/core/commands/StringCommandIntegrationTests.java b/src/test/java/io/lettuce/core/commands/StringCommandIntegrationTests.java index 0bee9425fd..59af685964 100644 --- a/src/test/java/io/lettuce/core/commands/StringCommandIntegrationTests.java +++ b/src/test/java/io/lettuce/core/commands/StringCommandIntegrationTests.java @@ -20,9 +20,11 @@ package io.lettuce.core.commands; import static io.lettuce.core.SetArgs.Builder.*; -import static io.lettuce.core.StringMatchResult.*; -import static org.assertj.core.api.Assertions.*; +import static io.lettuce.core.StringMatchResult.Position; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import java.lang.reflect.Proxy; import java.time.Duration; import java.time.Instant; import java.util.LinkedHashMap; @@ -31,19 +33,20 @@ import javax.inject.Inject; +import org.junit.jupiter.api.Assumptions; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; import org.junit.jupiter.api.extension.ExtendWith; -import io.lettuce.core.GetExArgs; -import io.lettuce.core.KeyValue; -import io.lettuce.core.RedisException; -import io.lettuce.core.SetArgs; -import io.lettuce.core.StrAlgoArgs; -import io.lettuce.core.StringMatchResult; -import io.lettuce.core.TestSupport; +import io.lettuce.core.*; +import io.lettuce.core.api.StatefulConnection; +import io.lettuce.core.api.StatefulRedisConnection; import io.lettuce.core.api.sync.RedisCommands; +import io.lettuce.core.dynamic.Commands; +import io.lettuce.core.dynamic.RedisCommandFactory; +import io.lettuce.core.dynamic.annotation.Command; +import io.lettuce.core.dynamic.annotation.Param; import io.lettuce.test.KeyValueStreamingAdapter; import io.lettuce.test.LettuceExtension; import io.lettuce.test.condition.EnabledOnCommand; @@ -373,4 +376,117 @@ void strAlgoWithIdx() { assertThat(matchResult.getLen()).isEqualTo(6); } + @Test + @EnabledOnCommand("LCS") + void lcs() { + redis.set("key1", "ohmytext"); + redis.set("key2", "mynewtext"); + + // LCS key1 key2 + CustomStringCommands commands = CustomStringCommands.instance(getConnection()); + StringMatchResult matchResult = commands.lcs("key1", "key2"); + assertThat(matchResult.getMatchString()).isEqualTo("mytext"); + + // LCS a b IDX MINMATCHLEN 4 WITHMATCHLEN + // Keys don't exist. + matchResult = commands.lcsMinMatchLenWithMatchLen("a", "b", 4); + assertThat(matchResult.getMatchString()).isNullOrEmpty(); + assertThat(matchResult.getLen()).isEqualTo(0); + } + + @Test + @EnabledOnCommand("LCS") + void lcsUsingKeys() { + + redis.set("key1{k}", "ohmytext"); + redis.set("key2{k}", "mynewtext"); + + CustomStringCommands commands = CustomStringCommands.instance(getConnection()); + + StringMatchResult matchResult = commands.lcs("key1{k}", "key2{k}"); + assertThat(matchResult.getMatchString()).isEqualTo("mytext"); + + // STRALGO LCS STRINGS a b + matchResult = commands.lcsMinMatchLenWithMatchLen("a", "b", 4); + assertThat(matchResult.getMatchString()).isNullOrEmpty(); + assertThat(matchResult.getLen()).isEqualTo(0); + } + + @Test + @EnabledOnCommand("LCS") + void lcsJustLen() { + redis.set("one", "ohmytext"); + redis.set("two", "mynewtext"); + + CustomStringCommands commands = CustomStringCommands.instance(getConnection()); + + StringMatchResult matchResult = commands.lcsLen("one", "two"); + + assertThat(matchResult.getLen()).isEqualTo(6); + } + + @Test + @EnabledOnCommand("LCS") + void lcsWithMinMatchLen() { + redis.set("key1", "ohmytext"); + redis.set("key2", "mynewtext"); + + CustomStringCommands commands = CustomStringCommands.instance(getConnection()); + + StringMatchResult matchResult = commands.lcsMinMatchLen("key1", "key2", 4); + + assertThat(matchResult.getMatchString()).isEqualTo("mytext"); + } + + @Test + @EnabledOnCommand("LCS") + void lcsMinMatchLenIdxMatchLen() { + redis.set("key1", "ohmytext"); + redis.set("key2", "mynewtext"); + + CustomStringCommands commands = CustomStringCommands.instance(getConnection()); + + // LCS key1 key2 IDX MINMATCHLEN 4 WITHMATCHLEN + StringMatchResult matchResult = commands.lcsMinMatchLenWithMatchLen("key1", "key2", 4); + + assertThat(matchResult.getMatches()).hasSize(1); + assertThat(matchResult.getMatches().get(0).getMatchLen()).isEqualTo(4); + + Position a = matchResult.getMatches().get(0).getA(); + Position b = matchResult.getMatches().get(0).getB(); + + assertThat(a.getStart()).isEqualTo(4); + assertThat(a.getEnd()).isEqualTo(7); + assertThat(b.getStart()).isEqualTo(5); + assertThat(b.getEnd()).isEqualTo(8); + assertThat(matchResult.getLen()).isEqualTo(6); + } + + protected StatefulConnection getConnection() { + StatefulRedisConnection src = redis.getStatefulConnection(); + Assumptions.assumeFalse(Proxy.isProxyClass(src.getClass()), "Redis connection is proxy, skipping."); + return src; + } + + private interface CustomStringCommands extends Commands { + + @Command("LCS :k1 :k2") + StringMatchResult lcs(@Param("k1") String k1, @Param("k2") String k2); + + @Command("LCS :k1 :k2 LEN") + StringMatchResult lcsLen(@Param("k1") String k1, @Param("k2") String k2); + + @Command("LCS :k1 :k2 MINMATCHLEN :mml") + StringMatchResult lcsMinMatchLen(@Param("k1") String k1, @Param("k2") String k2, @Param("mml") int mml); + + @Command("LCS :k1 :k2 IDX MINMATCHLEN :mml WITHMATCHLEN") + StringMatchResult lcsMinMatchLenWithMatchLen(@Param("k1") String k1, @Param("k2") String k2, @Param("mml") int mml); + + static CustomStringCommands instance(StatefulConnection conn) { + RedisCommandFactory factory = new RedisCommandFactory(conn); + return factory.getCommands(CustomStringCommands.class); + } + + } + } diff --git a/src/test/java/io/lettuce/core/output/StringMatchResultOutputUnitTests.java b/src/test/java/io/lettuce/core/output/StringMatchResultOutputUnitTests.java new file mode 100644 index 0000000000..428b1c0a85 --- /dev/null +++ b/src/test/java/io/lettuce/core/output/StringMatchResultOutputUnitTests.java @@ -0,0 +1,122 @@ +package io.lettuce.core.output; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import io.netty.buffer.Unpooled; +import org.junit.jupiter.api.Test; + +import io.lettuce.core.StringMatchResult; +import io.lettuce.core.codec.StringCodec; +import io.lettuce.core.protocol.ProtocolVersion; +import io.lettuce.core.protocol.RedisStateMachine; + +public class StringMatchResultOutputUnitTests { + + @Test + void parseManually() { + byte[] rawOne = "%2\r\n$7\r\nmatches\r\n*1\r\n*3\r\n*2\r\n:4\r\n:7\r\n*2\r\n:5\r\n:8\r\n:4\r\n$3\r\nlen\r\n:6\r\n" + .getBytes(StandardCharsets.US_ASCII); + byte[] rawTwo = "%2\r\n$3\r\nlen\r\n:6\r\n$7\r\nmatches\r\n*1\r\n*3\r\n*2\r\n:4\r\n:7\r\n*2\r\n:5\r\n:8\r\n:4\r\n" + .getBytes(StandardCharsets.US_ASCII); + RedisStateMachine rsm = new RedisStateMachine(); + rsm.setProtocolVersion(ProtocolVersion.RESP3); + + StringMatchResultOutput o1 = new StringMatchResultOutput<>(StringCodec.ASCII); + assertThat(rsm.decode(Unpooled.wrappedBuffer(rawOne), o1)).isTrue(); + + StringMatchResultOutput o2 = new StringMatchResultOutput<>(StringCodec.ASCII); + assertThat(rsm.decode(Unpooled.wrappedBuffer(rawTwo), o2)).isTrue(); + + Map res1 = transform(o1.get()); + Map res2 = transform(o2.get()); + + assertThat(res1).isEqualTo(res2); + } + + private Map transform(StringMatchResult result) { + Map obj = new HashMap<>(); + List matches = new ArrayList<>(); + for (StringMatchResult.MatchedPosition match : result.getMatches()) { + Map intra = new HashMap<>(); + Map a = new HashMap<>(); + Map b = new HashMap<>(); + a.put("start", match.getA().getStart()); + a.put("end", match.getA().getEnd()); + + b.put("start", match.getB().getStart()); + b.put("end", match.getB().getEnd()); + intra.put("a", a); + intra.put("b", b); + intra.put("matchLen", match.getMatchLen()); + matches.add(intra); + } + obj.put("matches", matches); + obj.put("len", result.getLen()); + return obj; + } + + @Test + void parseOnlyStringMatch() { + StringMatchResultOutput output = new StringMatchResultOutput<>(StringCodec.ASCII); + + String matchString = "some-string"; + output.set(ByteBuffer.wrap(matchString.getBytes())); + output.complete(0); + + StringMatchResult result = output.get(); + assertThat(result.getMatchString()).isEqualTo(matchString); + assertThat(result.getMatches()).isEmpty(); + assertThat(result.getLen()).isZero(); + } + + @Test + void parseOnlyLen() { + StringMatchResultOutput output = new StringMatchResultOutput<>(StringCodec.ASCII); + + output.set(42); + output.complete(0); + + StringMatchResult result = output.get(); + assertThat(result.getMatchString()).isNull(); + assertThat(result.getMatches()).isEmpty(); + assertThat(result.getLen()).isEqualTo(42); + } + + @Test + void parseLenAndMatchesWithIdx() { + StringMatchResultOutput output = new StringMatchResultOutput<>(StringCodec.ASCII); + + output.set(ByteBuffer.wrap("len".getBytes())); + output.set(42); + + output.set(ByteBuffer.wrap("matches".getBytes())); + output.set(0); + output.set(5); + output.set(10); + output.set(15); + + output.complete(2); + output.complete(0); + + StringMatchResult result = output.get(); + + assertThat(result.getMatchString()).isNull(); + assertThat(result.getLen()).isEqualTo(42); + assertThat(result.getMatches()).hasSize(1).satisfies(m -> assertMatchedPositions(m.get(0), 0, 5, 10, 15)); + } + + private void assertMatchedPositions(StringMatchResult.MatchedPosition match, int... expected) { + assertThat(match.getA().getStart()).isEqualTo(expected[0]); + assertThat(match.getA().getEnd()).isEqualTo(expected[1]); + assertThat(match.getB().getStart()).isEqualTo(expected[2]); + assertThat(match.getB().getEnd()).isEqualTo(expected[3]); + } + +}