From 4f1c72773ae84edd16a727948fe0972ea7b82bbd Mon Sep 17 00:00:00 2001 From: Andrey G Date: Tue, 18 Jun 2024 18:16:27 +0300 Subject: [PATCH] FMWK-48 Spring Data Aerospike Cacheable sync option (#755) --- .../data/aerospike/cache/AerospikeCache.java | 49 +++++++++++++++---- ...AerospikeCacheManagerIntegrationTests.java | 35 +++++++++++++ 2 files changed, 74 insertions(+), 10 deletions(-) diff --git a/src/main/java/org/springframework/data/aerospike/cache/AerospikeCache.java b/src/main/java/org/springframework/data/aerospike/cache/AerospikeCache.java index a0631b210..99614fa28 100644 --- a/src/main/java/org/springframework/data/aerospike/cache/AerospikeCache.java +++ b/src/main/java/org/springframework/data/aerospike/cache/AerospikeCache.java @@ -112,20 +112,49 @@ public Object getNativeCache() { * @return The value (bins) to which this cache maps the specified key. */ @Override - @SuppressWarnings({"unchecked", "NullableProblems"}) + @SuppressWarnings("NullableProblems") public T get(Object key, Callable valueLoader) { - T value = (T) client.get(null, getKey(key)).getValue(VALUE); - if (Objects.isNull(value)) { - try { - value = valueLoader.call(); - if (Objects.nonNull(value)) { - put(key, value); + if (valueLoader != null) { + Key dbKey = getKey(key); + Record record = client.get(null, dbKey); + if (record == null) { + synchronized (this) { + record = client.get(null, dbKey); + if (record == null) { + T value = callValueLoader(valueLoader, key); + if (Objects.nonNull(value)) { + put(key, value); + } + return value; + } } - } catch (Exception e) { - throw new Cache.ValueRetrievalException(key, valueLoader, e); } + if (record.getValue(VALUE) != null) { + AerospikeReadData data = AerospikeReadData.forRead(dbKey, record); + Class type = getValueType(valueLoader); // determine the class of T + return aerospikeConverter.read(type, data); + } + } + return null; + } + + private T callValueLoader(Callable valueLoader, Object key) { + try { + return valueLoader.call(); + } catch (Exception e) { + throw new Cache.ValueRetrievalException(key, valueLoader, e); + } + } + + // Helper method to determine the class of T + @SuppressWarnings("unchecked") + private static Class getValueType(Callable valueLoader) { + try { + // Use reflection to get the return type of the Callable + return (Class) valueLoader.getClass().getMethod("call").getReturnType(); + } catch (NoSuchMethodException e) { + throw new IllegalArgumentException("Cannot determine value type", e); } - return value; } /** diff --git a/src/test/java/org/springframework/data/aerospike/cache/AerospikeCacheManagerIntegrationTests.java b/src/test/java/org/springframework/data/aerospike/cache/AerospikeCacheManagerIntegrationTests.java index e99c72f30..156af2ed7 100644 --- a/src/test/java/org/springframework/data/aerospike/cache/AerospikeCacheManagerIntegrationTests.java +++ b/src/test/java/org/springframework/data/aerospike/cache/AerospikeCacheManagerIntegrationTests.java @@ -66,6 +66,35 @@ public void shouldCache() { assertThat(cachingComponent.getNoOfCalls()).isEqualTo(1); } + @Test + public void testCacheableMethodSync() throws InterruptedException { + assertThat(cachingComponent.getNoOfCalls() == 0).isTrue(); + + // Creating two threads that will call cacheableMethod concurrently + Thread thread1 = new Thread(() -> { + CachedObject response = cachingComponent.cacheableMethodSynchronized(KEY); + assertThat(response).isNotNull(); + assertThat(response.getValue()).isEqualTo(VALUE); + }); + + Thread thread2 = new Thread(() -> { + CachedObject response = cachingComponent.cacheableMethodSynchronized(KEY); + assertThat(response).isNotNull(); + assertThat(response.getValue()).isEqualTo(VALUE); + }); + + // Starting both threads + thread1.start(); + thread2.start(); + + // Waiting for both threads to complete + thread1.join(); + thread2.join(); + + // Expecting method to be called only once due to synchronization + assertThat(cachingComponent.getNoOfCalls() == 1).isTrue(); + } + @Test public void shouldEvictCache() { CachedObject response1 = cachingComponent.cacheableMethod(KEY); @@ -185,6 +214,12 @@ public CachedObject cacheableMethod(String param) { return new CachedObject("id", VALUE); } + @Cacheable(value = "TEST", sync = true) + public CachedObject cacheableMethodSynchronized(String param) { + noOfCalls++; + return new CachedObject("id", VALUE); + } + @Cacheable(value = "CACHE-WITH-TTL") public CachedObject cacheableMethodWithTTL(String param) { noOfCalls++;