diff --git a/src/main/java/org/springframework/data/redis/core/DefaultReactiveGeoOperations.java b/src/main/java/org/springframework/data/redis/core/DefaultReactiveGeoOperations.java index 4fa364f68d..60f0ca6e1f 100644 --- a/src/main/java/org/springframework/data/redis/core/DefaultReactiveGeoOperations.java +++ b/src/main/java/org/springframework/data/redis/core/DefaultReactiveGeoOperations.java @@ -91,7 +91,8 @@ public Mono add(K key, Map memberCoordinateMap) { Mono>> serializedList = Flux .fromIterable(() -> memberCoordinateMap.entrySet().iterator()) - .map(entry -> new GeoLocation<>(rawValue(entry.getKey()), entry.getValue())).collectList(); + .map(entry -> new GeoLocation<>(rawValue(entry.getKey()), entry.getValue())) + .collectList(); return serializedList.flatMap(list -> geoCommands.geoAdd(rawKey(key), list)); }); @@ -106,7 +107,8 @@ public Mono add(K key, Iterable> geoLocations) { return createMono(geoCommands -> { Mono>> serializedList = Flux.fromIterable(geoLocations) - .map(location -> new GeoLocation<>(rawValue(location.getName()), location.getPoint())).collectList(); + .map(location -> new GeoLocation<>(rawValue(location.getName()), location.getPoint())) + .collectList(); return serializedList.flatMap(list -> geoCommands.geoAdd(rawKey(key), list)); }); @@ -220,7 +222,7 @@ public Flux>> radius(K key, V member, double radius) { return createFlux(geoCommands -> geoCommands.geoRadiusByMember(rawKey(key), rawValue(member), new Distance(radius)) // - .map(this::readGeoResult)); + .map(this::readGeoResult)); } @Override @@ -265,7 +267,7 @@ public Mono delete(K key) { Assert.notNull(key, "Key must not be null"); - return template.doCreateMono(connection -> connection.keyCommands().del(rawKey(key))).map(l -> l != 0); + return template.doCreateMono(connection -> connection.keyCommands().del(rawKey(key))).map(count -> count != 0); } @Override @@ -274,10 +276,11 @@ public Flux>> search(K key, GeoReference reference, Assert.notNull(key, "Key must not be null"); Assert.notNull(reference, "GeoReference must not be null"); + GeoReference rawReference = getGeoReference(reference); - return createFlux(geoCommands -> geoCommands - .geoSearch(rawKey(key), rawReference, geoPredicate, args).map(this::readGeoResult)); + return createFlux(geoCommands -> geoCommands.geoSearch(rawKey(key), rawReference, geoPredicate, args) + .map(this::readGeoResult)); } @Override @@ -286,6 +289,7 @@ public Mono searchAndStore(K key, K destKey, GeoReference reference, Assert.notNull(key, "Key must not be null"); Assert.notNull(reference, "GeoReference must not be null"); + GeoReference rawReference = getGeoReference(reference); return createMono(geoCommands -> geoCommands.geoSearchStore(rawKey(destKey), rawKey(key), diff --git a/src/main/java/org/springframework/data/redis/core/DefaultReactiveHashOperations.java b/src/main/java/org/springframework/data/redis/core/DefaultReactiveHashOperations.java index da1050b487..b1fd142b8c 100644 --- a/src/main/java/org/springframework/data/redis/core/DefaultReactiveHashOperations.java +++ b/src/main/java/org/springframework/data/redis/core/DefaultReactiveHashOperations.java @@ -38,6 +38,7 @@ * * @author Mark Paluch * @author Christoph Strobl + * @author John Blum * @since 2.0 */ class DefaultReactiveHashOperations implements ReactiveHashOperations { @@ -62,7 +63,8 @@ public Mono remove(H key, Object... hashKeys) { Assert.noNullElements(hashKeys, "Hash keys must not contain null elements"); return createMono(hashCommands -> Flux.fromArray(hashKeys) // - .map(o -> (HK) o).map(this::rawHashKey) // + .map(hashKey -> (HK) hashKey) + .map(this::rawHashKey) // .collectList() // .flatMap(hks -> hashCommands.hDel(rawKey(key), hks))); } @@ -84,8 +86,8 @@ public Mono get(H key, Object hashKey) { Assert.notNull(key, "Key must not be null"); Assert.notNull(hashKey, "Hash key must not be null"); - return createMono(hashCommands -> - hashCommands.hGet(rawKey(key), rawHashKey((HK) hashKey)).map(this::readHashValue)); + return createMono(hashCommands -> hashCommands.hGet(rawKey(key), rawHashKey((HK) hashKey)) + .map(this::readHashValue)); } @Override @@ -107,8 +109,7 @@ public Mono increment(H key, HK hashKey, long delta) { Assert.notNull(key, "Key must not be null"); Assert.notNull(hashKey, "Hash key must not be null"); - return template.doCreateMono(connection -> connection // - .numberCommands() // + return template.doCreateMono(connection -> connection.numberCommands() .hIncrBy(rawKey(key), rawHashKey(hashKey), delta)); } @@ -118,8 +119,7 @@ public Mono increment(H key, HK hashKey, double delta) { Assert.notNull(key, "Key must not be null"); Assert.notNull(hashKey, "Hash key must not be null"); - return template.doCreateMono(connection -> connection // - .numberCommands() // + return template.doCreateMono(connection -> connection.numberCommands() .hIncrBy(rawKey(key), rawHashKey(hashKey), delta)); } @@ -128,8 +128,8 @@ public Mono randomKey(H key) { Assert.notNull(key, "Key must not be null"); - return template.doCreateMono(connection -> connection // - .hashCommands().hRandField(rawKey(key))).map(this::readRequiredHashKey); + return template.doCreateMono(connection -> connection.hashCommands().hRandField(rawKey(key))) + .map(this::readRequiredHashKey); } @Override @@ -137,7 +137,8 @@ public Mono> randomEntry(H key) { Assert.notNull(key, "Key must not be null"); - return createMono(hashCommands ->hashCommands.hRandFieldWithValues(rawKey(key))).map(this::deserializeHashEntry); + return createMono(hashCommands -> hashCommands.hRandFieldWithValues(rawKey(key))) + .map(this::deserializeHashEntry); } @Override @@ -145,8 +146,8 @@ public Flux randomKeys(H key, long count) { Assert.notNull(key, "Key must not be null"); - return template.doCreateFlux(connection -> connection // - .hashCommands().hRandField(rawKey(key), count)).map(this::readRequiredHashKey); + return template.doCreateFlux(connection -> connection.hashCommands().hRandField(rawKey(key), count)) + .map(this::readRequiredHashKey); } @Override @@ -154,8 +155,8 @@ public Flux> randomEntries(H key, long count) { Assert.notNull(key, "Key must not be null"); - return template.doCreateFlux(connection -> connection // - .hashCommands().hRandFieldWithValues(rawKey(key), count)).map(this::deserializeHashEntry); + return template.doCreateFlux(connection -> connection.hashCommands().hRandFieldWithValues(rawKey(key), count)) + .map(this::deserializeHashEntry); } @Override @@ -211,7 +212,7 @@ public Flux values(H key) { Assert.notNull(key, "Key must not be null"); - return createFlux(connection -> connection.hVals(rawKey(key)) // + return createFlux(hashCommands -> hashCommands.hVals(rawKey(key)) // .map(this::readRequiredHashValue)); } @@ -278,28 +279,28 @@ private HK readRequiredHashKey(ByteBuffer buffer) { HK hashKey = readHashKey(buffer); - if (hashKey == null) { - throw new InvalidDataAccessApiUsageException("Deserialized hash key is null"); + if (hashKey != null) { + return hashKey; } - return hashKey; + throw new InvalidDataAccessApiUsageException("Deserialized hash key is null"); } @SuppressWarnings("unchecked") @Nullable private HV readHashValue(@Nullable ByteBuffer value) { - return (HV) (value == null ? null : serializationContext.getHashValueSerializationPair().read(value)); + return value != null ? (HV) serializationContext.getHashValueSerializationPair().read(value) : null; } private HV readRequiredHashValue(ByteBuffer buffer) { HV hashValue = readHashValue(buffer); - if (hashValue == null) { - throw new InvalidDataAccessApiUsageException("Deserialized hash value is null"); + if (hashValue != null) { + return hashValue; } - return hashValue; + throw new InvalidDataAccessApiUsageException("Deserialized hash value is null"); } private Map.Entry deserializeHashEntry(Map.Entry source) { @@ -309,9 +310,11 @@ private Map.Entry deserializeHashEntry(Map.Entry private List deserializeHashValues(List source) { List values = new ArrayList<>(source.size()); + for (ByteBuffer byteBuffer : source) { values.add(readHashValue(byteBuffer)); } + return values; } } diff --git a/src/main/java/org/springframework/data/redis/core/DefaultReactiveListOperations.java b/src/main/java/org/springframework/data/redis/core/DefaultReactiveListOperations.java index 032b47e350..cf863e7839 100644 --- a/src/main/java/org/springframework/data/redis/core/DefaultReactiveListOperations.java +++ b/src/main/java/org/springframework/data/redis/core/DefaultReactiveListOperations.java @@ -351,12 +351,12 @@ private V readValue(ByteBuffer buffer) { private V readRequiredValue(ByteBuffer buffer) { - V v = readValue(buffer); + V value = readValue(buffer); - if (v == null) { - throw new InvalidDataAccessApiUsageException("Deserialized list value is null"); + if (value != null) { + return value; } - return v; + throw new InvalidDataAccessApiUsageException("Deserialized list value is null"); } } diff --git a/src/main/java/org/springframework/data/redis/core/DefaultReactiveSetOperations.java b/src/main/java/org/springframework/data/redis/core/DefaultReactiveSetOperations.java index 07bde24301..63dc0debae 100644 --- a/src/main/java/org/springframework/data/redis/core/DefaultReactiveSetOperations.java +++ b/src/main/java/org/springframework/data/redis/core/DefaultReactiveSetOperations.java @@ -40,6 +40,7 @@ * @author Mark Paluch * @author Christoph Strobl * @author Roman Bezpalko + * @author John Blum * @since 2.0 */ class DefaultReactiveSetOperations implements ReactiveSetOperations { @@ -424,12 +425,12 @@ private V readValue(ByteBuffer buffer) { private V readRequiredValue(ByteBuffer buffer) { - V v = readValue(buffer); + V value = readValue(buffer); - if (v == null) { - throw new InvalidDataAccessApiUsageException("Deserialized set value is null"); + if (value != null) { + return value; } - return v; + throw new InvalidDataAccessApiUsageException("Deserialized set value is null"); } } diff --git a/src/main/java/org/springframework/data/redis/core/DefaultReactiveValueOperations.java b/src/main/java/org/springframework/data/redis/core/DefaultReactiveValueOperations.java index f08b74800c..dd59f5bd18 100644 --- a/src/main/java/org/springframework/data/redis/core/DefaultReactiveValueOperations.java +++ b/src/main/java/org/springframework/data/redis/core/DefaultReactiveValueOperations.java @@ -44,6 +44,7 @@ * @author Mark Paluch * @author Christoph Strobl * @author Jiahe Cai + * @author John Blum * @since 2.0 */ class DefaultReactiveValueOperations implements ReactiveValueOperations { @@ -336,13 +337,13 @@ private V readValue(ByteBuffer buffer) { private V readRequiredValue(ByteBuffer buffer) { - V v = readValue(buffer); + V value = readValue(buffer); - if (v == null) { - throw new InvalidDataAccessApiUsageException("Deserialized value is null"); + if (value != null) { + return value; } - return v; + throw new InvalidDataAccessApiUsageException("Deserialized value is null"); } private SerializationPair stringSerializationPair() { @@ -372,5 +373,4 @@ private List deserializeValues(List source) { return result; } - } diff --git a/src/main/java/org/springframework/data/redis/core/DefaultReactiveZSetOperations.java b/src/main/java/org/springframework/data/redis/core/DefaultReactiveZSetOperations.java index 696a80992f..469264ccd9 100644 --- a/src/main/java/org/springframework/data/redis/core/DefaultReactiveZSetOperations.java +++ b/src/main/java/org/springframework/data/redis/core/DefaultReactiveZSetOperations.java @@ -38,6 +38,7 @@ import org.springframework.data.redis.core.ZSetOperations.TypedTuple; import org.springframework.data.redis.serializer.RedisSerializationContext; import org.springframework.data.redis.util.ByteUtils; +import org.springframework.data.redis.util.RedisAssertions; import org.springframework.lang.Nullable; import org.springframework.util.Assert; @@ -744,13 +745,8 @@ private V readValue(ByteBuffer buffer) { private V readRequiredValue(ByteBuffer buffer) { - V v = readValue(buffer); - - if (v == null) { - throw new InvalidDataAccessApiUsageException("Deserialized sorted set value is null"); - } - - return v; + return RedisAssertions.requireNonNull(readValue(buffer), + () -> new InvalidDataAccessApiUsageException("Deserialized sorted set value is null")); } private TypedTuple readTypedTuple(Tuple raw) { diff --git a/src/main/java/org/springframework/data/redis/core/DefaultZSetOperations.java b/src/main/java/org/springframework/data/redis/core/DefaultZSetOperations.java index 1e290a45b7..cc663f21f5 100644 --- a/src/main/java/org/springframework/data/redis/core/DefaultZSetOperations.java +++ b/src/main/java/org/springframework/data/redis/core/DefaultZSetOperations.java @@ -42,6 +42,7 @@ * @author Wongoo (望哥) * @author Andrey Shlykov * @author Shyngys Sapraliyev + * @author John Blum */ class DefaultZSetOperations extends AbstractOperations implements ZSetOperations { @@ -54,6 +55,7 @@ public Boolean add(K key, V value, double score) { byte[] rawKey = rawKey(key); byte[] rawValue = rawValue(value); + return execute(connection -> connection.zAdd(rawKey, score, rawValue)); } @@ -74,6 +76,7 @@ protected Boolean add(K key, V value, double score, ZAddArgs args) { byte[] rawKey = rawKey(key); byte[] rawValue = rawValue(value); + return execute(connection -> connection.zAdd(rawKey, score, rawValue, args)); } @@ -82,6 +85,7 @@ public Long add(K key, Set> tuples) { byte[] rawKey = rawKey(key); Set rawValues = rawTupleValues(tuples); + return execute(connection -> connection.zAdd(rawKey, rawValues)); } @@ -102,6 +106,7 @@ protected Long add(K key, Set> tuples, ZAddArgs args) { byte[] rawKey = rawKey(key); Set rawValues = rawTupleValues(tuples); + return execute(connection -> connection.zAdd(rawKey, rawValues, args)); } @@ -110,6 +115,7 @@ public Double incrementScore(K key, V value, double delta) { byte[] rawKey = rawKey(key); byte[] rawValue = rawValue(value); + return execute(connection -> connection.zIncrBy(rawKey, delta, rawValue)); } @@ -127,8 +133,8 @@ public Set distinctRandomMembers(K key, long count) { Assert.isTrue(count > 0, "Negative count not supported; Use randomMembers to allow duplicate elements"); byte[] rawKey = rawKey(key); - List result = execute(connection -> connection.zRandMember(rawKey, count)); + return result != null ? deserializeValues(new LinkedHashSet<>(result)) : null; } @@ -138,8 +144,8 @@ public List randomMembers(K key, long count) { Assert.isTrue(count > 0, "Use a positive number for count; This method is already allowing duplicate elements"); byte[] rawKey = rawKey(key); - List result = execute(connection -> connection.zRandMember(rawKey, count)); + return deserializeValues(result); } @@ -157,8 +163,8 @@ public Set> distinctRandomMembersWithScore(K key, long count) { Assert.isTrue(count > 0, "Negative count not supported; Use randomMembers to allow duplicate elements"); byte[] rawKey = rawKey(key); - List result = execute(connection -> connection.zRandMemberWithScore(rawKey, count)); + return result != null ? deserializeTupleValues(new LinkedHashSet<>(result)) : null; } @@ -168,8 +174,8 @@ public List> randomMembersWithScore(K key, long count) { Assert.isTrue(count > 0, "Use a positive number for count; This method is already allowing duplicate elements"); byte[] rawKey = rawKey(key); - List result = execute(connection -> connection.zRandMemberWithScore(rawKey, count)); + return result != null ? deserializeTupleValues(result) : null; } @@ -229,29 +235,37 @@ public Set reverseRangeByLex(K key, Range range, Limit limit) { @Override public Long rangeAndStoreByLex(K srcKey, K dstKey, Range range, Limit limit) { + byte[] rawDstKey = rawKey(dstKey); byte[] rawSrcKey = rawKey(srcKey); + return execute(connection -> connection.zRangeStoreByLex(rawDstKey, rawSrcKey, serialize(range), limit)); } @Override public Long reverseRangeAndStoreByLex(K srcKey, K dstKey, Range range, Limit limit) { + byte[] rawDstKey = rawKey(dstKey); byte[] rawSrcKey = rawKey(srcKey); + return execute(connection -> connection.zRangeStoreRevByLex(rawDstKey, rawSrcKey, serialize(range), limit)); } @Override public Long rangeAndStoreByScore(K srcKey, K dstKey, Range range, Limit limit) { + byte[] rawDstKey = rawKey(dstKey); byte[] rawSrcKey = rawKey(srcKey); + return execute(connection -> connection.zRangeStoreByScore(rawDstKey, rawSrcKey, range, limit)); } @Override public Long reverseRangeAndStoreByScore(K srcKey, K dstKey, Range range, Limit limit) { + byte[] rawDstKey = rawKey(dstKey); byte[] rawSrcKey = rawKey(srcKey); + return execute(connection -> connection.zRangeStoreRevByScore(rawDstKey, rawSrcKey, range, limit)); } @@ -322,8 +336,8 @@ public Set> reverseRangeByScoreWithScores(K key, double min, doubl public Set> reverseRangeByScoreWithScores(K key, double min, double max, long offset, long count) { byte[] rawKey = rawKey(key); - Set rawValues = execute( - connection -> connection.zRevRangeByScoreWithScores(rawKey, min, max, offset, count)); + Set rawValues = execute(connection -> + connection.zRevRangeByScoreWithScores(rawKey, min, max, offset, count)); return deserializeTupleValues(rawValues); } @@ -365,6 +379,7 @@ public Long remove(K key, Object... values) { public Long removeRange(K key, long start, long end) { byte[] rawKey = rawKey(key); + return execute(connection -> connection.zRemRange(rawKey, start, end)); } @@ -372,6 +387,7 @@ public Long removeRange(K key, long start, long end) { public Long removeRangeByLex(K key, Range range) { byte[] rawKey = rawKey(key); + return execute(connection -> connection.zRemRangeByLex(rawKey, serialize(range))); } @@ -379,6 +395,7 @@ public Long removeRangeByLex(K key, Range range) { public Long removeRangeByScore(K key, double min, double max) { byte[] rawKey = rawKey(key); + return execute(connection -> connection.zRemRangeByScore(rawKey, min, max)); } @@ -387,6 +404,7 @@ public Double score(K key, Object o) { byte[] rawKey = rawKey(key); byte[] rawValue = rawValue(o); + return execute(connection -> connection.zScore(rawKey, rawValue)); } @@ -395,6 +413,7 @@ public List score(K key, Object... o) { byte[] rawKey = rawKey(key); byte[][] rawValues = rawValues(o); + return execute(connection -> connection.zMScore(rawKey, rawValues)); } @@ -402,6 +421,7 @@ public List score(K key, Object... o) { public Long count(K key, double min, double max) { byte[] rawKey = rawKey(key); + return execute(connection -> connection.zCount(rawKey, min, max)); } @@ -409,6 +429,7 @@ public Long count(K key, double min, double max) { public Long lexCount(K key, Range range) { byte[] rawKey = rawKey(key); + return execute(connection -> connection.zLexCount(rawKey, serialize(range))); } @@ -417,6 +438,7 @@ public Long lexCount(K key, Range range) { public TypedTuple popMin(K key) { byte[] rawKey = rawKey(key); + return deserializeTuple(execute(connection -> connection.zPopMin(rawKey))); } @@ -426,6 +448,7 @@ public Set> popMin(K key, long count) { byte[] rawKey = rawKey(key); Set result = execute(connection -> connection.zPopMin(rawKey, count)); + return deserializeTupleValues(new LinkedHashSet<>(result)); } @@ -434,6 +457,7 @@ public Set> popMin(K key, long count) { public TypedTuple popMin(K key, long timeout, TimeUnit unit) { byte[] rawKey = rawKey(key); + return deserializeTuple(execute(connection -> connection.bZPopMin(rawKey, timeout, unit))); } @@ -442,6 +466,7 @@ public TypedTuple popMin(K key, long timeout, TimeUnit unit) { public TypedTuple popMax(K key) { byte[] rawKey = rawKey(key); + return deserializeTuple(execute(connection -> connection.zPopMax(rawKey))); } @@ -451,6 +476,7 @@ public Set> popMax(K key, long count) { byte[] rawKey = rawKey(key); Set result = execute(connection -> connection.zPopMax(rawKey, count)); + return deserializeTupleValues(new LinkedHashSet<>(result)); } @@ -459,6 +485,7 @@ public Set> popMax(K key, long count) { public TypedTuple popMax(K key, long timeout, TimeUnit unit) { byte[] rawKey = rawKey(key); + return deserializeTuple(execute(connection -> connection.bZPopMax(rawKey, timeout, unit))); } @@ -471,6 +498,7 @@ public Long size(K key) { public Long zCard(K key) { byte[] rawKey = rawKey(key); + return execute(connection -> connection.zCard(rawKey)); } @@ -479,6 +507,7 @@ public Set difference(K key, Collection otherKeys) { byte[][] rawKeys = rawKeys(key, otherKeys); Set rawValues = execute(connection -> connection.zDiff(rawKeys)); + return deserializeValues(rawValues); } @@ -487,6 +516,7 @@ public Set> differenceWithScores(K key, Collection otherKeys) { byte[][] rawKeys = rawKeys(key, otherKeys); Set result = execute(connection -> connection.zDiffWithScores(rawKeys)); + return deserializeTupleValues(new LinkedHashSet<>(result)); } @@ -504,6 +534,7 @@ public Set intersect(K key, Collection otherKeys) { byte[][] rawKeys = rawKeys(key, otherKeys); Set rawValues = execute(connection -> connection.zInter(rawKeys)); + return deserializeValues(rawValues); } @@ -512,6 +543,7 @@ public Set> intersectWithScores(K key, Collection otherKeys) { byte[][] rawKeys = rawKeys(key, otherKeys); Set result = execute(connection -> connection.zInterWithScores(rawKeys)); + return deserializeTupleValues(result); } @@ -520,6 +552,7 @@ public Set> intersectWithScores(K key, Collection otherKeys, Ag byte[][] rawKeys = rawKeys(key, otherKeys); Set result = execute(connection -> connection.zInterWithScores(aggregate, weights, rawKeys)); + return deserializeTupleValues(result); } @@ -551,6 +584,7 @@ public Set union(K key, Collection otherKeys) { byte[][] rawKeys = rawKeys(key, otherKeys); Set rawValues = execute(connection -> connection.zUnion(rawKeys)); + return deserializeValues(rawValues); } @@ -559,6 +593,7 @@ public Set> unionWithScores(K key, Collection otherKeys) { byte[][] rawKeys = rawKeys(key, otherKeys); Set result = execute(connection -> connection.zUnionWithScores(rawKeys)); + return deserializeTupleValues(result); } @@ -567,6 +602,7 @@ public Set> unionWithScores(K key, Collection otherKeys, Aggreg byte[][] rawKeys = rawKeys(key, otherKeys); Set result = execute(connection -> connection.zUnionWithScores(aggregate, weights, rawKeys)); + return deserializeTupleValues(result); } @@ -605,6 +641,7 @@ public Cursor> scan(K key, ScanOptions options) { public Set rangeByScore(K key, String min, String max) { byte[] rawKey = rawKey(key); + return execute(connection -> connection.zRangeByScore(rawKey, min, max)); } @@ -632,5 +669,4 @@ private Range.Bound rawBound(Range.Bound source) { .map(it -> source.isInclusive() ? Range.Bound.inclusive(it) : Range.Bound.exclusive(it)) .orElseGet(Range.Bound::unbounded); } - } diff --git a/src/main/java/org/springframework/data/redis/core/ReactiveRedisTemplate.java b/src/main/java/org/springframework/data/redis/core/ReactiveRedisTemplate.java index 6745a47a5c..5d52051a2a 100644 --- a/src/main/java/org/springframework/data/redis/core/ReactiveRedisTemplate.java +++ b/src/main/java/org/springframework/data/redis/core/ReactiveRedisTemplate.java @@ -64,9 +64,10 @@ * @author Mark Paluch * @author Christoph Strobl * @author Petromir Dzhunev - * @since 2.0 + * @author John Blum * @param the Redis key type against which the template works (usually a String) * @param the Redis value type against which the template works + * @since 2.0 */ public class ReactiveRedisTemplate implements ReactiveRedisOperations { @@ -680,10 +681,10 @@ private K readRequiredKey(ByteBuffer buffer) { K key = readKey(buffer); - if (key == null) { - throw new InvalidDataAccessApiUsageException("Deserialized key is null"); + if (key != null) { + return key; } - return key; + throw new InvalidDataAccessApiUsageException("Deserialized key is null"); } } diff --git a/src/main/java/org/springframework/data/redis/core/ReactiveStreamOperations.java b/src/main/java/org/springframework/data/redis/core/ReactiveStreamOperations.java index 9314bc2383..73bf8f0d25 100644 --- a/src/main/java/org/springframework/data/redis/core/ReactiveStreamOperations.java +++ b/src/main/java/org/springframework/data/redis/core/ReactiveStreamOperations.java @@ -26,11 +26,22 @@ import org.springframework.data.domain.Range; import org.springframework.data.redis.connection.Limit; import org.springframework.data.redis.connection.RedisStreamCommands.XClaimOptions; -import org.springframework.data.redis.connection.stream.*; +import org.springframework.data.redis.connection.stream.ByteBufferRecord; +import org.springframework.data.redis.connection.stream.Consumer; +import org.springframework.data.redis.connection.stream.MapRecord; +import org.springframework.data.redis.connection.stream.ObjectRecord; +import org.springframework.data.redis.connection.stream.PendingMessage; +import org.springframework.data.redis.connection.stream.PendingMessages; +import org.springframework.data.redis.connection.stream.PendingMessagesSummary; +import org.springframework.data.redis.connection.stream.ReadOffset; import org.springframework.data.redis.connection.stream.Record; +import org.springframework.data.redis.connection.stream.RecordId; import org.springframework.data.redis.connection.stream.StreamInfo.XInfoConsumer; import org.springframework.data.redis.connection.stream.StreamInfo.XInfoGroup; import org.springframework.data.redis.connection.stream.StreamInfo.XInfoStream; +import org.springframework.data.redis.connection.stream.StreamOffset; +import org.springframework.data.redis.connection.stream.StreamReadOptions; +import org.springframework.data.redis.connection.stream.StreamRecords; import org.springframework.data.redis.hash.HashMapper; import org.springframework.lang.Nullable; import org.springframework.util.Assert; diff --git a/src/main/java/org/springframework/data/redis/core/script/DefaultReactiveScriptExecutor.java b/src/main/java/org/springframework/data/redis/core/script/DefaultReactiveScriptExecutor.java index b85c13047a..a0c446d58a 100644 --- a/src/main/java/org/springframework/data/redis/core/script/DefaultReactiveScriptExecutor.java +++ b/src/main/java/org/springframework/data/redis/core/script/DefaultReactiveScriptExecutor.java @@ -40,6 +40,7 @@ * * @author Mark Paluch * @author Christoph Strobl + * @author John Blum * @param The type of keys that may be passed during script execution * @since 2.0 */ @@ -105,44 +106,41 @@ protected Flux eval(ReactiveRedisConnection connection, RedisScript sc Flux result = connection.scriptingCommands().evalSha(script.getSha1(), returnType, numKeys, keysAndArgs); - result = result.onErrorResume(e -> { + result = result.onErrorResume(cause -> { - if (ScriptUtils.exceptionContainsNoScriptError(e)) { + if (ScriptUtils.exceptionContainsNoScriptError(cause)) { return connection.scriptingCommands().eval(scriptBytes(script), returnType, numKeys, keysAndArgs); } - return Flux - .error(e instanceof RuntimeException ? (RuntimeException) e : new RedisSystemException(e.getMessage(), e)); + return Flux.error(cause instanceof RuntimeException ? cause + : new RedisSystemException(cause.getMessage(), cause)); }); return script.returnsRawValue() ? result : deserializeResult(resultReader, result); } - @SuppressWarnings("Convert2MethodRef") + @SuppressWarnings({ "Convert2MethodRef", "rawtypes", "unchecked" }) protected ByteBuffer[] keysAndArgs(RedisElementWriter argsWriter, List keys, List args) { return Stream.concat(keys.stream().map(t -> keySerializer().getWriter().write(t)), args.stream().map(t -> argsWriter.write(t))).toArray(size -> new ByteBuffer[size]); } - /** - * @param script - * @return - */ protected ByteBuffer scriptBytes(RedisScript script) { return serializationContext.getStringSerializationPair().getWriter().write(script.getScriptAsString()); } protected Flux deserializeResult(RedisElementReader reader, Flux result) { + return result.map(it -> { T value = ScriptUtils.deserializeResult(reader, it); - if (value == null) { - throw new InvalidDataAccessApiUsageException("Deserialized script result is null"); + if (value != null) { + return value; } - return value; + throw new InvalidDataAccessApiUsageException("Deserialized script result is null"); }); } @@ -169,6 +167,6 @@ private Flux execute(ReactiveRedisCallback action) { } public ReactiveRedisConnectionFactory getConnectionFactory() { - return connectionFactory; + return this.connectionFactory; } } diff --git a/src/main/java/org/springframework/data/redis/util/RedisAssertions.java b/src/main/java/org/springframework/data/redis/util/RedisAssertions.java index 2fed009c81..bc63a3b9b3 100644 --- a/src/main/java/org/springframework/data/redis/util/RedisAssertions.java +++ b/src/main/java/org/springframework/data/redis/util/RedisAssertions.java @@ -57,6 +57,25 @@ public static T requireNonNull(@Nullable T target, Supplier message) return target; } + /** + * Asserts the given {@link Object} is not {@literal null} throwing the given {@link RuntimeException} + * if {@link Object} is {@literal null}. + * + * @param {@link Class type} of {@link Object} being asserted. + * @param target {@link Object} to evaluate. + * @param cause {@link Supplier} of a {@link RuntimeException} to throw + * if the given {@link Object} is {@literal null}. + * @return the given {@link Object}. + */ + public static T requireNonNull(@Nullable T target, RuntimeExceptionSupplier cause) { + + if (target == null) { + throw cause.get(); + } + + return target; + } + /** * Asserts the given {@link Object} is not {@literal null}. * @@ -85,4 +104,7 @@ public static T requireState(@Nullable T target, Supplier message) { Assert.state(target != null, message); return target; } + + public interface RuntimeExceptionSupplier extends Supplier { } + } diff --git a/src/test/java/org/springframework/data/redis/core/ReactiveStringRedisTemplateIntegrationTests.java b/src/test/java/org/springframework/data/redis/core/ReactiveStringRedisTemplateIntegrationTests.java index 0b3c8db611..503f3d8d69 100644 --- a/src/test/java/org/springframework/data/redis/core/ReactiveStringRedisTemplateIntegrationTests.java +++ b/src/test/java/org/springframework/data/redis/core/ReactiveStringRedisTemplateIntegrationTests.java @@ -36,7 +36,7 @@ @ExtendWith(LettuceConnectionFactoryExtension.class) public class ReactiveStringRedisTemplateIntegrationTests { - private ReactiveRedisConnectionFactory connectionFactory; + private final ReactiveRedisConnectionFactory connectionFactory; private ReactiveStringRedisTemplate template; @@ -66,17 +66,16 @@ void keysFailsOnNullElements() { template.opsForValue().set("a", "1").as(StepVerifier::create).expectNext(true).verifyComplete(); template.opsForValue().set("b", "1").as(StepVerifier::create).expectNext(true).verifyComplete(); - RedisElementWriter writer = RedisElementWriter.from(StringRedisSerializer.UTF_8); RedisElementReader reader = RedisElementReader.from(StringRedisSerializer.UTF_8); + RedisElementWriter writer = RedisElementWriter.from(StringRedisSerializer.UTF_8); + RedisSerializationContext nullReadingContext = RedisSerializationContext - . newSerializationContext(StringRedisSerializer.UTF_8).key(buffer -> { + .newSerializationContext(StringRedisSerializer.UTF_8).key(buffer -> { String read = reader.read(buffer); - if ("a".equals(read)) { - return null; - } - return read; + return "a".equals(read) ? null : read; + }, writer).build(); ReactiveRedisTemplate customTemplate = new ReactiveRedisTemplate<>(template.getConnectionFactory(), diff --git a/src/test/java/org/springframework/data/redis/util/RedisAssertionsUnitTests.java b/src/test/java/org/springframework/data/redis/util/RedisAssertionsUnitTests.java index ea157930f8..b0e613d491 100644 --- a/src/test/java/org/springframework/data/redis/util/RedisAssertionsUnitTests.java +++ b/src/test/java/org/springframework/data/redis/util/RedisAssertionsUnitTests.java @@ -16,6 +16,7 @@ package org.springframework.data.redis.util; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.assertj.core.api.Assertions.assertThatIllegalStateException; import static org.mockito.Mockito.doReturn; @@ -31,6 +32,8 @@ import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.springframework.dao.InvalidDataAccessApiUsageException; + /** * Unit Tests for {@link RedisAssertions}. * @@ -78,6 +81,24 @@ void requireNonNullWithSupplierThrowsIllegalArgumentException() { verifyNoMoreInteractions(this.mockSupplier); } + @Test + void requireNonNullWithRuntimeExceptionSupplierIsSuccessful() { + + assertThat(RedisAssertions.requireNonNull("mock", () -> new InvalidDataAccessApiUsageException("TEST"))) + .isEqualTo("mock"); + } + + @Test + @SuppressWarnings("all") + void requireNonNullWithThrowsRuntimeException() { + + assertThatExceptionOfType(InvalidDataAccessApiUsageException.class) + .isThrownBy(() -> RedisAssertions.requireNonNull(null, + () -> new InvalidDataAccessApiUsageException("TEST"))) + .withMessage("TEST") + .withNoCause(); + } + @Test void requireStateWithMessageAndArgumentsIsSuccessful() { assertThat(RedisAssertions.requireState("test", "Mock message")).isEqualTo("test");