Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement lock-free probabilistic cache reloading #784

Draft
wants to merge 11 commits into
base: master
Choose a base branch
from
133 changes: 89 additions & 44 deletions src/main/groovy/io/seqera/wave/store/cache/AbstractTieredCache.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import com.github.benmanes.caffeine.cache.RemovalCause
import com.github.benmanes.caffeine.cache.RemovalListener
import groovy.transform.Canonical
import groovy.transform.CompileStatic
import groovy.transform.Memoized
import groovy.transform.ToString
import groovy.util.logging.Slf4j
import io.seqera.wave.encoder.EncodingStrategy
Expand Down Expand Up @@ -104,17 +105,45 @@ abstract class AbstractTieredCache<V extends MoshiExchange> implements TieredCac
}
}

abstract int getMaxSize()
abstract protected int getMaxSize()

abstract protected getName()

abstract protected String getPrefix()

/**
* The cache probabilistic revalidation internal.
*
* See https://blog.cloudflare.com/sometimes-i-cache/
*
* @return
* The cache cache revalidation internal as a {@link Duration} value.
* When {@link Duration#ZERO} probabilistic revalidation is disabled.
*/
protected Duration getCacheRevalidationInterval() {
return Duration.ZERO
}

/**
* The cache probabilistic revalidation steepness value.
*
* By default is implemented as 1 / {@link #getCacheRevalidationInterval()} (as millis).
* Subclasses can override this method to provide a different value.
*
* See https://blog.cloudflare.com/sometimes-i-cache/
*
* @return Returns the revalidation steepness value.
*/
@Memoized
protected double getRevalidationSteepness() {
return 1 / getCacheRevalidationInterval().toMillis()
}

