Skip to content

Commit

Permalink
Terminate stream with error on null values returned by `RedisElemen…
Browse files Browse the repository at this point in the history
…tReader` for top-level elements.

We now emit InvalidDataAccessApiUsageException when a RedisElementReader returns null in the context of a top-level stream to indicate invalid API usage although RedisElementReader.read can generally return null values if these are being collected in a container or value wrapper or parent complex object.
  • Loading branch information
mp911de committed Aug 17, 2023
1 parent 7d3e805 commit 549f815
Show file tree
Hide file tree
Showing 20 changed files with 253 additions and 68 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ public static Object parse(Object source, String sourcePath, Map<String, Class<?
* @return
* @since 2.6
*/
public static <K, V> Map.Entry<K, V> entryOf(K key, V value) {
public static <K, V> Map.Entry<K, V> entryOf(@Nullable K key, @Nullable V value) {
return new AbstractMap.SimpleImmutableEntry<>(key, value);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import java.util.stream.Collectors;

import org.reactivestreams.Publisher;

import org.springframework.data.geo.Circle;
import org.springframework.data.geo.Distance;
import org.springframework.data.geo.GeoResult;
Expand All @@ -40,6 +39,7 @@
import org.springframework.data.redis.domain.geo.GeoReference.GeoMemberReference;
import org.springframework.data.redis.domain.geo.GeoShape;
import org.springframework.data.redis.serializer.RedisSerializationContext;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;

/**
Expand Down Expand Up @@ -321,6 +321,7 @@ private ByteBuffer rawValue(V value) {
return serializationContext.getValueSerializationPair().write(value);
}

@Nullable
private V readValue(ByteBuffer buffer) {
return serializationContext.getValueSerializationPair().read(buffer);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,11 @@
import java.util.function.Function;

import org.reactivestreams.Publisher;

import org.springframework.dao.InvalidDataAccessApiUsageException;
import org.springframework.data.redis.connection.ReactiveHashCommands;
import org.springframework.data.redis.connection.convert.Converters;
import org.springframework.data.redis.serializer.RedisSerializationContext;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;

/**
Expand Down Expand Up @@ -127,7 +128,7 @@ public Mono<HK> randomKey(H key) {
Assert.notNull(key, "Key must not be null");

return template.doCreateMono(connection -> connection //
.hashCommands().hRandField(rawKey(key))).map(this::readHashKey);
.hashCommands().hRandField(rawKey(key))).map(this::readRequiredHashKey);
}

@Override
Expand All @@ -145,7 +146,7 @@ public Flux<HK> randomKeys(H key, long count) {
Assert.notNull(key, "Key must not be null");

return template.doCreateFlux(connection -> connection //
.hashCommands().hRandField(rawKey(key), count)).map(this::readHashKey);
.hashCommands().hRandField(rawKey(key), count)).map(this::readRequiredHashKey);
}

@Override
Expand All @@ -163,7 +164,7 @@ public Flux<HK> keys(H key) {
Assert.notNull(key, "Key must not be null");

return createFlux(connection -> connection.hKeys(rawKey(key)) //
.map(this::readHashKey));
.map(this::readRequiredHashKey));
}

@Override
Expand Down Expand Up @@ -211,7 +212,7 @@ public Flux<HV> values(H key) {
Assert.notNull(key, "Key must not be null");

return createFlux(connection -> connection.hVals(rawKey(key)) //
.map(this::readHashValue));
.map(this::readRequiredHashValue));
}

@Override
Expand Down Expand Up @@ -268,13 +269,37 @@ private ByteBuffer rawHashValue(HV key) {
}

@SuppressWarnings("unchecked")
@Nullable
private HK readHashKey(ByteBuffer value) {
return (HK) serializationContext.getHashKeySerializationPair().read(value);
}

private HK readRequiredHashKey(ByteBuffer buffer) {

HK hashKey = readHashKey(buffer);

if (hashKey == null) {
throw new InvalidDataAccessApiUsageException("Deserialized hash key is null");
}

return hashKey;
}

@SuppressWarnings("unchecked")
private HV readHashValue(ByteBuffer value) {
return (HV) (value == null ? value : serializationContext.getHashValueSerializationPair().read(value));
@Nullable
private HV readHashValue(@Nullable ByteBuffer value) {
return (HV) (value == null ? null : serializationContext.getHashValueSerializationPair().read(value));
}

private HV readRequiredHashValue(ByteBuffer buffer) {

HV hashValue = readHashValue(buffer);

if (hashValue == null) {
throw new InvalidDataAccessApiUsageException("Deserialized hash value is null");
}

return hashValue;
}

private Map.Entry<HK, HV> deserializeHashEntry(Map.Entry<ByteBuffer, ByteBuffer> source) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,13 @@
import java.util.function.Function;

import org.reactivestreams.Publisher;
import org.springframework.dao.InvalidDataAccessApiUsageException;
import org.springframework.data.redis.connection.ReactiveListCommands;
import org.springframework.data.redis.connection.ReactiveListCommands.Direction;
import org.springframework.data.redis.connection.ReactiveListCommands.LPosCommand;
import org.springframework.data.redis.connection.RedisListCommands.Position;
import org.springframework.data.redis.serializer.RedisSerializationContext;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;

/**
Expand All @@ -58,7 +60,7 @@ public Flux<V> range(K key, long start, long end) {

Assert.notNull(key, "Key must not be null");

return createFlux(connection -> connection.lRange(rawKey(key), start, end).map(this::readValue));
return createFlux(connection -> connection.lRange(rawKey(key), start, end).map(this::readRequiredValue));
}

@Override
Expand Down Expand Up @@ -170,7 +172,8 @@ public Mono<V> move(K sourceKey, Direction from, K destinationKey, Direction to)
Assert.notNull(to, "To direction must not be null");

return createMono(
connection -> connection.lMove(rawKey(sourceKey), rawKey(destinationKey), from, to).map(this::readValue));
connection -> connection.lMove(rawKey(sourceKey), rawKey(destinationKey), from, to)
.map(this::readRequiredValue));
}

@Override
Expand All @@ -183,7 +186,7 @@ public Mono<V> move(K sourceKey, Direction from, K destinationKey, Direction to,
Assert.notNull(timeout, "Timeout must not be null");

return createMono(connection -> connection.bLMove(rawKey(sourceKey), rawKey(destinationKey), from, to, timeout)
.map(this::readValue));
.map(this::readRequiredValue));
}

@Override
Expand All @@ -208,7 +211,7 @@ public Mono<V> index(K key, long index) {

Assert.notNull(key, "Key must not be null");

return createMono(connection -> connection.lIndex(rawKey(key), index).map(this::readValue));
return createMono(connection -> connection.lIndex(rawKey(key), index).map(this::readRequiredValue));
}

@Override
Expand All @@ -232,7 +235,7 @@ public Mono<V> leftPop(K key) {

Assert.notNull(key, "Key must not be null");

return createMono(connection -> connection.lPop(rawKey(key)).map(this::readValue));
return createMono(connection -> connection.lPop(rawKey(key)).map(this::readRequiredValue));

}

Expand All @@ -244,15 +247,15 @@ public Mono<V> leftPop(K key, Duration timeout) {
Assert.isTrue(isZeroOrGreater1Second(timeout), "Duration must be either zero or greater or equal to 1 second");

return createMono(connection -> connection.blPop(Collections.singletonList(rawKey(key)), timeout)
.map(popResult -> readValue(popResult.getValue())));
.mapNotNull(popResult -> readValue(popResult.getValue())));
}

@Override
public Mono<V> rightPop(K key) {

Assert.notNull(key, "Key must not be null");

return createMono(connection -> connection.rPop(rawKey(key)).map(this::readValue));
return createMono(connection -> connection.rPop(rawKey(key)).map(this::readRequiredValue));
}

@Override
Expand All @@ -263,7 +266,7 @@ public Mono<V> rightPop(K key, Duration timeout) {
Assert.isTrue(isZeroOrGreater1Second(timeout), "Duration must be either zero or greater or equal to 1 second");

return createMono(connection -> connection.brPop(Collections.singletonList(rawKey(key)), timeout)
.map(popResult -> readValue(popResult.getValue())));
.mapNotNull(popResult -> readValue(popResult.getValue())));
}

@Override
Expand All @@ -273,7 +276,7 @@ public Mono<V> rightPopAndLeftPush(K sourceKey, K destinationKey) {
Assert.notNull(destinationKey, "Destination key must not be null");

return createMono(
connection -> connection.rPopLPush(rawKey(sourceKey), rawKey(destinationKey)).map(this::readValue));
connection -> connection.rPopLPush(rawKey(sourceKey), rawKey(destinationKey)).map(this::readRequiredValue));
}

@Override
Expand All @@ -285,7 +288,8 @@ public Mono<V> rightPopAndLeftPush(K sourceKey, K destinationKey, Duration timeo
Assert.isTrue(isZeroOrGreater1Second(timeout), "Duration must be either zero or greater or equal to 1 second");

return createMono(
connection -> connection.bRPopLPush(rawKey(sourceKey), rawKey(destinationKey), timeout).map(this::readValue));
connection -> connection.bRPopLPush(rawKey(sourceKey), rawKey(destinationKey), timeout)
.map(this::readRequiredValue));
}

@Override
Expand Down Expand Up @@ -322,7 +326,19 @@ private ByteBuffer rawValue(V value) {
return serializationContext.getValueSerializationPair().write(value);
}

@Nullable
private V readValue(ByteBuffer buffer) {
return serializationContext.getValueSerializationPair().read(buffer);
}

private V readRequiredValue(ByteBuffer buffer) {

V v = readValue(buffer);

if (v == null) {
throw new InvalidDataAccessApiUsageException("Deserialized list value is null");
}

return v;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@
import java.util.function.Function;

import org.reactivestreams.Publisher;
import org.springframework.dao.InvalidDataAccessApiUsageException;
import org.springframework.data.redis.connection.ReactiveSetCommands;
import org.springframework.data.redis.serializer.RedisSerializationContext;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;

/**
Expand Down Expand Up @@ -88,15 +90,15 @@ public Mono<V> pop(K key) {

Assert.notNull(key, "Key must not be null");

return createMono(connection -> connection.sPop(rawKey(key)).map(this::readValue));
return createMono(connection -> connection.sPop(rawKey(key)).map(this::readRequiredValue));
}

@Override
public Flux<V> pop(K key, long count) {

Assert.notNull(key, "Key must not be null");

return createFlux(connection -> connection.sPop(rawKey(key), count).map(this::readValue));
return createFlux(connection -> connection.sPop(rawKey(key), count).map(this::readRequiredValue));
}

@Override
Expand Down Expand Up @@ -176,7 +178,7 @@ public Flux<V> intersect(Collection<K> keys) {
.map(this::rawKey) //
.collectList() //
.flatMapMany(connection::sInter) //
.map(this::readValue));
.map(this::readRequiredValue));
}

@Override
Expand Down Expand Up @@ -238,7 +240,7 @@ public Flux<V> union(Collection<K> keys) {
.map(this::rawKey) //
.collectList() //
.flatMapMany(connection::sUnion) //
.map(this::readValue));
.map(this::readRequiredValue));
}

@Override
Expand Down Expand Up @@ -300,7 +302,7 @@ public Flux<V> difference(Collection<K> keys) {
.map(this::rawKey) //
.collectList() //
.flatMapMany(connection::sDiff) //
.map(this::readValue));
.map(this::readRequiredValue));
}

@Override
Expand Down Expand Up @@ -340,7 +342,7 @@ public Flux<V> members(K key) {

Assert.notNull(key, "Key must not be null");

return createFlux(connection -> connection.sMembers(rawKey(key)).map(this::readValue));
return createFlux(connection -> connection.sMembers(rawKey(key)).map(this::readRequiredValue));
}

@Override
Expand All @@ -349,31 +351,31 @@ public Flux<V> scan(K key, ScanOptions options) {
Assert.notNull(key, "Key must not be null");
Assert.notNull(options, "ScanOptions must not be null");

return createFlux(connection -> connection.sScan(rawKey(key), options).map(this::readValue));
return createFlux(connection -> connection.sScan(rawKey(key), options).map(this::readRequiredValue));
}

@Override
public Mono<V> randomMember(K key) {

Assert.notNull(key, "Key must not be null");

return createMono(connection -> connection.sRandMember(rawKey(key)).map(this::readValue));
return createMono(connection -> connection.sRandMember(rawKey(key)).map(this::readRequiredValue));
}

@Override
public Flux<V> distinctRandomMembers(K key, long count) {

Assert.isTrue(count > 0, "Negative count not supported; Use randomMembers to allow duplicate elements");

return createFlux(connection -> connection.sRandMember(rawKey(key), count).map(this::readValue));
return createFlux(connection -> connection.sRandMember(rawKey(key), count).map(this::readRequiredValue));
}

@Override
public Flux<V> randomMembers(K key, long count) {

Assert.isTrue(count > 0, "Use a positive number for count; This method is already allowing duplicate elements");

return createFlux(connection -> connection.sRandMember(rawKey(key), -count).map(this::readValue));
return createFlux(connection -> connection.sRandMember(rawKey(key), -count).map(this::readRequiredValue));
}

@Override
Expand Down Expand Up @@ -416,7 +418,19 @@ private ByteBuffer rawValue(V value) {
return serializationContext.getValueSerializationPair().write(value);
}

@Nullable
private V readValue(ByteBuffer buffer) {
return serializationContext.getValueSerializationPair().read(buffer);
}

private V readRequiredValue(ByteBuffer buffer) {

V v = readValue(buffer);

if (v == null) {
throw new InvalidDataAccessApiUsageException("Deserialized set value is null");
}

return v;
}
}
Loading

0 comments on commit 549f815

Please sign in to comment.