private RemovalListener removalListener0() {
new RemovalListener() {
@Override
void onRemoval(@Nullable key, @Nullable value, RemovalCause cause) {
if( log.isTraceEnabled( )) {
if( log.isTraceEnabled() ) {
log.trace "Cache '${name}' removing key=$key; value=$value; cause=$cause"
}
}
Expand Down Expand Up @@ -170,47 +199,59 @@ abstract class AbstractTieredCache<V extends MoshiExchange> implements TieredCac

private V getOrCompute0(String key, Function<String, Tuple2<V,Duration>> loader) {
assert key!=null, "Argument key cannot be null"

if( log.isTraceEnabled() )
log.trace "Cache '${name}' checking key=$key"
final ts = Instant.now()
// Try L1 cache first
V value = l1Get(key)
if (value != null) {
Entry entry = l1Get(key)
Boolean needsRevalidation = entry ? shouldRevalidate(entry.expiresAt, ts) : null
if( entry && !needsRevalidation ) {
if( log.isTraceEnabled() )
log.trace "Cache '${name}' L1 hit (a) - key=$key => value=$value"
return value
log.trace "Cache '${name}' L1 hit (a) - key=$key => entry=$entry"
return (V) entry.value
}

final sync = locks.get(key).get()
sync.lock()
try {
value = l1Get(key)
if (value != null) {
// check again L1 cache once in the sync block
if( !entry ) {
entry = l1Get(key)
needsRevalidation = entry ? shouldRevalidate(entry.expiresAt, ts) : null
}
if( entry && !needsRevalidation ) {
if( log.isTraceEnabled() )
log.trace "Cache '${name}' L1 hit (b) - key=$key => value=$value"
return value
log.trace "Cache '${name}' L1 hit (b) - key=$key => entry=$entry"
return (V)entry.value
}

// Fallback to L2 cache
final entry = l2GetEntry(key)
if (entry != null) {
if( !entry ) {
entry = l2Get(key)
needsRevalidation = entry ? shouldRevalidate(entry.expiresAt, ts) : null
}
if( entry && !needsRevalidation ) {
if( log.isTraceEnabled() )
log.trace "Cache '${name}' L2 hit - key=$key => entry=$entry"
log.trace "Cache '${name}' L2 hit (c) - key=$key => entry=$entry"
// Rehydrate L1 cache
l1.put(key, entry)
return (V) entry.value
}

// still not value found, use loader function to fetch the value
if( value==null && loader!=null ) {
if( log.isTraceEnabled() )
// still not entry found or cache revalidation needed
// use the loader function to fetch the value
V value = null
if( loader!=null ) {
if( entry && needsRevalidation )
log.debug "Cache '${name}' invoking loader - entry=$entry needs refresh"
else if( log.isTraceEnabled() )
log.trace "Cache '${name}' invoking loader - key=$key"
final ret = loader.apply(key)
value = ret?.v1
Duration ttl = ret?.v2
if( value!=null && ttl!=null ) {
final exp = Instant.now().plus(ttl).toEpochMilli()
final newEntry = new Entry(value,exp)
final newEntry = new Entry(value, exp)
l1Put(key, newEntry)
l2Put(key, newEntry, ttl)
}
Expand Down Expand Up @@ -240,47 +281,25 @@ abstract class AbstractTieredCache<V extends MoshiExchange> implements TieredCac

protected String key0(String k) { return getPrefix() + ':' + k }

protected V l1Get(String key) {
return (V) l1GetEntry(key)?.value
}

protected Entry l1GetEntry(String key) {
final entry = l1.getIfPresent(key)
if( entry == null )
return null

if( System.currentTimeMillis() > entry.expiresAt ) {
if( log.isTraceEnabled() )
log.trace "Cache '${name}' L1 expired - key=$key => entry=$entry"
return null
}
return entry
protected Entry l1Get(String key) {
return l1.getIfPresent(key)
}

protected void l1Put(String key, Entry entry) {
l1.put(key, entry)
}

protected Entry l2GetEntry(String key) {
protected Entry l2Get(String key) {
if( l2 == null )
return null

final raw = l2.get(key0(key))
if( raw == null )
return null

final Entry entry = encoder.decode(raw)
if( System.currentTimeMillis() > entry.expiresAt ) {
if( log.isTraceEnabled() )
log.trace "Cache '${name}' L2 expired - key=$key => value=${entry}"
return null
}
return entry
return encoder.decode(raw)
}

protected V l2Get(String key) {
return (V) l2GetEntry(key)?.value
}

protected void l2Put(String key, Entry entry, Duration ttl) {
if( l2 != null ) {
Expand All @@ -293,4 +312,30 @@ abstract class AbstractTieredCache<V extends MoshiExchange> implements TieredCac
l1.invalidateAll()
}

protected boolean shouldRevalidate(long expiration, Instant time=Instant.now()) {
// when 'remainingCacheTime' is less than or equals to zero, it means
// the current time is beyond the expiration time, therefore a cache validation is needed
final remainingCacheTime = expiration - time.toEpochMilli()
if (remainingCacheTime <= 0) {
return true
}

// otherwise, when remaining is greater than the cache revalidation interval
// no revalidation is needed
final cacheRevalidationMills = cacheRevalidationInterval.toMillis()
if( cacheRevalidationMills < remainingCacheTime ) {
return false
}

// finally the remaining time is shorter the validation interval
// i.e. it's approaching the cache expiration, in this cache the needed
// for cache revalidation is determined in a probabilistic manner
// see https://blog.cloudflare.com/sometimes-i-cache/
return randomRevalidate(cacheRevalidationMills-remainingCacheTime)
}

protected boolean randomRevalidate(long remainingTime) {
return Math.random() < Math.exp(-revalidationSteepness * remainingTime)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class ClientCache extends AbstractTieredCache {
}

@Override
int getMaxSize() {
protected int getMaxSize() {
return maxSize
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@

package io.seqera.wave.store.cache

import spock.lang.Retry

import java.time.Duration
import java.time.Instant

import com.squareup.moshi.JsonAdapter
import com.squareup.moshi.adapters.PolymorphicJsonAdapterFactory
Expand Down Expand Up @@ -102,7 +105,7 @@ class AbstractTieredCacheTest extends Specification implements RedisTestContaine
cache1.put(k, value, TTL)

then:
def entry1 = cache1.l1GetEntry(k)
def entry1 = cache1.l1Get(k)
and:
entry1.expiresAt > begin
then:
Expand Down Expand Up @@ -223,4 +226,83 @@ class AbstractTieredCacheTest extends Specification implements RedisTestContaine
cache.get(k2) == null
}

def 'should validate revalidation logic' () {
given:
def REVALIDATION_INTERVAL_SECS = 10
def now = Instant.now()
def cache = Spy(MyCache)
cache.getCacheRevalidationInterval() >> Duration.ofSeconds(REVALIDATION_INTERVAL_SECS)

when:
// when expiration is past, then 'revalidate' should be true
def expiration = now.minusSeconds(1)
def revalidate = cache.shouldRevalidate(expiration.toEpochMilli(), now)
then:
0 * cache.randomRevalidate(_) >> null
and:
revalidate

when:
// when expiration is longer than the revalidation internal, then 'revalidate' is false
expiration = now.plusSeconds(REVALIDATION_INTERVAL_SECS +1)
revalidate = cache.shouldRevalidate(expiration.toEpochMilli(), now)
then:
0 * cache.randomRevalidate(_) >> null
and:
!revalidate

when:
// when expiration is less than or equal the revalidation internal, then 'revalidate' is computed randomly
expiration = now.plusSeconds(REVALIDATION_INTERVAL_SECS)
revalidate = cache.shouldRevalidate(expiration.toEpochMilli(), now)
then:
1 * cache.randomRevalidate(_) >> true
and:
revalidate

when:
// when expiration is less than or equal the revalidation internal, then 'revalidate' is computed randomly
expiration = now.plusSeconds(REVALIDATION_INTERVAL_SECS -1)
revalidate = cache.shouldRevalidate(expiration.toEpochMilli(), now)
then:
1 * cache.randomRevalidate(_) >> false
and:
!revalidate
}

def 'should validate random function' () {
given:
def now = Instant.now()
def cache = Spy(MyCache)
cache.getCacheRevalidationInterval() >> Duration.ofSeconds(10)
expect:
cache.randomRevalidate(0)
}

@Retry(count = 5)
def 'should validate random revalidate with interval 10s' () {
given:
def now = Instant.now()
def cache = Spy(MyCache)
cache.getCacheRevalidationInterval() >> Duration.ofSeconds(10)
expect:
// when remaining time is approaching 0
// the function should return true
cache.randomRevalidate(10) // 10 millis
cache.randomRevalidate(100) // 100 millis
}

@Retry(count = 5)
def 'should validate random revalidate with interval 300s' () {
given:
def now = Instant.now()
def cache = Spy(MyCache)
cache.getCacheRevalidationInterval() >> Duration.ofSeconds(300)
expect:
// when remaining time is approaching 0
// the function should return true
cache.randomRevalidate(10) // 10 millis
cache.randomRevalidate(100) // 100 millis
cache.randomRevalidate(500) // 100 millis
}
}
Loading