diff --git a/build/go/integration_test.sh b/build/go/integration_test.sh index 5ad51afc1..d4f61898a 100755 --- a/build/go/integration_test.sh +++ b/build/go/integration_test.sh @@ -4,4 +4,4 @@ # Note the use of `-p 1` is required to prevent multiple test packages from running in # parallel (default), ensuring access to any shared resource (e.g., dynamodb-local) # is serialized. -gotestsum -f testname -- -p 1 -race -coverprofile coverage.out -v ./integrationtest/... +gotestsum -f testname -- -p 1 -race -coverprofile coverage.out -v `go list ./integrationtest/... | grep -v traces` diff --git a/go/appencryption/.gitignore b/go/appencryption/.gitignore index 730925ca2..d1a9e135c 100644 --- a/go/appencryption/.gitignore +++ b/go/appencryption/.gitignore @@ -12,3 +12,10 @@ cmd/example/log.log hack .DS_Store **/checkstyle-result.xml + +integrationtest/traces/data/*.dat +integrationtest/traces/data/*.bz2 +integrationtest/traces/data/*.gz +integrationtest/traces/data/*.xz +integrationtest/traces/data/*.tgz +integrationtest/traces/out/ diff --git a/go/appencryption/.versionfile b/go/appencryption/.versionfile index 9e11b32fc..1d0ba9ea1 100644 --- a/go/appencryption/.versionfile +++ b/go/appencryption/.versionfile @@ -1 +1 @@ -0.3.1 +0.4.0 diff --git a/go/appencryption/cache.go b/go/appencryption/cache.go deleted file mode 100644 index c2a698d25..000000000 --- a/go/appencryption/cache.go +++ /dev/null @@ -1,282 +0,0 @@ -package appencryption - -import ( - "fmt" - "sync" - "time" - - "github.com/godaddy/asherah/go/appencryption/internal" - "github.com/godaddy/asherah/go/appencryption/pkg/log" -) - -// cacheEntry contains a key and the time it was loaded from the metastore. -type cacheEntry struct { - loadedAt time.Time - key *internal.CryptoKey -} - -// newCacheEntry returns a cacheEntry with the current time and key. -func newCacheEntry(k *internal.CryptoKey) cacheEntry { - return cacheEntry{ - loadedAt: time.Now(), - key: k, - } -} - -// cacheKey formats an id and create timestamp to a usable -// key for storage in a cache. -func cacheKey(id string, create int64) string { - return fmt.Sprintf("%s-%d", id, create) -} - -// keyLoaderFunc is an adapter to allow the use of ordinary functions as key loaders. -// If f is a function with the appropriate signature, keyLoaderFunc(f) is a keyLoader -// that calls f. -type keyLoaderFunc func() (*internal.CryptoKey, error) - -// Load calls f(). -func (f keyLoaderFunc) Load() (*internal.CryptoKey, error) { - return f() -} - -// keyLoader is used by cache objects to retrieve keys on an as-needed basis. -type keyLoader interface { - Load() (*internal.CryptoKey, error) -} - -// keyReloader extends keyLoader by adding the ability to inspect loaded keys -// and reload them when needed -type keyReloader interface { - keyLoader - - // IsInvalid returns true if the provided key is no longer valid - IsInvalid(*internal.CryptoKey) bool -} - -// cache contains cached keys for reuse. -type cache interface { - GetOrLoad(id KeyMeta, loader keyLoader) (*internal.CryptoKey, error) - GetOrLoadLatest(id string, loader keyLoader) (*internal.CryptoKey, error) - Close() error -} - -// Verify keyCache implements the cache interface. -var _ cache = (*keyCache)(nil) - -// keyCache is used to persist session based keys and destroys them on a call to close. -type keyCache struct { - once sync.Once - rw sync.RWMutex - policy *CryptoPolicy - keys map[string]cacheEntry -} - -// newKeyCache constructs a cache object that is ready to use. -func newKeyCache(policy *CryptoPolicy) *keyCache { - keys := make(map[string]cacheEntry) - - return &keyCache{ - policy: policy, - keys: keys, - } -} - -// isReloadRequired returns true if the check interval has elapsed -// since the timestamp provided. -func isReloadRequired(entry cacheEntry, checkInterval time.Duration) bool { - if entry.key.Revoked() { - // this key is revoked so no need to reload it again. - return false - } - - return entry.loadedAt.Add(checkInterval).Before(time.Now()) -} - -// GetOrLoad returns a key from the cache if it's already been loaded. If the key -// is not present in the cache it will retrieve the key using the provided keyLoader -// and store the key if an error is not returned. -func (c *keyCache) GetOrLoad(id KeyMeta, loader keyLoader) (*internal.CryptoKey, error) { - // get with "light" lock - c.rw.RLock() - k, ok := c.get(id) - c.rw.RUnlock() - - if ok { - return k, nil - } - - // load with heavy lock - c.rw.Lock() - defer c.rw.Unlock() - // exit early if the key doesn't need to be reloaded just in case it has been loaded by rw lock in front of us - if k, ok := c.get(id); ok { - return k, nil - } - - return c.load(id, loader) -} - -// get returns a key from the cache if present AND fresh. -// A cached value is considered stale if its time in cache -// has exceeded the RevokeCheckInterval. -// The second return value indicates the successful retrieval of a -// fresh key. -func (c *keyCache) get(id KeyMeta) (*internal.CryptoKey, bool) { - key := cacheKey(id.ID, id.Created) - - if e, ok := c.read(key); ok && !isReloadRequired(e, c.policy.RevokeCheckInterval) { - return e.key, true - } - - return nil, false -} - -// load returns a key from the cache if it's already been loaded. If the key is -// not present in the cache, or the cached entry needs to be reloaded, it will -// retrieve the key using the provided keyLoader and cache the key for future use. -// load maintains the latest entry for each distinct ID which can be accessed using -// id.Created == 0. -func (c *keyCache) load(id KeyMeta, loader keyLoader) (*internal.CryptoKey, error) { - key := cacheKey(id.ID, id.Created) - - k, err := loader.Load() - if err != nil { - return nil, err - } - - e, ok := c.read(key) - if ok && e.key.Created() == k.Created() { - // existing key in cache. update revoked status and last loaded time and close key - // we just loaded since we don't need it - e.key.SetRevoked(k.Revoked()) - e.loadedAt = time.Now() - c.write(key, e) - - k.Close() - } else { - // first time loading this key into cache or we have an ID-only key with mismatched - // create timestamps - e = newCacheEntry(k) - c.write(key, e) - } - - latestKey := cacheKey(id.ID, 0) - if key == latestKey { - // we've loaded a key using ID-only, ensure we've got a cache entry with a fully - // qualified cache key - c.write(cacheKey(id.ID, k.Created()), e) - } else if latest, ok := c.read(latestKey); !ok || latest.key.Created() < k.Created() { - // we've loaded a key using a fully qualified cache key and the ID-only entry is - // either missing or stale - c.write(latestKey, e) - } - - return e.key, nil -} - -// read retrieves the entry from the cache matching the provided ID if present. The second -// return value indicates whether or not the key was present in the cache. -func (c *keyCache) read(id string) (cacheEntry, bool) { - e, ok := c.keys[id] - - if !ok { - log.Debugf("%s miss -- id: %s\n", c, id) - } - - return e, ok -} - -// write entry e to the cache using id as the key. -func (c *keyCache) write(id string, e cacheEntry) { - if existing, ok := c.keys[id]; ok { - log.Debugf("%s update -> old: %s, new: %s, id: %s\n", c, existing.key, e.key, id) - } - - log.Debugf("%s write -> key: %s, id: %s\n", c, e.key, id) - c.keys[id] = e -} - -// GetOrLoadLatest returns the latest key from the cache matching the provided ID -// if it's already been loaded. If the key is not present in the cache it will -// retrieve the key using the provided KeyLoader and store the key if an error is not returned. -// If the provided loader implements the optional keyReloader interface then retrieved keys -// will be inspected for validity and reloaded if necessary. -func (c *keyCache) GetOrLoadLatest(id string, loader keyLoader) (*internal.CryptoKey, error) { - c.rw.Lock() - defer c.rw.Unlock() - - meta := KeyMeta{ID: id} - - key, ok := c.get(meta) - if !ok { - log.Debugf("%s.GetOrLoadLatest get miss -- id: %s\n", c, id) - - var err error - key, err = c.load(meta, loader) - - if err != nil { - return nil, err - } - } - - if reloader, ok := loader.(keyReloader); ok && reloader.IsInvalid(key) { - reloaded, ok := loader.Load() - log.Debugf("%s.GetOrLoadLatest reload -- invalid: %s, new: %s, id: %s\n", c, key, reloaded, id) - - e := newCacheEntry(reloaded) - - // update latest - latest := cacheKey(id, 0) - c.write(latest, e) - - // ensure we've got a cache entry with a fully qualified cache key - c.write(cacheKey(id, reloaded.Created()), e) - - return reloaded, ok - } - - return key, nil -} - -// Close frees all memory locked by the keys in this cache. -// It MUST be called after a session is complete to avoid -// running into MEMLOCK limits. -func (c *keyCache) Close() error { - c.once.Do(c.close) - - return nil -} - -func (c *keyCache) close() { - c.rw.Lock() - defer c.rw.Unlock() - - for k := range c.keys { - c.keys[k].key.Close() - } -} - -func (c *keyCache) String() string { - return fmt.Sprintf("keyCache(%p)", c) -} - -// Verify neverCache implements the cache interface. -var _ cache = (*neverCache)(nil) - -type neverCache struct { -} - -// GetOrLoad always executes the provided function to load the value. It never actually caches. -func (neverCache) GetOrLoad(id KeyMeta, loader keyLoader) (*internal.CryptoKey, error) { - return loader.Load() -} - -// GetOrLoadLatest always executes the provided function to load the latest value. It never actually caches. -func (neverCache) GetOrLoadLatest(id string, loader keyLoader) (*internal.CryptoKey, error) { - return loader.Load() -} - -// Close is a no-op function to satisfy the cache interface -func (neverCache) Close() error { - return nil -} diff --git a/go/appencryption/cache_benchmark_test.go b/go/appencryption/cache_benchmark_test.go deleted file mode 100644 index 7c4ff727d..000000000 --- a/go/appencryption/cache_benchmark_test.go +++ /dev/null @@ -1,334 +0,0 @@ -package appencryption - -import ( - "fmt" - "sync/atomic" - "testing" - "time" - - "github.com/godaddy/asherah/go/securememory/memguard" - "github.com/pkg/errors" - "github.com/stretchr/testify/assert" - - "github.com/godaddy/asherah/go/appencryption/internal" -) - -var ( - secretFactory = new(memguard.SecretFactory) - created = time.Now().Unix() -) - -func BenchmarkKeyCache_GetOrLoad_MultipleThreadsReadExistingKey(b *testing.B) { - c := newKeyCache(NewCryptoPolicy()) - - c.keys[cacheKey(testKey, created)] = cacheEntry{ - key: internal.NewCryptoKeyForTest(created, false), - loadedAt: time.Now(), - } - - b.ResetTimer() - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - key, err := c.GetOrLoad(KeyMeta{testKey, created}, keyLoaderFunc(func() (key *internal.CryptoKey, e error) { - // The passed function is irrelevant because we'll always find the value in the cache - return nil, nil - })) - - assert.NoError(b, err) - assert.Equal(b, created, key.Created()) - } - }) -} - -func BenchmarkKeyCache_GetOrLoad_MultipleThreadsWriteSameKey(b *testing.B) { - c := newKeyCache(NewCryptoPolicy()) - - b.ResetTimer() - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - _, err := c.GetOrLoad(KeyMeta{testKey, created}, keyLoaderFunc(func() (key *internal.CryptoKey, e error) { - // Add a delay to simulate time spent in performing a metastore read - time.Sleep(5 * time.Millisecond) - return internal.NewCryptoKeyForTest(created, false), nil - })) - - assert.NoError(b, err) - assert.Equal(b, created, c.keys[cacheKey(testKey, 0)].key.Created()) - } - }) -} - -func BenchmarkKeyCache_GetOrLoad_MultipleThreadsWriteUniqueKeys(b *testing.B) { - var ( - c = newKeyCache(NewCryptoPolicy()) - i int64 - ) - - b.ResetTimer() - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - curr := atomic.AddInt64(&i, 1) - _, err := c.GetOrLoad(KeyMeta{cacheKey(testKey, curr), created}, keyLoaderFunc(func() (key *internal.CryptoKey, e error) { - // Add a delay to simulate time spent in performing a metastore read - time.Sleep(5 * time.Millisecond) - return internal.NewCryptoKeyForTest(created, false), nil - })) - assert.NoError(b, err) - - // ensure we have a "latest" entry for this key as well - latest, err := c.GetOrLoadLatest(cacheKey(testKey, curr), keyLoaderFunc(func() (*internal.CryptoKey, error) { - return nil, errors.New("loader should not be executed") - })) - assert.NoError(b, err) - assert.NotNil(b, latest) - } - }) - assert.NotNil(b, c.keys) - assert.Equal(b, i*2, int64(len(c.keys))) -} - -func BenchmarkKeyCache_GetOrLoad_MultipleThreadsReadRevokedKey(b *testing.B) { - var ( - c = newKeyCache(NewCryptoPolicy()) - created = time.Now().Add(-(time.Minute * 100)).Unix() - ) - - key, err := internal.NewCryptoKey(secretFactory, created, false, []byte("testing")) - - assert.NoError(b, err) - - cacheEntry := cacheEntry{ - key: key, - loadedAt: time.Unix(created, 0), - } - - defer c.Close() - c.keys[cacheKey(testKey, created)] = cacheEntry - - b.ResetTimer() - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - _, err := c.GetOrLoad(KeyMeta{testKey, created}, keyLoaderFunc(func() (key *internal.CryptoKey, e error) { - // Add a delay to simulate time spent in performing a metastore read - time.Sleep(5 * time.Millisecond) - key, err2 := internal.NewCryptoKey(secretFactory, created, true, []byte("testing")) - if err2 != nil { - return nil, err2 - } - - return key, nil - })) - - assert.NoError(b, err) - assert.Equal(b, created, c.keys[cacheKey(testKey, 0)].key.Created()) - assert.True(b, c.keys[cacheKey(testKey, 0)].key.Revoked()) - assert.True(b, c.keys[cacheKey(testKey, created)].key.Revoked()) - } - }) -} - -func BenchmarkKeyCache_GetOrLoad_MultipleThreadsRead_NeedReloadKey(b *testing.B) { - var ( - c = newKeyCache(NewCryptoPolicy()) - created = time.Now().Add(-(time.Minute * 100)).Unix() - ) - - key, err := internal.NewCryptoKey(secretFactory, created, false, []byte("testing")) - - assert.NoError(b, err) - - cacheEntry := cacheEntry{ - key: key, - loadedAt: time.Unix(created, 0), - } - - defer c.Close() - c.keys[cacheKey(testKey, created)] = cacheEntry - - b.ResetTimer() - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - k, err := c.GetOrLoad(KeyMeta{testKey, created}, keyLoaderFunc(func() (*internal.CryptoKey, error) { - // Note: this function should only happen on first load (although could execute more than once currently), if it doesn't, then something is broken - - // Add a delay to simulate time spent in performing a metastore read - time.Sleep(5 * time.Millisecond) - - return internal.NewCryptoKey(secretFactory, created, false, []byte("testing")) - })) - - if err != nil { - b.Error(err) - } - if created != k.Created() { - b.Error("created mismatch") - } - } - }) -} - -func BenchmarkKeyCache_GetOrLoad_MultipleThreadsReadUniqueKeys(b *testing.B) { - var ( - c = newKeyCache(NewCryptoPolicy()) - i int64 - ) - - for ; i < int64(b.N); i++ { - c.keys[cacheKey(fmt.Sprintf(testKey+"-%d", i), created)] = cacheEntry{ - key: internal.NewCryptoKeyForTest(created, false), - loadedAt: time.Now(), - } - } - - i = 0 - - b.ResetTimer() - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - curr := atomic.LoadInt64(&i) - key, err := c.GetOrLoad(KeyMeta{cacheKey(testKey, curr), created}, keyLoaderFunc(func() (key *internal.CryptoKey, e error) { - // The passed function is irrelevant because we'll always find the value in the cache - return nil, nil - })) - assert.NoError(b, err) - assert.Equal(b, created, key.Created()) - - atomic.AddInt64(&i, 1) - } - }) -} - -func BenchmarkKeyCache_GetOrLoadLatest_MultipleThreadsReadExistingKey(b *testing.B) { - c := newKeyCache(NewCryptoPolicy()) - - c.keys[cacheKey(testKey, 0)] = cacheEntry{ - key: internal.NewCryptoKeyForTest(created, false), - loadedAt: time.Now(), - } - - b.ResetTimer() - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - key, err := c.GetOrLoadLatest(testKey, keyLoaderFunc(func() (*internal.CryptoKey, error) { - // The passed function is irrelevant because we'll always find the value in the cache - return nil, nil - })) - assert.NoError(b, err) - assert.Equal(b, created, key.Created()) - } - }) -} - -func BenchmarkKeyCache_GetOrLoadLatest_MultipleThreadsWriteSameKey(b *testing.B) { - c := newKeyCache(NewCryptoPolicy()) - - b.ResetTimer() - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - _, err := c.GetOrLoadLatest(testKey, keyLoaderFunc(func() (*internal.CryptoKey, error) { - // Add a delay to simulate time spent in performing a metastore read - time.Sleep(5 * time.Millisecond) - return internal.NewCryptoKeyForTest(created, false), nil - })) - assert.NoError(b, err) - assert.Equal(b, created, c.keys[cacheKey(testKey, 0)].key.Created()) - } - }) -} - -func BenchmarkKeyCache_GetOrLoadLatest_MultipleThreadsWriteUniqueKey(b *testing.B) { - var ( - c = newKeyCache(NewCryptoPolicy()) - i int64 - ) - - b.ResetTimer() - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - curr := atomic.AddInt64(&i, 1) - _, err := c.GetOrLoadLatest(cacheKey(testKey, curr), keyLoaderFunc(func() (*internal.CryptoKey, error) { - // Add a delay to simulate time spent in performing a metastore read - time.Sleep(5 * time.Millisecond) - - return internal.NewCryptoKeyForTest(created, false), nil - })) - assert.NoError(b, err) - - // ensure we actually have a "latest" entry for this key in the cache - latest, err := c.GetOrLoadLatest(cacheKey(testKey, curr), keyLoaderFunc(func() (*internal.CryptoKey, error) { - return nil, errors.New("loader should not be executed") - })) - assert.NoError(b, err) - assert.NotNil(b, latest) - } - }) - assert.NotNil(b, c.keys) - assert.Equal(b, i*2, int64(len(c.keys))) -} - -func BenchmarkKeyCache_GetOrLoadLatest_MultipleThreadsReadRevokedKey(b *testing.B) { - var ( - c = newKeyCache(NewCryptoPolicy()) - created = time.Now().Add(-(time.Minute * 100)).Unix() - ) - - key, err := internal.NewCryptoKey(secretFactory, created, false, []byte("testing")) - cacheEntry := cacheEntry{ - key: key, - loadedAt: time.Unix(created, 0), - } - - assert.NoError(b, err) - - defer c.Close() - - c.keys[cacheKey(testKey, 0)] = cacheEntry - - b.ResetTimer() - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - _, err := c.GetOrLoadLatest(testKey, keyLoaderFunc(func() (key *internal.CryptoKey, e error) { - // Add a delay to simulate time spent in performing a metastore read - time.Sleep(5 * time.Millisecond) - - return internal.NewCryptoKey(secretFactory, created, true, []byte("testing")) - })) - - assert.NoError(b, err) - assert.Equal(b, created, c.keys[cacheKey(testKey, 0)].key.Created()) - assert.True(b, c.keys[cacheKey(testKey, 0)].key.Revoked()) - } - }) -} - -func BenchmarkKeyCache_GetOrLoadLatest_MultipleThreadsReadUniqueKeys(b *testing.B) { - var ( - c = newKeyCache(NewCryptoPolicy()) - i int64 - ) - - for ; i < int64(b.N); i++ { - c.keys[cacheKey(fmt.Sprintf(testKey+"-%d", i), 0)] = cacheEntry{ - key: internal.NewCryptoKeyForTest(created, false), - loadedAt: time.Now(), - } - } - - i = 0 - - b.ResetTimer() - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - curr := atomic.LoadInt64(&i) - key, err := c.GetOrLoadLatest(fmt.Sprintf(testKey+"-%d", curr), keyLoaderFunc(func() (key *internal.CryptoKey, e error) { - // The passed function is irrelevant because we'll always find the value in the cache - return nil, nil - })) - - assert.NoError(b, err) - assert.Equal(b, created, key.Created()) - - atomic.AddInt64(&i, 1) - } - }) -} diff --git a/go/appencryption/cmd/example/main.go b/go/appencryption/cmd/example/main.go index 075dc99d5..026f3eeec 100644 --- a/go/appencryption/cmd/example/main.go +++ b/go/appencryption/cmd/example/main.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "log" + "math/rand" "net/http" _ "net/http/pprof" "os" @@ -63,6 +64,13 @@ type Options struct { CheckInterval time.Duration `long:"check" description:"Interval to check for expired keys"` ConnectionString string `short:"C" long:"conn" description:"MySQL Connection String"` NoExit bool `short:"x" long:"no-exit" description:"Prevent app from closing once tests are completed. Especially useful for profiling."` + + SessionCacheSize int `long:"session-cache-size" description:"Number of sessions to cache in the shared session cache."` + SessionCacheExpiry time.Duration `long:"session-cache-expiry" description:"Duration after which a session is evicted from the shared session cache."` + EnableSharedIKCache bool `long:"enable-shared-ik-cache" description:"Enables the shared IK cache."` + IKCacheSize int `long:"ik-cache-size" description:"Number of IKs to cache in the IK cache."` + SKCacheSize int `long:"sk-cache-size" description:"Number of SKs to cache in the SK cache."` + RandomizePartition bool `long:"randomize-partition" description:"Randomize the partition ID for each session using a Zipfian distribution."` } var ( @@ -232,17 +240,37 @@ func main() { return func(*appencryption.CryptoPolicy) { /* noop */ } } + policy := appencryption.NewCryptoPolicy( + appencryption.WithExpireAfterDuration(expireAfter), + appencryption.WithRevokeCheckInterval(checkInterval), + withCacheOption(), + withSessionCacheOption(), + ) + + if opts.SessionCacheSize > 0 { + policy.SessionCacheMaxSize = opts.SessionCacheSize + } + + if opts.SessionCacheExpiry > 0 { + policy.SessionCacheDuration = opts.SessionCacheExpiry + } + + if opts.IKCacheSize > 0 { + policy.IntermediateKeyCacheMaxSize = opts.IKCacheSize + } + + if opts.SKCacheSize > 0 { + policy.SystemKeyCacheMaxSize = opts.SKCacheSize + } + + policy.SharedIntermediateKeyCache = opts.EnableSharedIKCache + keyManager := CreateKMS() conf := &appencryption.Config{ Service: "exampleService", Product: "productId", - Policy: appencryption.NewCryptoPolicy( - appencryption.WithExpireAfterDuration(expireAfter), - appencryption.WithRevokeCheckInterval(checkInterval), - withCacheOption(), - withSessionCacheOption(), - ), + Policy: policy, } secrets := new(memguard.SecretFactory) @@ -293,9 +321,26 @@ func main() { }() } + var partitioner func() int + if opts.RandomizePartition { + r := rand.New(rand.NewSource(1)) + zipf := rand.NewZipf(r, 1.01, 1.0, 1<<16-1) + + partitioner = func() int { + return int(zipf.Uint64()) + } + } + for i := 0; i < opts.Iterations; i++ { - log.Println("Run iteration:", i) - RunSessionIteration(time.Now(), factory) + if opts.Verbose { + log.Printf( + "[run iteration %d] secrets: allocs=%d, inuse=%d\n", + i, + securememory.AllocCounter.Count(), + securememory.InUseCounter.Count()) + } + + RunSessionIteration(time.Now(), factory, partitioner) } done <- true @@ -364,14 +409,30 @@ func main() { } } -func RunSessionIteration(start time.Time, factory *appencryption.SessionFactory) { +func RunSessionIteration(start time.Time, factory *appencryption.SessionFactory, partitioner func() int) { var wg sync.WaitGroup for i := 0; i < opts.Sessions; i++ { wg.Add(1) + partitionID := i + + if partitioner != nil { + partitionID = partitioner() + } + go func(i int) { - defer wg.Done() + defer func() { + if r := recover(); r != nil { + log.Printf( + "[panic] secrets: allocs=%d, inuse=%d\n", + securememory.AllocCounter.Count(), + securememory.InUseCounter.Count()) + panic(r) + } + + wg.Done() + }() runFunc := func(shopper string) { session, err := factory.GetSession(shopper) @@ -407,7 +468,7 @@ func RunSessionIteration(start time.Time, factory *appencryption.SessionFactory) runFunc(shopper) } - }(i) + }(partitionID) } wg.Wait() diff --git a/go/appencryption/envelope.go b/go/appencryption/envelope.go index 8c859d952..8b01a4569 100644 --- a/go/appencryption/envelope.go +++ b/go/appencryption/envelope.go @@ -3,7 +3,6 @@ package appencryption import ( "context" "fmt" - "sync" "time" "github.com/godaddy/asherah/go/securememory" @@ -11,6 +10,7 @@ import ( "github.com/rcrowley/go-metrics" "github.com/godaddy/asherah/go/appencryption/internal" + "github.com/godaddy/asherah/go/appencryption/pkg/log" ) // MetricsPrefix prefixes all metrics names @@ -33,6 +33,19 @@ func (m KeyMeta) String() string { return fmt.Sprintf("KeyMeta [keyId=%s created=%d]", m.ID, m.Created) } +// IsLatest returns true if the key meta is the latest version of the key. +func (m KeyMeta) IsLatest() bool { + return m.Created == 0 +} + +// AsLatest returns a copy of the key meta with the Created timestamp set to 0. +func (m KeyMeta) AsLatest() KeyMeta { + return KeyMeta{ + ID: m.ID, + Created: 0, + } +} + // DataRowRecord contains the encrypted key and provided data, as well as the information // required to decrypt the key encryption key. This struct should be stored in your // data persistence as it's required to decrypt data. @@ -57,14 +70,14 @@ var _ Encryption = (*envelopeEncryption)(nil) // envelopeEncryption is used to encrypt and decrypt data related to a specific partition ID. type envelopeEncryption struct { - partition partition - Metastore Metastore - KMS KeyManagementService - Policy *CryptoPolicy - Crypto AEAD - SecretFactory securememory.SecretFactory - systemKeys cache - intermediateKeys cache + partition partition + Metastore Metastore + KMS KeyManagementService + Policy *CryptoPolicy + Crypto AEAD + SecretFactory securememory.SecretFactory + skCache keyCacher + ikCache keyCacher } // loadSystemKey fetches a known system key from the metastore and decrypts it using the key management service. @@ -91,12 +104,17 @@ func (e *envelopeEncryption) systemKeyFromEKR(ctx context.Context, ekr *Envelope return internal.NewCryptoKey(e.SecretFactory, ekr.Created, ekr.Revoked, bytes) } +type accessorRevokable interface { + internal.Revokable + internal.BytesFuncAccessor +} + // intermediateKeyFromEKR decrypts ekr using sk and returns a new CryptoKey containing the decrypted key data. -func (e *envelopeEncryption) intermediateKeyFromEKR(sk *internal.CryptoKey, ekr *EnvelopeKeyRecord) (*internal.CryptoKey, error) { +func (e *envelopeEncryption) intermediateKeyFromEKR(sk accessorRevokable, ekr *EnvelopeKeyRecord) (*internal.CryptoKey, error) { if ekr != nil && ekr.ParentKeyMeta != nil && sk.Created() != ekr.ParentKeyMeta.Created { - //In this case, the system key just rotated and this EKR was encrypted with the prior SK. - //A duplicate IK would have been attempted to create with the correct SK but would create a duplicate so is discarded. - //Lookup the correct system key so the ik decryption can succeed. + // In this case, the system key just rotated and this EKR was encrypted with the prior SK. + // A duplicate IK would have been attempted to create with the correct SK but would create a duplicate so is discarded. + // Lookup the correct system key so the ik decryption can succeed. skLoaded, err := e.getOrLoadSystemKey(context.Background(), *ekr.ParentKeyMeta) if err != nil { return nil, err @@ -177,105 +195,10 @@ func (e *envelopeEncryption) tryStoreSystemKey(ctx context.Context, sk *internal return e.tryStore(ctx, ekr), nil } -var _ keyReloader = (*reloader)(nil) - -type reloader struct { - loadedKeys []*internal.CryptoKey - mu sync.Mutex - loader keyLoader - isInvalidFunc func(key *internal.CryptoKey) bool - keyID string - isCached bool -} - -// Load implements keyLoader. -func (r *reloader) Load() (*internal.CryptoKey, error) { - k, err := r.loader.Load() - if err != nil { - return nil, err - } - - r.append(k) - - return k, nil -} - -// append a key to the list of loaded keys. A call to -// Close will close all appended keys. -func (r *reloader) append(key *internal.CryptoKey) { - r.mu.Lock() - r.loadedKeys = append(r.loadedKeys, key) - r.mu.Unlock() -} - -// IsInvalid implements keyReloader -func (r *reloader) IsInvalid(key *internal.CryptoKey) bool { - return r.isInvalidFunc(key) -} - -// Close calls maybeCloseKey for all keys previously loaded by a reloader instance. -func (r *reloader) Close() { - r.mu.Lock() - defer r.mu.Unlock() - - for k := range r.loadedKeys { - key := r.loadedKeys[k] - - maybeCloseKey(r.isCached, key) - } -} - -// GetOrLoadLatest wraps the GetOrLoadLatest of c using r as the loader. -func (r *reloader) GetOrLoadLatest(c cache) (*internal.CryptoKey, error) { - return c.GetOrLoadLatest(r.keyID, r) -} - -// newIntermediateKeyReloader returns a new reloader for intermediate keys. -func (e *envelopeEncryption) newIntermediateKeyReloader(ctx context.Context) *reloader { - return e.newKeyReloader( - ctx, - e.partition.IntermediateKeyID(), - e.Policy.CacheIntermediateKeys, - e.loadLatestOrCreateIntermediateKey, - ) -} - -// newSystemKeyReloader returns a new reloader for system keys. -func (e *envelopeEncryption) newSystemKeyReloader(ctx context.Context) *reloader { - return e.newKeyReloader( - ctx, - e.partition.SystemKeyID(), - e.Policy.CacheSystemKeys, - e.loadLatestOrCreateSystemKey, - ) -} - -// newKeyReloader returns a new reloader. -func (e *envelopeEncryption) newKeyReloader( - ctx context.Context, - id string, - isCached bool, - loader func(context.Context, string) (*internal.CryptoKey, error), -) *reloader { - return &reloader{ - keyID: id, - isCached: isCached, - loader: keyLoaderFunc(func() (*internal.CryptoKey, error) { - return loader(ctx, id) - }), - isInvalidFunc: e.isKeyInvalid, - } -} - -// isKeyInvalid checks if the key is revoked or expired. -func (e *envelopeEncryption) isKeyInvalid(key *internal.CryptoKey) bool { - return key.Revoked() || isKeyExpired(key.Created(), e.Policy.ExpireKeyAfter) -} - // isEnvelopeInvalid checks if the envelope key record is revoked or has an expired key. func (e *envelopeEncryption) isEnvelopeInvalid(ekr *EnvelopeKeyRecord) bool { // TODO Add key rotation policy check. If not inline, then can return valid even if expired - return e == nil || isKeyExpired(ekr.Created, e.Policy.ExpireKeyAfter) || ekr.Revoked + return e == nil || internal.IsKeyExpired(ekr.Created, e.Policy.ExpireKeyAfter) || ekr.Revoked } func (e *envelopeEncryption) generateKey() (*internal.CryptoKey, error) { @@ -318,15 +241,16 @@ func (e *envelopeEncryption) mustLoadLatest(ctx context.Context, id string) (*En // createIntermediateKey creates a new IK and attempts to persist the new key to the metastore. // If unsuccessful createIntermediateKey will attempt to fetch the latest IK from the metastore. func (e *envelopeEncryption) createIntermediateKey(ctx context.Context) (*internal.CryptoKey, error) { - r := e.newSystemKeyReloader(ctx) - defer r.Close() - // Try to get latest from cache. - sk, err := r.GetOrLoadLatest(e.systemKeys) + sk, err := e.skCache.GetOrLoadLatest(e.partition.SystemKeyID(), func(meta KeyMeta) (*internal.CryptoKey, error) { + return e.loadLatestOrCreateSystemKey(ctx, meta.ID) + }) if err != nil { return nil, err } + defer sk.Close() + ik, err := e.generateKey() if err != nil { return nil, err @@ -358,7 +282,7 @@ func (e *envelopeEncryption) createIntermediateKey(ctx context.Context) (*intern // tryStoreIntermediateKey attempts to persist the encrypted ik to the metastore ignoring all persistence related errors. // err will be non-nil only if encryption fails. -func (e *envelopeEncryption) tryStoreIntermediateKey(ctx context.Context, ik, sk *internal.CryptoKey) (success bool, err error) { +func (e *envelopeEncryption) tryStoreIntermediateKey(ctx context.Context, ik, sk accessorRevokable) (success bool, err error) { encBytes, err := internal.WithKeyFunc(ik, func(keyBytes []byte) ([]byte, error) { return internal.WithKeyFunc(sk, func(systemKeyBytes []byte) ([]byte, error) { return e.Crypto.Encrypt(keyBytes, systemKeyBytes) @@ -398,7 +322,7 @@ func (e *envelopeEncryption) loadLatestOrCreateIntermediateKey(ctx context.Conte return e.createIntermediateKey(ctx) } - defer maybeCloseKey(e.Policy.CacheSystemKeys, sk) + defer sk.Close() // Only use the loaded IK if it and its parent key is valid. if ik := e.getValidIntermediateKey(sk, ikEkr); ik != nil { @@ -411,18 +335,16 @@ func (e *envelopeEncryption) loadLatestOrCreateIntermediateKey(ctx context.Conte // getOrLoadSystemKey returns a system key from cache if it's already been loaded. Otherwise it retrieves the key // from the metastore. -func (e *envelopeEncryption) getOrLoadSystemKey(ctx context.Context, meta KeyMeta) (*internal.CryptoKey, error) { - loader := keyLoaderFunc(func() (*internal.CryptoKey, error) { - return e.loadSystemKey(ctx, meta) +func (e *envelopeEncryption) getOrLoadSystemKey(ctx context.Context, meta KeyMeta) (*cachedCryptoKey, error) { + return e.skCache.GetOrLoad(meta, func(m KeyMeta) (*internal.CryptoKey, error) { + return e.loadSystemKey(ctx, m) }) - - return e.systemKeys.GetOrLoad(meta, loader) } // getValidIntermediateKey returns a new CryptoKey constructed from ekr. It returns nil if sk is invalid or if key initialization fails. -func (e *envelopeEncryption) getValidIntermediateKey(sk *internal.CryptoKey, ekr *EnvelopeKeyRecord) *internal.CryptoKey { +func (e *envelopeEncryption) getValidIntermediateKey(sk accessorRevokable, ekr *EnvelopeKeyRecord) *internal.CryptoKey { // IK is only valid if its parent is valid - if e.isKeyInvalid(sk) { + if internal.IsKeyInvalid(sk, e.Policy.ExpireKeyAfter) { return nil } @@ -436,7 +358,7 @@ func (e *envelopeEncryption) getValidIntermediateKey(sk *internal.CryptoKey, ekr } // decryptRow decrypts drr using ik as the parent key and returns the decrypted data. -func decryptRow(ik *internal.CryptoKey, drr DataRowRecord, crypto AEAD) ([]byte, error) { +func decryptRow(ik internal.BytesFuncAccessor, drr DataRowRecord, crypto AEAD) ([]byte, error) { return internal.WithKeyFunc(ik, func(bytes []byte) ([]byte, error) { // TODO Consider having separate DecryptKey that is functional and handles wiping bytes rawDrk, err := crypto.Decrypt(drr.Key.EncryptedKey, bytes) @@ -450,31 +372,30 @@ func decryptRow(ik *internal.CryptoKey, drr DataRowRecord, crypto AEAD) ([]byte, }) } -// maybeCloseKey closes key if isCached is false. -func maybeCloseKey(isCached bool, key *internal.CryptoKey) { - if !isCached { - key.Close() - } -} - // EncryptPayload encrypts a provided slice of bytes and returns the data with the data row key and required // parent information to decrypt the data in the future. It also takes a context used for cancellation. func (e *envelopeEncryption) EncryptPayload(ctx context.Context, data []byte) (*DataRowRecord, error) { defer encryptTimer.UpdateSince(time.Now()) - reloader := e.newIntermediateKeyReloader(ctx) - defer reloader.Close() + loader := func(meta KeyMeta) (*internal.CryptoKey, error) { + log.Debugf("[EncryptPayload] loadLatestOrCreateIntermediateKey: %s", meta.ID) + return e.loadLatestOrCreateIntermediateKey(ctx, meta.ID) + } // Try to get latest from cache. - ik, err := reloader.GetOrLoadLatest(e.intermediateKeys) + ik, err := e.ikCache.GetOrLoadLatest(e.partition.IntermediateKeyID(), loader) if err != nil { + log.Debugf("[EncryptPayload] GetOrLoadLatest failed: %s", err.Error()) return nil, err } + defer ik.Close() + // Note the id doesn't mean anything for DRK. Don't need to truncate created since that is intended // to prevent excessive IK/SK creation (we always create new DRK on each write, so not a concern there) drk, err := internal.GenerateKey(e.SecretFactory, time.Now().Unix(), AES256KeySize) if err != nil { + log.Debugf("[EncryptPayload] GenerateKey failed: %s", err.Error()) return nil, err } @@ -484,6 +405,7 @@ func (e *envelopeEncryption) EncryptPayload(ctx context.Context, data []byte) (* return e.Crypto.Encrypt(data, bytes) }) if err != nil { + log.Debugf("[EncryptPayload] WithKeyFunc failed to encrypt data using DRK: %s", err.Error()) return nil, err } @@ -493,6 +415,7 @@ func (e *envelopeEncryption) EncryptPayload(ctx context.Context, data []byte) (* }) }) if err != nil { + log.Debugf("[EncryptPayload] WithKeyFunc failed to encrypt DRK using IK: %s", err.Error()) return nil, err } @@ -526,16 +449,20 @@ func (e *envelopeEncryption) DecryptDataRowRecord(ctx context.Context, drr DataR return nil, errors.New("unable to decrypt record") } - loader := keyLoaderFunc(func() (*internal.CryptoKey, error) { - return e.loadIntermediateKey(ctx, *drr.Key.ParentKeyMeta) - }) + loader := func(meta KeyMeta) (*internal.CryptoKey, error) { + log.Debugf("[DecryptDataRowRecord] loadIntermediateKey: %s", meta.ID) - ik, err := e.intermediateKeys.GetOrLoad(*drr.Key.ParentKeyMeta, loader) + return e.loadIntermediateKey(ctx, meta) + } + + ik, err := e.ikCache.GetOrLoad(*drr.Key.ParentKeyMeta, loader) if err != nil { + log.Debugf("[DecryptDataRowRecord] GetOrLoad IK failed: %s", err.Error()) + return nil, err } - defer maybeCloseKey(e.Policy.CacheIntermediateKeys, ik) + defer ik.Close() return decryptRow(ik, drr, e.Crypto) } @@ -556,7 +483,7 @@ func (e *envelopeEncryption) loadIntermediateKey(ctx context.Context, meta KeyMe return nil, err } - defer maybeCloseKey(e.Policy.CacheSystemKeys, sk) + defer sk.Close() return e.intermediateKeyFromEKR(sk, ekr) } @@ -564,5 +491,9 @@ func (e *envelopeEncryption) loadIntermediateKey(ctx context.Context, meta KeyMe // Close frees all memory locked by the keys in the session. It should be called // as soon as its no longer in use. func (e *envelopeEncryption) Close() error { - return e.intermediateKeys.Close() + if e.Policy != nil && e.Policy.SharedIntermediateKeyCache { + return nil + } + + return e.ikCache.Close() } diff --git a/go/appencryption/envelope_test.go b/go/appencryption/envelope_test.go index cdec8c8ed..1d69f65c7 100644 --- a/go/appencryption/envelope_test.go +++ b/go/appencryption/envelope_test.go @@ -29,8 +29,8 @@ var ( type EnvelopeSuite struct { suite.Suite crypto AEAD - ikCache cache - skCache cache + ikCache keyCacher + skCache keyCacher partition partition e envelopeEncryption metastore Metastore @@ -53,14 +53,14 @@ func (suite *EnvelopeSuite) SetupTest() { suite.secretFactory = new(MockSecretFactory) suite.e = envelopeEncryption{ - partition: suite.partition, - Metastore: suite.metastore, - KMS: suite.kms, - Policy: NewCryptoPolicy(), - Crypto: suite.crypto, - SecretFactory: suite.secretFactory, - systemKeys: suite.skCache, - intermediateKeys: suite.ikCache, + partition: suite.partition, + Metastore: suite.metastore, + KMS: suite.kms, + Policy: NewCryptoPolicy(), + Crypto: suite.crypto, + SecretFactory: suite.secretFactory, + skCache: suite.skCache, + ikCache: suite.ikCache, } var err error @@ -874,66 +874,6 @@ func (suite *EnvelopeSuite) TestEnvelopeEncryption_DecryptDataRowRecord_ReturnsE mock.AssertExpectationsForObjects(suite.T(), suite.ikCache) } -func (suite *EnvelopeSuite) Test_KeyReloader_Load() { - called := false - - reloader := &reloader{ - loader: keyLoaderFunc(func() (*internal.CryptoKey, error) { - called = true - return nil, nil - }), - } - - k, err := reloader.Load() - assert.Nil(suite.T(), k) - assert.NoError(suite.T(), err) - assert.True(suite.T(), called) -} - -func (suite *EnvelopeSuite) Test_KeyReloader_IsInvalid() { - k, _ := getKeyAndKeyBytes(suite.T()) - called := false - - reloader := &reloader{ - isInvalidFunc: func(key *internal.CryptoKey) bool { - called = true - - assert.Equal(suite.T(), k, key) - return false - }, - } - - reloader.IsInvalid(k) - - assert.True(suite.T(), called) -} - -func (suite *EnvelopeSuite) Test_KeyReloader_Close() { - reloader := &reloader{ - loader: keyLoaderFunc(func() (*internal.CryptoKey, error) { - k, _ := getKeyAndKeyBytes(suite.T()) - return k, nil - }), - } - loadTestKey := func() *internal.CryptoKey { - k, _ := reloader.Load() - return k - } - - var keys []*internal.CryptoKey - keys = append(keys, loadTestKey(), loadTestKey()) - - for _, k := range keys { - assert.False(suite.T(), k.IsClosed()) - } - - reloader.Close() - - for _, k := range keys { - assert.True(suite.T(), k.IsClosed()) - } -} - func TestKeyMeta_String(t *testing.T) { meta := KeyMeta{ Created: someTimestamp, @@ -968,17 +908,19 @@ func TestEnvelopeEncryption_Close(t *testing.T) { sec, err := secretFactory.New(data) if assert.NoError(t, err) { m := new(MockSecretFactory) + m.On("New", data).Return(sec, nil) - cache := newKeyCache(NewCryptoPolicy()) + cache := newKeyCache(CacheTypeIntermediateKeys, NewCryptoPolicy()) + key, _ := internal.NewCryptoKey(m, 123456, false, data) - cache.keys["testing"] = cacheEntry{ - key: key, - } + cache.keys.Set("testing", cacheEntry{ + key: &cachedCryptoKey{CryptoKey: key, refs: 1}, + }) e := &envelopeEncryption{ - intermediateKeys: cache, + ikCache: cache, } assert.False(t, sec.IsClosed()) diff --git a/go/appencryption/go.mod b/go/appencryption/go.mod index e257de196..12566a83c 100644 --- a/go/appencryption/go.mod +++ b/go/appencryption/go.mod @@ -4,7 +4,6 @@ go 1.19 require ( github.com/aws/aws-sdk-go v1.46.7 - github.com/goburrow/cache v0.1.4 github.com/godaddy/asherah/go/securememory v0.1.5 github.com/google/uuid v1.4.0 github.com/pkg/errors v0.9.1 diff --git a/go/appencryption/go.sum b/go/appencryption/go.sum index ff7b57929..918b9dca6 100644 --- a/go/appencryption/go.sum +++ b/go/appencryption/go.sum @@ -2,10 +2,6 @@ github.com/awnumar/memcall v0.1.2 h1:7gOfDTL+BJ6nnbtAp9+HQzUFjtP1hEseRQq8eP055QY github.com/awnumar/memcall v0.1.2/go.mod h1:S911igBPR9CThzd/hYQQmTc9SWNu3ZHIlCGaWsWsoJo= github.com/awnumar/memguard v0.22.3 h1:b4sgUXtbUjhrGELPbuC62wU+BsPQy+8lkWed9Z+pj0Y= github.com/awnumar/memguard v0.22.3/go.mod h1:mmGunnffnLHlxE5rRgQc3j+uwPZ27eYb61ccr8Clz2Y= -github.com/aws/aws-sdk-go v1.44.190 h1:QC+Pf/Ooj7Waf2obOPZbIQOqr00hy4h54j3ZK9mvHcc= -github.com/aws/aws-sdk-go v1.44.190/go.mod h1:aVsgQcEevwlmQ7qHE9I3h+dtQgpqhFB+i8Phjh7fkwI= -github.com/aws/aws-sdk-go v1.44.250 h1:IuGUO2Hafv/b0yYKI5UPLQShYDx50BCIQhab/H1sX2M= -github.com/aws/aws-sdk-go v1.44.250/go.mod h1:aVsgQcEevwlmQ7qHE9I3h+dtQgpqhFB+i8Phjh7fkwI= github.com/aws/aws-sdk-go v1.44.265 h1:rlBuD8OYjM5Vfcf7jDa264oVHqlPqY7y7o+JmrjNFUc= github.com/aws/aws-sdk-go v1.44.265/go.mod h1:aVsgQcEevwlmQ7qHE9I3h+dtQgpqhFB+i8Phjh7fkwI= github.com/aws/aws-sdk-go v1.46.7 h1:IjvAWeiJZlbETOemOwvheN5L17CvKvKW0T1xOC6d3Sc= @@ -14,8 +10,6 @@ github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ3 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/goburrow/cache v0.1.4 h1:As4KzO3hgmzPlnaMniZU9+VmoNYseUhuELbxy9mRBfw= -github.com/goburrow/cache v0.1.4/go.mod h1:cDFesZDnIlrHoNlMYqqMpCRawuXulgx+y7mXU8HZ+/c= github.com/godaddy/asherah/go/securememory v0.1.4 h1:1UlEPE5Q2wK1fbGwjIBtlGO02teLFBFk7dNIvdWOzNQ= github.com/godaddy/asherah/go/securememory v0.1.4/go.mod h1:grCFdMhT5CY8h+E+Qb1Abhd6uBDIxrwVh6Tulsc9gj4= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= diff --git a/go/appencryption/go.work.sum b/go/appencryption/go.work.sum index 7d4baf243..560dcbb5c 100644 --- a/go/appencryption/go.work.sum +++ b/go/appencryption/go.work.sum @@ -384,47 +384,7 @@ golang.org/x/image v0.0.0-20190802002840-cff245a6509b h1:+qEpEAPhDZ1o0x3tHzZTQDA golang.org/x/lint v0.0.0-20200302205851-738671d3881b h1:Wh+f8QHJXR411sJR8/vRBTZ7YapZaRvUcLFFJhusH0k= golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028 h1:4+4C/Iv2U4fMZBiMCc98MG1In4gJY5YRhtpDNeDeHWs= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= -golang.org/x/oauth2 v0.0.0-20221014153046-6fdb5e3db783 h1:nt+Q6cXKz4MosCSpnbMtqiQ8Oz0pxTef2B4Vca2lvfk= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/term v0.4.0 h1:O7UWfv5+A2qiuulQk30kVinPoMtoIPeVaKLEgLpVkvg= -golang.org/x/term v0.4.0/go.mod h1:9P2UbLfCdcvo3p/nzKvsmas4TnlujnuoV9hGgYzW1lQ= golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U= -golang.org/x/text v0.6.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= -golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= -google.golang.org/api v0.22.0 h1:J1Pl9P2lnmYFSJvgs70DKELqHNh8CNWXPbud4njEE2s= -google.golang.org/appengine v1.6.7 h1:FZR1q0exgwxzPzp/aF+VccGrSfxfPpkBqjIIEq3ru6c= -google.golang.org/cloud v0.0.0-20151119220103-975617b05ea8 h1:Cpp2P6TPjujNoC5M2KHY6g7wfyLYfIWRZaSdIKfDasA= -gopkg.in/airbrake/gobrake.v2 v2.0.9 h1:7z2uVWwn7oVeeugY1DtlPAy5H+KYgB1KeKTnqjNatLo= -gopkg.in/alecthomas/kingpin.v2 v2.2.6 h1:jMFz6MfLP0/4fUyZle81rXUoxOBFi19VUFKVDOQfozc= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= -gopkg.in/cheggaaa/pb.v1 v1.0.25 h1:Ev7yu1/f6+d+b3pi5vPdRPc6nNtP1umSfcWiEfRqv6I= -gopkg.in/errgo.v2 v2.1.0 h1:0vLT13EuvQ0hNvakwLuFZ/jYrLp5F3kcWHXdRggjCE8= -gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4= -gopkg.in/gemnasium/logrus-airbrake-hook.v2 v2.1.2 h1:OAj3g0cR6Dx/R07QgQe8wkA9RNjB2u4i700xBkIT4e0= -gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc= -gopkg.in/natefinch/lumberjack.v2 v2.0.0 h1:1Lc07Kr7qY4U2YPouBjpCLxpiyxIVoxqXgkXLknAOE8= -gopkg.in/resty.v1 v1.12.0 h1:CuXP0Pjfw9rOuY6EP+UvtNvt5DSqHpIxILZKT/quCZI= -gopkg.in/square/go-jose.v2 v2.5.1 h1:7odma5RETjNHWJnR32wx8t+Io4djHE1PqxCFx3iiZ2w= -gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= -gotest.tools/gotestsum v1.8.2 h1:szU3TaSz8wMx/uG+w/A2+4JUPwH903YYaMI9yOOYAyI= -honnef.co/go/tools v0.0.1-2020.1.3 h1:sXmLre5bzIR6ypkjXCDI3jHPssRhc8KD/Ome589sc3U= -k8s.io/api v0.22.5 h1:xk7C+rMjF/EGELiD560jdmwzrB788mfcHiNbMQLIVI8= -k8s.io/apimachinery v0.22.5 h1:cIPwldOYm1Slq9VLBRPtEYpyhjIm1C6aAMAoENuvN9s= -k8s.io/apiserver v0.22.5 h1:71krQxCUz218ecb+nPhfDsNB6QgP1/4EMvi1a2uYBlg= -k8s.io/client-go v0.22.5 h1:I8Zn/UqIdi2r02aZmhaJ1hqMxcpfJ3t5VqvHtctHYFo= -k8s.io/code-generator v0.19.7 h1:kM/68Y26Z/u//TFc1ggVVcg62te8A2yQh57jBfD0FWQ= -k8s.io/component-base v0.22.5 h1:U0eHqZm7mAFE42hFwYhY6ze/MmVaW00JpMrzVsQmzYE= -k8s.io/cri-api v0.25.0 h1:INwdXsCDSA/0hGNdPxdE2dQD6ft/5K1EaKXZixvSQxg= -k8s.io/gengo v0.0.0-20201113003025-83324d819ded h1:JApXBKYyB7l9xx+DK7/+mFjC7A9Bt5A93FPvFD0HIFE= -k8s.io/klog/v2 v2.30.0 h1:bUO6drIvCIsvZ/XFgfxoGFQU/a4Qkh0iAlvUR7vlHJw= -k8s.io/kube-openapi v0.0.0-20201113171705-d219536bb9fd h1:sOHNzJIkytDF6qadMNKhhDRpc6ODik8lVC6nOur7B2c= -k8s.io/kubernetes v1.13.0 h1:qTfB+u5M92k2fCCCVP2iuhgwwSOv1EkAkvQY1tQODD8= -k8s.io/utils v0.0.0-20210930125809-cb0fa318a74b h1:wxEMGetGMur3J1xuGLQY7GEQYg9bZxKn3tKo5k/eYcs= -lukechampine.com/frand v1.4.2 h1:RzFIpOvkMXuPMBb9maa4ND4wjBn71E1Jpf8BzJHMaVw= -rsc.io/binaryregexp v0.2.0 h1:HfqmD5MEmC0zvwBuF187nq9mdnXjXsSivRiXN7SmRkE= -rsc.io/quote/v3 v3.1.0 h1:9JKUTTIUgS6kzR9mK1YuGKv6Nl+DijDNIc0ghT58FaY= -rsc.io/sampler v1.3.0 h1:7uVkIFmeBqHfdjD+gZwtXXI+RODJ2Wc4O7MPEh/QiW4= -sigs.k8s.io/apiserver-network-proxy/konnectivity-client v0.0.15 h1:4uqm9Mv+w2MmBYD+F4qf/v6tDFUdPOk29C095RbU5mY= -sigs.k8s.io/structured-merge-diff/v4 v4.1.2 h1:Hr/htKFmJEbtMgS/UD0N+gtgctAqz81t3nu+sPzynno= -sigs.k8s.io/yaml v1.2.0 h1:kr/MCeFWJWTwyaHoR9c8EjH9OumOmoF9YGiZd7lFm/Q= diff --git a/go/appencryption/integrationtest/traces/README.md b/go/appencryption/integrationtest/traces/README.md new file mode 100644 index 000000000..f86b8903d --- /dev/null +++ b/go/appencryption/integrationtest/traces/README.md @@ -0,0 +1,23 @@ +# SessionFactory Performance Report + +This package benchmarks the performance of the `SessionFactory` class and +its dependencies. It compares Metastore and KMS access patterns with +different cache configurations. + +The source code for this package is derived from the package of the same +name in the [Mango Cache](https://github.com/goburrow/cache) project. See +[NOTICE](../../pkg/cache/internal/NOTICE) for copyright and +licensing information. + +## Traces + +Name | Source +------------ | ------ +Glimpse | Authors of the LIRS algorithm - retrieved from [Cache2k](https://github.com/cache2k/cache2k-benchmark) +Multi2 | Authors of the LIRS algorithm - retrieved from [Cache2k](https://github.com/cache2k/cache2k-benchmark) +OLTP | Authors of the ARC algorithm - retrieved from [Cache2k](https://github.com/cache2k/cache2k-benchmark) +ORMBusy | GmbH - retrieved from [Cache2k](https://github.com/cache2k/cache2k-benchmark) +Sprite | Authors of the LIRS algorithm - retrieved from [Cache2k](https://github.com/cache2k/cache2k-benchmark) +Wikipedia | [WikiBench](http://www.wikibench.eu/) +YouTube | [University of Massachusetts](http://traces.cs.umass.edu/index.php/Network/Network) +WebSearch | [University of Massachusetts](http://traces.cs.umass.edu/index.php/Storage/Storage) diff --git a/go/appencryption/integrationtest/traces/cache2k.go b/go/appencryption/integrationtest/traces/cache2k.go new file mode 100644 index 000000000..6ab86a36e --- /dev/null +++ b/go/appencryption/integrationtest/traces/cache2k.go @@ -0,0 +1,40 @@ +package traces + +import ( + "bufio" + "context" + "encoding/binary" + "io" +) + +type cache2kProvider struct { + r *bufio.Reader +} + +// NewCache2kProvider returns a Provider which items are from traces +// in Cache2k repository (https://github.com/cache2k/cache2k-benchmark). +func NewCache2kProvider(r io.Reader) Provider { + return &cache2kProvider{ + r: bufio.NewReader(r), + } +} + +func (p *cache2kProvider) Provide(ctx context.Context, keys chan<- interface{}) { + defer close(keys) + + v := make([]byte, 4) + + for { + _, err := p.r.Read(v) + if err != nil { + return + } + + k := binary.LittleEndian.Uint32(v) + select { + case <-ctx.Done(): + return + case keys <- k: + } + } +} diff --git a/go/appencryption/integrationtest/traces/cache2k_test.go b/go/appencryption/integrationtest/traces/cache2k_test.go new file mode 100644 index 000000000..eeafd8b31 --- /dev/null +++ b/go/appencryption/integrationtest/traces/cache2k_test.go @@ -0,0 +1,201 @@ +package traces + +import "testing" + +func TestRequestORMBusy(t *testing.T) { + for _, p := range policies { + p := p + t.Run(p, func(t *testing.T) { + t.Parallel() + opt := options{ + policy: p, + cacheSize: 1000, + reportInterval: 40000, + maxItems: 4000000, + } + testRequest(t, NewCache2kProvider, opt, + "trace-mt-db-*-busy.trc.bin.bz2", "request_ormbusy-"+p+".txt") + }) + } +} + +func TestSizeORMBusy(t *testing.T) { + for _, p := range policies { + p := p + t.Run(p, func(t *testing.T) { + t.Parallel() + opt := options{ + policy: p, + cacheSize: 250, + maxItems: 1000000, + } + testSize(t, NewCache2kProvider, opt, + "trace-mt-db-*-busy.trc.bin.bz2", "size_ormbusy-"+p+".txt") + }) + } +} + +func TestRequestORMNight(t *testing.T) { + for _, p := range policies { + p := p + t.Run(p, func(t *testing.T) { + t.Parallel() + opt := options{ + policy: p, + cacheSize: 1000, + reportInterval: 40000, + maxItems: 4000000, + } + testRequest(t, NewCache2kProvider, opt, + "trace-mt-db-*-night.trc.bin.bz2", "request_ormnight-"+p+".txt") + }) + } +} + +func TestSizeORMNight(t *testing.T) { + for _, p := range policies { + p := p + t.Run(p, func(t *testing.T) { + t.Parallel() + opt := options{ + policy: p, + cacheSize: 250, + maxItems: 1000000, + } + testSize(t, NewCache2kProvider, opt, + "trace-mt-db-*-night.trc.bin.bz2", "size_ormnight-"+p+".txt") + }) + } +} + +func TestRequestGlimpse(t *testing.T) { + for _, p := range policies { + p := p + t.Run(p, func(t *testing.T) { + t.Parallel() + opt := options{ + policy: p, + cacheSize: 512, + reportInterval: 100, + maxItems: 6000, + } + testRequest(t, NewCache2kProvider, opt, + "trace-glimpse.trc.bin.gz", "request_glimpse-"+p+".txt") + }) + } +} + +func TestSizeGlimpse(t *testing.T) { + for _, p := range policies { + p := p + t.Run(p, func(t *testing.T) { + t.Parallel() + opt := options{ + policy: p, + cacheSize: 125, + maxItems: 6000, + } + testSize(t, NewCache2kProvider, opt, + "trace-glimpse.trc.bin.gz", "size_glimpse-"+p+".txt") + }) + } +} + +func TestRequestOLTP(t *testing.T) { + for _, p := range policies { + p := p + t.Run(p, func(t *testing.T) { + t.Parallel() + opt := options{ + policy: p, + cacheSize: 1000, + reportInterval: 1000, + maxItems: 900000, + } + testRequest(t, NewCache2kProvider, opt, + "trace-oltp.trc.bin.gz", "request_oltp-"+p+".txt") + }) + } +} + +func TestSizeOLTP(t *testing.T) { + for _, p := range policies { + p := p + t.Run(p, func(t *testing.T) { + t.Parallel() + opt := options{ + policy: p, + cacheSize: 250, + maxItems: 500000, + } + testSize(t, NewCache2kProvider, opt, + "trace-oltp.trc.bin.gz", "size_oltp-"+p+".txt") + }) + } +} + +func TestRequestSprite(t *testing.T) { + for _, p := range policies { + p := p + t.Run(p, func(t *testing.T) { + t.Parallel() + opt := options{ + policy: p, + cacheSize: 1000, + reportInterval: 1000, + maxItems: 120000, + } + testRequest(t, NewCache2kProvider, opt, + "trace-sprite.trc.bin.gz", "request_sprite-"+p+".txt") + }) + } +} + +func TestSizeSprite(t *testing.T) { + for _, p := range policies { + p := p + t.Run(p, func(t *testing.T) { + t.Parallel() + opt := options{ + policy: p, + cacheSize: 25, + maxItems: 120000, + } + testSize(t, NewCache2kProvider, opt, + "trace-sprite.trc.bin.gz", "size_sprite-"+p+".txt") + }) + } +} + +func TestRequestMulti2(t *testing.T) { + for _, p := range policies { + p := p + t.Run(p, func(t *testing.T) { + t.Parallel() + opt := options{ + policy: p, + cacheSize: 1000, + reportInterval: 200, + maxItems: 25000, + } + testRequest(t, NewCache2kProvider, opt, + "trace-multi2.trc.bin.gz", "request_multi2-"+p+".txt") + }) + } +} + +func TestSizeMulti2(t *testing.T) { + for _, p := range policies { + p := p + t.Run(p, func(t *testing.T) { + t.Parallel() + opt := options{ + policy: p, + cacheSize: 250, + maxItems: 25000, + } + testSize(t, NewCache2kProvider, opt, + "trace-multi2.trc.bin.gz", "size_multi2-"+p+".txt") + }) + } +} diff --git a/go/appencryption/integrationtest/traces/combine-png.sh b/go/appencryption/integrationtest/traces/combine-png.sh new file mode 100755 index 000000000..08c49d8d7 --- /dev/null +++ b/go/appencryption/integrationtest/traces/combine-png.sh @@ -0,0 +1,10 @@ +#!/bin/sh +set -e +# NAMES="financial zipf" +NAMES="financial oltp ormbusy ormnight multi2 youtube websearch zipf" +FORMAT="png" +FILES="" +for N in $NAMES; do + FILES="$FILES out/$N-requests.$FORMAT out/$N-cachesize.$FORMAT" +done +gm montage -mode concatenate -tile 4x $FILES "out/report-session-cache.$FORMAT" diff --git a/go/appencryption/integrationtest/traces/combine.sh b/go/appencryption/integrationtest/traces/combine.sh new file mode 100755 index 000000000..3c16a36b4 --- /dev/null +++ b/go/appencryption/integrationtest/traces/combine.sh @@ -0,0 +1,11 @@ +#!/bin/sh +set -e + +if [ -z "$FORMAT" ]; then + FORMAT="png" +else + FORMAT="${FORMAT%% *}" +fi + +FILES=$(ls out/*-requests.$FORMAT out/*-cachesize.$FORMAT | sort) +gm montage -mode concatenate -tile 4x $FILES "out/report.$FORMAT" diff --git a/go/appencryption/integrationtest/traces/data/dl-address.sh b/go/appencryption/integrationtest/traces/data/dl-address.sh new file mode 100755 index 000000000..b7e05af6b --- /dev/null +++ b/go/appencryption/integrationtest/traces/data/dl-address.sh @@ -0,0 +1,4 @@ +#!/bin/sh +FILE="proj1-traces.tar.gz" +curl -O "http://cseweb.ucsd.edu/classes/fa07/cse240a/$FILE" +tar xvzf "$FILE" diff --git a/go/appencryption/integrationtest/traces/data/dl-cache2k.sh b/go/appencryption/integrationtest/traces/data/dl-cache2k.sh new file mode 100755 index 000000000..c32a6a0cd --- /dev/null +++ b/go/appencryption/integrationtest/traces/data/dl-cache2k.sh @@ -0,0 +1,9 @@ +#!/bin/sh +set -e + +FILES="trace-cpp.trc.bin.gz trace-glimpse.trc.bin.gz trace-mt-db-20160419-busy.trc.bin.bz2 trace-multi2.trc.bin.gz trace-oltp.trc.bin.gz trace-sprite.trc.bin.gz" +for F in $FILES; do + if [ ! -f "$F" ]; then + curl -L -O "https://github.com/cache2k/cache2k-benchmark/raw/master/traces/src/main/resources/org/cache2k/benchmark/traces/$F" + fi +done diff --git a/go/appencryption/integrationtest/traces/data/dl-storage.sh b/go/appencryption/integrationtest/traces/data/dl-storage.sh new file mode 100755 index 000000000..aba198ffc --- /dev/null +++ b/go/appencryption/integrationtest/traces/data/dl-storage.sh @@ -0,0 +1,8 @@ +#!/bin/sh +set -e +FILES="WebSearch1.spc.bz2 Financial2.spc.bz2" +for F in $FILES; do + if [ ! -f "$F" ]; then + curl -O "http://skuld.cs.umass.edu/traces/storage/$F" + fi +done diff --git a/go/appencryption/integrationtest/traces/data/dl-wikipedia.sh b/go/appencryption/integrationtest/traces/data/dl-wikipedia.sh new file mode 100755 index 000000000..5ac5f59ed --- /dev/null +++ b/go/appencryption/integrationtest/traces/data/dl-wikipedia.sh @@ -0,0 +1,8 @@ +#!/bin/sh +set -e +FILES="wiki.1191201596.gz" +for F in $FILES; do + if [ ! -f "$F" ]; then + curl -O "http://www.wikibench.eu/wiki/2007-10/$F" + fi +done diff --git a/go/appencryption/integrationtest/traces/data/dl-youtube.sh b/go/appencryption/integrationtest/traces/data/dl-youtube.sh new file mode 100755 index 000000000..2b5334973 --- /dev/null +++ b/go/appencryption/integrationtest/traces/data/dl-youtube.sh @@ -0,0 +1,16 @@ +#!/bin/sh +set -e +FILE="youtube_traces.tgz" +if [ ! -f "$FILE" ]; then + curl -O "http://skuld.cs.umass.edu/traces/network/$FILE" +fi +tar xzf "$FILE" + +rm youtube.parsed.*.24.dat +rm youtube.parsed.*.S1.dat + +for FILE in youtube.parsed.*.dat; do + # YYMMDD + NAME="$(echo "$FILE" | sed -e 's/\([0-9]\{2\}\)\([0-9]\{2\}\)\([0-9]\{2\}\)\(\.dat\)/\3\1\2\4/')" + mv "$FILE" "$NAME" +done diff --git a/go/appencryption/integrationtest/traces/files.go b/go/appencryption/integrationtest/traces/files.go new file mode 100644 index 000000000..a5e44cdd6 --- /dev/null +++ b/go/appencryption/integrationtest/traces/files.go @@ -0,0 +1,167 @@ +package traces + +import ( + "compress/bzip2" + "compress/gzip" + "fmt" + "io" + "os" + "path/filepath" + "strings" +) + +type readSeekCloser interface { + io.ReadCloser + io.Seeker +} + +type gzipFile struct { + r *gzip.Reader + f *os.File +} + +func newGzipFile(f *os.File) *gzipFile { + r, err := gzip.NewReader(f) + if err != nil { + panic(err) + } + + return &gzipFile{ + r: r, + f: f, + } +} + +func (f *gzipFile) Read(p []byte) (int, error) { + return f.r.Read(p) +} + +func (f *gzipFile) Seek(offset int64, whence int) (int64, error) { + n, err := f.f.Seek(offset, whence) + if err != nil { + return n, err + } + + f.r.Reset(f.f) + + return n, nil +} + +func (f *gzipFile) Close() error { + err1 := f.r.Close() + + if err2 := f.f.Close(); err2 != nil { + return err2 + } + + return err1 +} + +type bzip2File struct { + r io.Reader + f *os.File +} + +func newBzip2File(f *os.File) *bzip2File { + return &bzip2File{ + r: bzip2.NewReader(f), + f: f, + } +} + +func (f *bzip2File) Read(p []byte) (int, error) { + return f.r.Read(p) +} + +func (f *bzip2File) Seek(offset int64, whence int) (int64, error) { + n, err := f.f.Seek(offset, whence) + if err != nil { + return n, err + } + + f.r = bzip2.NewReader(f.f) + + return n, nil +} + +func (f *bzip2File) Close() error { + return f.f.Close() +} + +type filesReader struct { + io.Reader + files []readSeekCloser +} + +func openFilesGlob(pattern string) (*filesReader, error) { + files, err := filepath.Glob(pattern) + if err != nil { + return nil, err + } + + if len(files) == 0 { + return nil, fmt.Errorf("%s not found", pattern) + } + + return openFiles(files...) +} + +func openFiles(files ...string) (*filesReader, error) { + r := &filesReader{} + r.files = make([]readSeekCloser, 0, len(files)) + readers := make([]io.Reader, 0, len(files)) + + for _, name := range files { + f, err := os.Open(name) + if err != nil { + r.Close() + return nil, err + } + + var rs readSeekCloser + if strings.HasSuffix(name, ".gz") { + rs = newGzipFile(f) + } else if strings.HasSuffix(name, ".bz2") { + rs = newBzip2File(f) + } else { + rs = f + } + + r.files = append(r.files, rs) + readers = append(readers, rs) + } + + r.Reader = io.MultiReader(readers...) + + return r, nil +} + +func (r *filesReader) Close() error { + var err error + + for _, f := range r.files { + e := f.Close() + if err != nil && e != nil { + err = e + } + } + + return err +} + +func (r *filesReader) Reset() error { + readers := make([]io.Reader, 0, len(r.files)) + + for _, f := range r.files { + _, err := f.Seek(0, 0) + if err != nil { + return err + } + + readers = append(readers, f) + } + + r.Reader = io.MultiReader(readers...) + + return nil +} diff --git a/go/appencryption/integrationtest/traces/out/report-session-cache.png b/go/appencryption/integrationtest/traces/out/report-session-cache.png new file mode 100644 index 000000000..41d09620b Binary files /dev/null and b/go/appencryption/integrationtest/traces/out/report-session-cache.png differ diff --git a/go/appencryption/integrationtest/traces/report.go b/go/appencryption/integrationtest/traces/report.go new file mode 100644 index 000000000..ad729000c --- /dev/null +++ b/go/appencryption/integrationtest/traces/report.go @@ -0,0 +1,314 @@ +package traces + +import ( + "context" + "fmt" + "io" + "math/rand" + "time" + + "github.com/rcrowley/go-metrics" + + "github.com/godaddy/asherah/go/appencryption" + "github.com/godaddy/asherah/go/appencryption/internal" + "github.com/godaddy/asherah/go/appencryption/pkg/crypto/aead" + "github.com/godaddy/asherah/go/appencryption/pkg/kms" + "github.com/godaddy/asherah/go/appencryption/pkg/persistence" +) + +type Stats struct { + RequestCount uint64 + KMSOpCount uint64 + KMSEncryptCount uint64 + KMSDecryptCount uint64 + MetastoreOpCount uint64 + MetastoreLoadCount uint64 + MetastoreLoadLatestCount uint64 + MetastoreStoreCount uint64 + OpRate float64 +} + +type Reporter interface { + Report(Stats, options) +} + +type Provider interface { + Provide(ctx context.Context, keys chan<- interface{}) +} + +type reporter struct { + w io.Writer + headerPrinted bool +} + +func NewReporter(w io.Writer) Reporter { + return &reporter{w: w} +} + +func (r *reporter) Report(st Stats, opt options) { + if !r.headerPrinted { + fmt.Fprintf(r.w, "Requests,KMSOps,KMSEncrypts,KMSDecrypts,MetastoreOps,MetastoreLoads,MetastoreLoadLatests,MetastoreStores,OpRate,CacheSize\n") + r.headerPrinted = true + } + + fmt.Fprintf( + r.w, + "%d,%d,%d,%d,%d,%d,%d,%d,%.04f,%d\n", + st.RequestCount, + st.KMSOpCount, + st.KMSEncryptCount, + st.KMSDecryptCount, + st.MetastoreOpCount, + st.MetastoreLoadCount, + st.MetastoreLoadLatestCount, + st.MetastoreStoreCount, + st.OpRate, + opt.cacheSize) +} + +// trackedKMS is a KeyManagementService that tracks the number of encrypt and +// decrypt operations. +type trackedKMS struct { + appencryption.KeyManagementService + + decryptCounter metrics.Counter + encryptCounter metrics.Counter +} + +func newTrackedKMS(kms appencryption.KeyManagementService) *trackedKMS { + return &trackedKMS{ + KeyManagementService: kms, + decryptCounter: metrics.NewCounter(), + encryptCounter: metrics.NewCounter(), + } +} + +func (t *trackedKMS) DecryptKey(ctx context.Context, key []byte) ([]byte, error) { + t.decryptCounter.Inc(1) + return t.KeyManagementService.DecryptKey(ctx, key) +} + +func (t *trackedKMS) EncryptKey(ctx context.Context, key []byte) ([]byte, error) { + t.encryptCounter.Inc(1) + return t.KeyManagementService.EncryptKey(ctx, key) +} + +// delayedMetastore is a Metastore that delays all operations by a configurable +// amount of time. +type delayedMetastore struct { + m *persistence.MemoryMetastore + delay time.Duration + jitter time.Duration + + loadCounter metrics.Counter + loadLatestCounter metrics.Counter + storeCounter metrics.Counter +} + +func newDelayedMetastore(delay time.Duration, jitter time.Duration) *delayedMetastore { + return &delayedMetastore{ + m: persistence.NewMemoryMetastore(), + delay: delay, + jitter: jitter, + + loadCounter: metrics.NewCounter(), + loadLatestCounter: metrics.NewCounter(), + storeCounter: metrics.NewCounter(), + } +} + +func (d *delayedMetastore) delayWithJitter() { + ch := make(chan int) + go func() { + randJitter := int64(0) + if d.jitter > 0 { + randJitter = rand.Int63n(int64(d.jitter)) + } + + if d.delay > 0 { + time.Sleep(d.delay + time.Duration(randJitter)) + } + + ch <- 1 + }() + + <-ch +} + +func (d *delayedMetastore) Load(ctx context.Context, keyID string, created int64) (*appencryption.EnvelopeKeyRecord, error) { + d.loadCounter.Inc(1) + + d.delayWithJitter() + + return d.m.Load(ctx, keyID, created) +} + +func (d *delayedMetastore) LoadLatest(ctx context.Context, keyID string) (*appencryption.EnvelopeKeyRecord, error) { + d.loadLatestCounter.Inc(1) + + d.delayWithJitter() + + return d.m.LoadLatest(ctx, keyID) +} + +func (d *delayedMetastore) Store(ctx context.Context, keyID string, created int64, envelope *appencryption.EnvelopeKeyRecord) (bool, error) { + d.storeCounter.Inc(1) + + d.delayWithJitter() + + return d.m.Store(ctx, keyID, created, envelope) +} + +type options struct { + policy string + cacheSize int + reportInterval int + maxItems int +} + +var policies = []string{ + "session-legacy", + "session-slru", + "shared-slru", + + "shared-lru", + "shared-tinylfu", + "shared-lfu", +} + +const ( + product = "enclibrary" + service = "asherah" + staticKey = "thisIsAStaticMasterKeyForTesting" + payloadSizeBytes = 100 +) + +var c = aead.NewAES256GCM() + +//nolint:gocyclo +func benchmarkSessionFactory(p Provider, r Reporter, opt options) { + static, err := kms.NewStatic(staticKey, c) + if err != nil { + panic(err) + } + + km := newTrackedKMS(static) + config := getConfig(opt) + ms := newDelayedMetastore(5, 5) + + factory := appencryption.NewSessionFactory( + config, + ms, + km, + c, + ) + defer factory.Close() + + randomBytes := internal.GetRandBytes(payloadSizeBytes) + + keys := make(chan interface{}, 100) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go p.Provide(ctx, keys) + + stats := Stats{} + + for i := 0; ; { + if opt.maxItems > 0 && i >= opt.maxItems { + break + } + + k, ok := <-keys + if !ok { + break + } + + sess, err := factory.GetSession(fmt.Sprintf("partition-%v", k)) + if err != nil { + panic(err) + } + + _, err = sess.Encrypt(ctx, randomBytes) + sess.Close() + + if err != nil { + fmt.Printf("encrypt fail: i=%d, err=%v\n", i, err) + continue + } + + i++ + if opt.reportInterval > 0 && i%opt.reportInterval == 0 { + metastoreStats(&stats, ms, km, uint64(i)) + r.Report(stats, opt) + } + } + + if opt.reportInterval == 0 { + metastoreStats(&stats, ms, km, uint64(opt.maxItems)) + r.Report(stats, opt) + } +} + +func getConfig(opt options) *appencryption.Config { + policy := appencryption.NewCryptoPolicy( + // appencryption.WithRevokeCheckInterval(10 * time.Second), + ) + + policy.CreateDatePrecision = time.Minute + + switch opt.policy { + case "session-legacy": + policy.CacheSessions = true + policy.SessionCacheMaxSize = opt.cacheSize + case "session-slru": + policy.CacheSessions = true + policy.SessionCacheMaxSize = opt.cacheSize + policy.SessionCacheEvictionPolicy = "slru" + case "shared-slru": + policy.CacheSessions = false + policy.IntermediateKeyCacheMaxSize = opt.cacheSize + policy.IntermediateKeyCacheEvictionPolicy = "slru" + policy.SharedIntermediateKeyCache = true + case "shared-lru": + policy.CacheSessions = false + policy.IntermediateKeyCacheMaxSize = opt.cacheSize + policy.IntermediateKeyCacheEvictionPolicy = "lru" + policy.SharedIntermediateKeyCache = true + case "shared-tinylfu": + policy.CacheSessions = false + policy.IntermediateKeyCacheMaxSize = opt.cacheSize + policy.IntermediateKeyCacheEvictionPolicy = "tinylfu" + policy.SharedIntermediateKeyCache = true + case "shared-lfu": + policy.CacheSessions = false + policy.IntermediateKeyCacheMaxSize = opt.cacheSize + policy.IntermediateKeyCacheEvictionPolicy = "lfu" + policy.SharedIntermediateKeyCache = true + default: + panic(fmt.Sprintf("unknown policy: %s", opt.policy)) + } + + return &appencryption.Config{ + Policy: policy, + Product: product, + Service: service, + } +} + +// metastoreStats populates the cache stats for the appencryption metastore. +func metastoreStats(stats *Stats, ms *delayedMetastore, kms *trackedKMS, requests uint64) { + stats.RequestCount = requests + + stats.MetastoreLoadCount = uint64(ms.loadCounter.Count()) + stats.MetastoreLoadLatestCount = uint64(ms.loadLatestCounter.Count()) + stats.MetastoreStoreCount = uint64(ms.storeCounter.Count()) + stats.MetastoreOpCount = stats.MetastoreLoadCount + stats.MetastoreLoadLatestCount + stats.MetastoreStoreCount + + stats.KMSDecryptCount = uint64(kms.decryptCounter.Count()) + stats.KMSEncryptCount = uint64(kms.encryptCounter.Count()) + stats.KMSOpCount = stats.KMSDecryptCount + stats.KMSEncryptCount + + stats.OpRate = float64(stats.MetastoreOpCount+stats.KMSOpCount) / float64(stats.RequestCount) +} diff --git a/go/appencryption/integrationtest/traces/report.sh b/go/appencryption/integrationtest/traces/report.sh new file mode 100755 index 000000000..6d856ecc8 --- /dev/null +++ b/go/appencryption/integrationtest/traces/report.sh @@ -0,0 +1,29 @@ +#!/bin/bash +set -e + +report() { + NAME="$1" + TESTARGS="-p 1 -timeout=3h -run=$NAME" + go test -v $TESTARGS | tee "out/$NAME.txt" + + NAME=$(echo "$NAME" | tr '[:upper:]' '[:lower:]') + ./visualize-request.sh out/request_$NAME-*.txt + for OUTPUT in out.*; do + mv -v "$OUTPUT" "out/$NAME-requests.${OUTPUT#*.}" + done + ./visualize-size.sh out/size_$NAME-*.txt + for OUTPUT in out.*; do + mv -v "$OUTPUT" "out/$NAME-cachesize.${OUTPUT#*.}" + done +} + +# use first arg or default to a small subset +TRACES="$@" +if [ -z "$TRACES" ]; then + TRACES="Financial OLTP ORMBusy Zipf" +fi + +# TRACES="Multi2 ORMBusy ORMNight Glimpse OLTP Sprite Financial WebSearch Wikipedia YouTube Zipf" +for TRACE in $TRACES; do + report $TRACE +done diff --git a/go/appencryption/integrationtest/traces/report_test.go b/go/appencryption/integrationtest/traces/report_test.go new file mode 100644 index 000000000..a32fb006e --- /dev/null +++ b/go/appencryption/integrationtest/traces/report_test.go @@ -0,0 +1,55 @@ +package traces + +import ( + "io" + "os" + "path/filepath" + "testing" +) + +func testRequest(t *testing.T, newProvider func(io.Reader) Provider, opt options, traceFiles string, reportFile string) { + r, err := openFilesGlob(filepath.Join("data", traceFiles)) + if err != nil { + t.Skip(err) + } + defer r.Close() + provider := newProvider(r) + + w, err := os.Create(filepath.Join("out", reportFile)) + if err != nil { + t.Fatal(err) + } + defer w.Close() + reporter := NewReporter(w) + + benchmarkSessionFactory(provider, reporter, opt) +} + +func testSize(t *testing.T, newProvider func(io.Reader) Provider, opt options, traceFiles, reportFile string) { + r, err := openFilesGlob(filepath.Join("data", traceFiles)) + if err != nil { + t.Skip(err) + } + defer r.Close() + + w, err := os.Create(filepath.Join("out", reportFile)) + if err != nil { + t.Fatal(err) + } + defer w.Close() + + reporter := NewReporter(w) + + for i := 0; i < 5; i++ { + provider := newProvider(r) + + benchmarkSessionFactory(provider, reporter, opt) + + err = r.Reset() + if err != nil { + t.Fatal(err) + } + + opt.cacheSize += opt.cacheSize + } +} diff --git a/go/appencryption/integrationtest/traces/storage.go b/go/appencryption/integrationtest/traces/storage.go new file mode 100644 index 000000000..b1622ed48 --- /dev/null +++ b/go/appencryption/integrationtest/traces/storage.go @@ -0,0 +1,63 @@ +package traces + +import ( + "bufio" + "bytes" + "context" + "io" + "strconv" +) + +type storageProvider struct { + r *bufio.Reader +} + +// NewStorageProvider returns a Provider with items are from +// Storage traces by the University of Massachusetts +// (http://traces.cs.umass.edu/index.php/Storage/Storage). +func NewStorageProvider(r io.Reader) Provider { + return &storageProvider{ + r: bufio.NewReader(r), + } +} + +func (p *storageProvider) Provide(ctx context.Context, keys chan<- interface{}) { + defer close(keys) + + for { + b, err := p.r.ReadBytes('\n') + if err != nil { + return + } + + k := p.parse(b) + if k > 0 { + select { + case <-ctx.Done(): + return + case keys <- k: + } + } + } +} + +func (p *storageProvider) parse(b []byte) uint64 { + idx := bytes.IndexByte(b, ',') + if idx < 0 { + return 0 + } + + b = b[idx+1:] + + idx = bytes.IndexByte(b, ',') + if idx < 0 { + return 0 + } + + k, err := strconv.ParseUint(string(b[:idx]), 10, 64) + if err != nil { + return 0 + } + + return k +} diff --git a/go/appencryption/integrationtest/traces/storage_test.go b/go/appencryption/integrationtest/traces/storage_test.go new file mode 100644 index 000000000..17fbd683f --- /dev/null +++ b/go/appencryption/integrationtest/traces/storage_test.go @@ -0,0 +1,71 @@ +package traces + +import "testing" + +func TestRequestWebSearch(t *testing.T) { + for _, p := range policies { + p := p + t.Run(p, func(t *testing.T) { + t.Parallel() + opt := options{ + policy: p, + cacheSize: 256000, + reportInterval: 10000, + maxItems: 1000000, + } + testRequest(t, NewStorageProvider, opt, + "WebSearch*.spc.bz2", "request_websearch-"+p+".txt") + }) + } +} + +func TestRequestFinancial(t *testing.T) { + for _, p := range policies { + p := p + t.Run(p, func(t *testing.T) { + t.Parallel() + opt := options{ + policy: p, + cacheSize: 1000, + reportInterval: 30000, + maxItems: 3000000, + } + testRequest(t, NewStorageProvider, opt, + "Financial*.spc.bz2", "request_financial-"+p+".txt") + }) + } +} + +func TestSizeWebSearch(t *testing.T) { + for _, p := range policies { + p := p + opt := options{ + policy: p, + cacheSize: 25000, + maxItems: 1000000, + } + + t.Run(p, func(t *testing.T) { + t.Parallel() + testSize(t, NewStorageProvider, opt, + "WebSearch*.spc.bz2", "size_websearch-"+p+".txt") + }) + } +} + +func TestSizeFinancial(t *testing.T) { + for _, p := range policies { + p := p + opt := options{ + policy: p, + cacheSize: 250, + maxItems: 1000000, + } + + t.Run(p, func(t *testing.T) { + t.Parallel() + testSize(t, NewStorageProvider, opt, + "Financial*.spc.bz2", "size_financial-"+p+".txt") + }) + } +} diff --git a/go/appencryption/integrationtest/traces/visualize-request.sh b/go/appencryption/integrationtest/traces/visualize-request.sh new file mode 100755 index 000000000..392edcaae --- /dev/null +++ b/go/appencryption/integrationtest/traces/visualize-request.sh @@ -0,0 +1,31 @@ +#!/bin/bash +if [ -z "$FORMAT" ]; then + #FORMAT='svg size 400,300 font "Helvetica,10"' + # FORMAT='png size 220,180 small noenhanced' + FORMAT='png size 400,300 small noenhanced' +fi +OUTPUT="out.${FORMAT%% *}" +PLOTARG="" + +for f in "$@"; do + if [ ! -z "$PLOTARG" ]; then + PLOTARG="$PLOTARG," + fi + NAME="$(basename "$f")" + NAME="${NAME%.*}" + NAME="${NAME#*_}" + PLOTARG="$PLOTARG '$f' every ::1 using 1:9 with lines title '$NAME'" +done + +ARG="set datafile separator ',';\ + set xlabel 'Requests';\ + set xtics rotate by 45 right;\ + set ylabel 'Op Rate' offset 1;\ + set yrange [0:];\ + set key bottom right;\ + set colors classic;\ + set terminal $FORMAT;\ + set output '$OUTPUT';\ + plot $PLOTARG" + +gnuplot -e "$ARG" diff --git a/go/appencryption/integrationtest/traces/visualize-size.sh b/go/appencryption/integrationtest/traces/visualize-size.sh new file mode 100755 index 000000000..2ca8dcd73 --- /dev/null +++ b/go/appencryption/integrationtest/traces/visualize-size.sh @@ -0,0 +1,31 @@ +#!/bin/bash +if [ -z "$FORMAT" ]; then + #FORMAT='svg size 400,300 font "Helvetica,10"' + # FORMAT='png size 220,180 small noenhanced' + FORMAT='png size 400,300 small noenhanced' +fi +OUTPUT="out.${FORMAT%% *}" +PLOTARG="" + +for f in "$@"; do + if [ ! -z "$PLOTARG" ]; then + PLOTARG="$PLOTARG," + fi + NAME="$(basename "$f")" # remove path + NAME="${NAME%.*}" # remove extension + NAME="${NAME#*_}" # remove prefix + PLOTARG="$PLOTARG '$f' every ::1 using 10:9:xtic(10) with lines title '$NAME'" # 10:9 is cache size:op rate +done + +ARG="set datafile separator ',';\ + set xlabel 'Cache Size';\ + set xtics rotate by 45 right;\ + set ylabel 'Op Rate' offset 1;\ + set yrange [0:];\ + set key bottom right;\ + set colors classic;\ + set terminal $FORMAT;\ + set output '$OUTPUT';\ + plot $PLOTARG" + +gnuplot -e "$ARG" diff --git a/go/appencryption/integrationtest/traces/wikipedia.go b/go/appencryption/integrationtest/traces/wikipedia.go new file mode 100644 index 000000000..f6a890447 --- /dev/null +++ b/go/appencryption/integrationtest/traces/wikipedia.go @@ -0,0 +1,62 @@ +package traces + +import ( + "bufio" + "bytes" + "context" + "io" +) + +type wikipediaProvider struct { + r *bufio.Reader +} + +func NewWikipediaProvider(r io.Reader) Provider { + return &wikipediaProvider{ + r: bufio.NewReader(r), + } +} + +func (p *wikipediaProvider) Provide(ctx context.Context, keys chan<- interface{}) { + defer close(keys) + + for { + b, err := p.r.ReadBytes('\n') + if err != nil { + return + } + + v := p.parse(b) + if v != "" { + select { + case <-ctx.Done(): + return + case keys <- v: + } + } + } +} + +func (p *wikipediaProvider) parse(b []byte) string { + // Get url + idx := bytes.Index(b, []byte("http://")) + if idx < 0 { + return "" + } + + b = b[idx+len("http://"):] + + // Get path + idx = bytes.IndexByte(b, '/') + if idx > 0 { + b = b[idx:] + } + + // Skip params + idx = bytes.IndexAny(b, "? ") + if idx > 0 { + b = b[:idx] + } + + return string(b) +} diff --git a/go/appencryption/integrationtest/traces/wikipedia_test.go b/go/appencryption/integrationtest/traces/wikipedia_test.go new file mode 100644 index 000000000..a49c801c4 --- /dev/null +++ b/go/appencryption/integrationtest/traces/wikipedia_test.go @@ -0,0 +1,36 @@ +package traces + +import "testing" + +func TestRequestWikipedia(t *testing.T) { + for _, p := range policies { + p := p + t.Run(p, func(t *testing.T) { + t.Parallel() + opt := options{ + policy: p, + cacheSize: 512, + reportInterval: 10000, + maxItems: 1000000, + } + testRequest(t, NewWikipediaProvider, opt, + "wiki.*.gz", "request_wikipedia-"+p+".txt") + }) + } +} + +func TestSizeWikipedia(t *testing.T) { + for _, p := range policies { + p := p + t.Run(p, func(t *testing.T) { + t.Parallel() + opt := options{ + policy: p, + cacheSize: 250, + maxItems: 100000, + } + testSize(t, NewWikipediaProvider, opt, + "wiki.*.gz", "size_wikipedia-"+p+".txt") + }) + } +} diff --git a/go/appencryption/integrationtest/traces/youtube.go b/go/appencryption/integrationtest/traces/youtube.go new file mode 100644 index 000000000..72a1c975b --- /dev/null +++ b/go/appencryption/integrationtest/traces/youtube.go @@ -0,0 +1,55 @@ +package traces + +import ( + "bufio" + "bytes" + "context" + "io" +) + +type youtubeProvider struct { + r *bufio.Reader +} + +func NewYoutubeProvider(r io.Reader) Provider { + return &youtubeProvider{ + r: bufio.NewReader(r), + } +} + +func (p *youtubeProvider) Provide(ctx context.Context, keys chan<- interface{}) { + defer close(keys) + + for { + b, err := p.r.ReadBytes('\n') + if err != nil { + return + } + + v := p.parse(b) + if v != "" { + select { + case <-ctx.Done(): + return + case keys <- v: + } + } + } +} + +func (p *youtubeProvider) parse(b []byte) string { + // Get video id + idx := bytes.Index(b, []byte("GETVIDEO ")) + if idx < 0 { + return "" + } + + b = b[idx+len("GETVIDEO "):] + + idx = bytes.IndexAny(b, "& ") + if idx > 0 { + b = b[:idx] + } + + return string(b) +} diff --git a/go/appencryption/integrationtest/traces/youtube_test.go b/go/appencryption/integrationtest/traces/youtube_test.go new file mode 100644 index 000000000..16e8482f8 --- /dev/null +++ b/go/appencryption/integrationtest/traces/youtube_test.go @@ -0,0 +1,36 @@ +package traces + +import "testing" + +func TestRequestYouTube(t *testing.T) { + for _, p := range policies { + p := p + t.Run(p, func(t *testing.T) { + t.Parallel() + opt := options{ + policy: p, + cacheSize: 1000, + reportInterval: 2000, + maxItems: 200000, + } + testRequest(t, NewYoutubeProvider, opt, + "youtube.parsed.0803*.dat", "request_youtube-"+p+".txt") + }) + } +} + +func TestSizeYouTube(t *testing.T) { + for _, p := range policies { + p := p + t.Run(p, func(t *testing.T) { + t.Parallel() + opt := options{ + policy: p, + cacheSize: 250, + maxItems: 100000, + } + testSize(t, NewYoutubeProvider, opt, + "youtube.parsed.0803*.dat", "size_youtube-"+p+".txt") + }) + } +} diff --git a/go/appencryption/integrationtest/traces/zipf.go b/go/appencryption/integrationtest/traces/zipf.go new file mode 100644 index 000000000..b18dc9643 --- /dev/null +++ b/go/appencryption/integrationtest/traces/zipf.go @@ -0,0 +1,37 @@ +package traces + +import ( + "context" + "math/rand" +) + +type zipfProvider struct { + r *rand.Zipf + n int +} + +func NewZipfProvider(s float64, num int) Provider { + if s <= 1.0 || num <= 0 { + panic("invalid zipf parameters") + } + + r := rand.New(rand.NewSource(1)) + + return &zipfProvider{ + r: rand.NewZipf(r, s, 1.0, 1<<16-1), + n: num, + } +} + +func (p *zipfProvider) Provide(ctx context.Context, keys chan<- interface{}) { + defer close(keys) + + for i := 0; i < p.n; i++ { + v := p.r.Uint64() + select { + case <-ctx.Done(): + return + case keys <- v: + } + } +} diff --git a/go/appencryption/integrationtest/traces/zipf_test.go b/go/appencryption/integrationtest/traces/zipf_test.go new file mode 100644 index 000000000..6673072a4 --- /dev/null +++ b/go/appencryption/integrationtest/traces/zipf_test.go @@ -0,0 +1,70 @@ +package traces + +import ( + "os" + "path/filepath" + "testing" +) + +func TestRequestZipf(t *testing.T) { + for _, p := range policies { + p := p + t.Run(p, func(t *testing.T) { + t.Parallel() + testRequestZipf(t, p, "request_zipf-"+p+".txt") + }) + } +} + +func testRequestZipf(t *testing.T, policy, reportFile string) { + opt := options{ + policy: policy, + cacheSize: 1000, + reportInterval: 1000, + maxItems: 100000, + } + + provider := NewZipfProvider(1.01, opt.maxItems) + + w, err := os.Create(filepath.Join("out", reportFile)) + if err != nil { + t.Fatal(err) + } + defer w.Close() + reporter := NewReporter(w) + // benchmarkCache(provider, reporter, opt) + benchmarkSessionFactory(provider, reporter, opt) +} + +func TestSizeZipf(t *testing.T) { + for _, p := range policies { + p := p + t.Run(p, func(t *testing.T) { + t.Parallel() + testSizeZipf(t, p, "size_zipf-"+p+".txt") + }) + } +} + +func testSizeZipf(t *testing.T, policy, reportFile string) { + opt := options{ + cacheSize: 250, + policy: policy, + maxItems: 100000, + } + + w, err := os.Create(filepath.Join("out", reportFile)) + if err != nil { + t.Fatal(err) + } + defer w.Close() + + reporter := NewReporter(w) + + for i := 0; i < 5; i++ { + provider := NewZipfProvider(1.01, opt.maxItems) + // benchmarkCache(provider, reporter, opt) + benchmarkSessionFactory(provider, reporter, opt) + opt.cacheSize += opt.cacheSize + } +} diff --git a/go/appencryption/internal/key.go b/go/appencryption/internal/key.go index e6ec7f506..957e62dbd 100644 --- a/go/appencryption/internal/key.go +++ b/go/appencryption/internal/key.go @@ -4,6 +4,7 @@ import ( "fmt" "sync" "sync/atomic" + "time" "github.com/godaddy/asherah/go/securememory" ) @@ -43,6 +44,11 @@ func (k *CryptoKey) Close() { // Close destroys the underlying buffer for this key. func (k *CryptoKey) close() { + // k.secret is nil when the key is created for test. + if k.secret == nil { + return + } + k.secret.Close() } @@ -55,6 +61,16 @@ func (k *CryptoKey) String() string { return fmt.Sprintf("CryptoKey(%p){secret(%p)}", k, k.secret) } +// WithBytes implements BytesAccessor. +func (k *CryptoKey) WithBytes(action func([]byte) error) error { + return k.secret.WithBytes(action) +} + +// WithBytesFunc implements BytesFuncAccessor. +func (k *CryptoKey) WithBytesFunc(action func([]byte) ([]byte, error)) ([]byte, error) { + return k.secret.WithBytesFunc(action) +} + // NewCryptoKey creates a CryptoKey using the given key. Note that the underlying array will be wiped after the function // exits. func NewCryptoKey(factory securememory.SecretFactory, created int64, revoked bool, key []byte) (*CryptoKey, error) { @@ -104,16 +120,43 @@ func GenerateKey(factory securememory.SecretFactory, created int64, size int) (* }, nil } -// WithKey takes in a CryptoKey, makes the underlying bytes readable, and passes them to the function provided. -// A reference MUST not be stored to the provided bytes. The underlying array will be wiped after the function -// exits. -func WithKey(key *CryptoKey, action func([]byte) error) error { - return key.secret.WithBytes(action) +type BytesAccessor interface { + WithBytes(action func([]byte) error) error } -// WithKeyFunc takes in a CryptoKey, makes the underlying bytes readable, and passes them to the function provided. -// A reference MUST not be stored to the provided bytes. The underlying array will be wiped after the function -// exits. -func WithKeyFunc(key *CryptoKey, action func([]byte) ([]byte, error)) ([]byte, error) { - return key.secret.WithBytesFunc(action) +// WithKey takes in BytesAccessor, e.g., a CryptoKey, makes the underlying bytes readable, and passes them to the +// function provided. A reference MUST not be stored to the provided bytes. The underlying array will be wiped after +// the function exits. +func WithKey(key BytesAccessor, action func([]byte) error) error { + return key.WithBytes(action) +} + +type BytesFuncAccessor interface { + WithBytesFunc(action func([]byte) ([]byte, error)) ([]byte, error) +} + +// WithKeyFunc takes in a BytesFuncAccessor, e.g., a CryptoKey, makes the underlying bytes readable, and passes them to +// the function provided. A reference MUST not be stored to the provided bytes. The underlying array will be wiped after +// the function exits. +func WithKeyFunc(key BytesFuncAccessor, action func([]byte) ([]byte, error)) ([]byte, error) { + return key.WithBytesFunc(action) +} + +type Revokable interface { + // Revoked returns true if the key is revoked. + Revoked() bool + + // Created returns the time the CryptoKey was created as a Unix epoch in seconds. + Created() int64 +} + +// IsKeyInvalid checks if the key is revoked or expired. +func IsKeyInvalid(key Revokable, expireAfter time.Duration) bool { + return key.Revoked() || IsKeyExpired(key.Created(), expireAfter) +} + +// IsKeyExpired checks if the key's created timestamp is older than the +// allowed duration. +func IsKeyExpired(created int64, expireAfter time.Duration) bool { + return time.Now().After(time.Unix(created, 0).Add(expireAfter)) } diff --git a/go/appencryption/key_cache.go b/go/appencryption/key_cache.go new file mode 100644 index 000000000..0a65184b2 --- /dev/null +++ b/go/appencryption/key_cache.go @@ -0,0 +1,387 @@ +package appencryption + +import ( + "fmt" + "sync" + "time" + + "github.com/godaddy/asherah/go/appencryption/internal" + "github.com/godaddy/asherah/go/appencryption/pkg/cache" + "github.com/godaddy/asherah/go/appencryption/pkg/log" +) + +// cachedCryptoKey is a wrapper around a CryptoKey that tracks concurrent access. +type cachedCryptoKey struct { + *internal.CryptoKey + + rw sync.RWMutex // protects concurrent access to the key's reference count + refs int // number of references to this key +} + +// Close decrements the reference count for this key. If the reference count +// reaches zero, the underlying key is closed. +func (c *cachedCryptoKey) Close() { + c.rw.Lock() + defer c.rw.Unlock() + + c.refs-- + + if c.refs > 0 { + return + } + + log.Debugf("closing cached key: %s, refs=%d", c.CryptoKey, c.refs) + c.CryptoKey.Close() +} + +// increment the reference count for this key. +func (c *cachedCryptoKey) increment() { + c.rw.Lock() + defer c.rw.Unlock() + + c.refs++ +} + +// cacheEntry contains a key and the time it was loaded from the metastore. +type cacheEntry struct { + loadedAt time.Time + key *cachedCryptoKey +} + +// newCacheEntry returns a cacheEntry with the current time and key. +func newCacheEntry(k *internal.CryptoKey) cacheEntry { + return cacheEntry{ + loadedAt: time.Now(), + key: &cachedCryptoKey{ + CryptoKey: k, + + // initialize with a reference count of 1 to represent the + // reference held by the cache + refs: 1, + }, + } +} + +// cacheKey formats an id and create timestamp to a usable +// key for storage in a cache. +func cacheKey(id string, create int64) string { + return fmt.Sprintf("%s-%d", id, create) +} + +// keyCacher contains cached keys for reuse. +type keyCacher interface { + GetOrLoad(id KeyMeta, loader func(KeyMeta) (*internal.CryptoKey, error)) (*cachedCryptoKey, error) + GetOrLoadLatest(id string, loader func(KeyMeta) (*internal.CryptoKey, error)) (*cachedCryptoKey, error) + Close() error +} + +// Verify keyCache implements the cache interface. +var _ keyCacher = (*keyCache)(nil) + +// keyCache is used to persist session based keys and destroys them on a call to close. +type keyCache struct { + policy *CryptoPolicy + + keys cache.Interface[string, cacheEntry] + rw sync.RWMutex // protects concurrent access to the cache + + latest map[string]KeyMeta + + cacheType cacheKeyType +} + +// cacheKeyType is used to identify the type of key cache. +type cacheKeyType int + +// String returns a string representation of the cacheKeyType. +func (t cacheKeyType) String() string { + switch t { + case CacheTypeSystemKeys: + return "system" + case CacheTypeIntermediateKeys: + return "intermediate" + default: + return "unknown" + } +} + +const ( + // CacheTypeSystemKeys is used to cache system keys. + CacheTypeSystemKeys cacheKeyType = iota + // CacheTypeIntermediateKeys is used to cache intermediate keys. + CacheTypeIntermediateKeys +) + +// newKeyCache constructs a cache object that is ready to use. +func newKeyCache(t cacheKeyType, policy *CryptoPolicy) (c *keyCache) { + cacheMaxSize := DefaultKeyCacheMaxSize + cachePolicy := "" + + switch t { + case CacheTypeSystemKeys: + cacheMaxSize = policy.SystemKeyCacheMaxSize + cachePolicy = policy.SystemKeyCacheEvictionPolicy + case CacheTypeIntermediateKeys: + cacheMaxSize = policy.IntermediateKeyCacheMaxSize + cachePolicy = policy.IntermediateKeyCacheEvictionPolicy + } + + c = &keyCache{ + policy: policy, + latest: make(map[string]KeyMeta), + + cacheType: t, + } + + onEvict := func(key string, value cacheEntry) { + log.Debugf("[onEvict] closing key -- id: %s\n", key) + + value.key.Close() + } + + cb := cache.New[string, cacheEntry](cacheMaxSize) + + if cachePolicy != "" { + log.Debugf("setting cache policy to %s", cachePolicy) + + cb.WithPolicy(cache.CachePolicy(cachePolicy)) + } + + if cacheMaxSize < 100 { + log.Debugf("cache size is less than 100, setting synchronous eviction policy") + + cb.Synchronous() + } + + c.keys = cb.WithEvictFunc(onEvict).Build() + + return c +} + +// isReloadRequired returns true if the check interval has elapsed +// since the timestamp provided. +func isReloadRequired(entry cacheEntry, checkInterval time.Duration) bool { + if entry.key.Revoked() { + // this key is revoked so no need to reload it again. + return false + } + + return entry.loadedAt.Add(checkInterval).Before(time.Now()) +} + +// GetOrLoad returns a key from the cache if it's already been loaded. If the key +// is not present in the cache it will retrieve the key using the provided loader +// and store the key if an error is not returned. +func (c *keyCache) GetOrLoad(id KeyMeta, loader func(KeyMeta) (*internal.CryptoKey, error)) (*cachedCryptoKey, error) { + c.rw.Lock() + defer c.rw.Unlock() + + if k, ok := c.getFresh(id); ok { + return tracked(k), nil + } + + k, err := c.load(id, loader) + if err != nil { + return nil, err + } + + return tracked(k), nil +} + +// tracked increments the reference count for the provided key, then returns it. +func tracked(key *cachedCryptoKey) *cachedCryptoKey { + key.increment() + return key +} + +// getFresh returns a key from the cache if present AND fresh. +// A cached value is considered stale if its time in cache +// has exceeded the RevokeCheckInterval. +// The second return value indicates the successful retrieval of a +// fresh key. +func (c *keyCache) getFresh(meta KeyMeta) (*cachedCryptoKey, bool) { + if e, ok := c.read(meta); ok && !isReloadRequired(e, c.policy.RevokeCheckInterval) { + return e.key, true + } else if ok { + log.Debugf("%s stale -- id: %s-%d\n", c, meta.ID, e.key.Created()) + return e.key, false + } + + return nil, false +} + +// load retrieves a key using the provided loader. If the key is present in the cache +// it will be updated with the latest revocation status and last loaded time. Otherwise +// a new cache entry will be created and stored in the cache. +// +// load maintains the latest entry for each distinct KeyMeta.ID which can be accessed using +// KeyMeta.Created == 0. +func (c *keyCache) load(meta KeyMeta, loader func(KeyMeta) (*internal.CryptoKey, error)) (*cachedCryptoKey, error) { + k, err := loader(meta) + if err != nil { + return nil, err + } + + e, ok := c.read(meta) + + switch { + case ok: + // existing key in cache. update revoked status and last loaded time and close key + // we just loaded since we don't need it + e.key.SetRevoked(k.Revoked()) + e.loadedAt = time.Now() + + k.Close() + default: + // first time loading this key into cache or we have an ID-only key with mismatched + // create timestamps + e = newCacheEntry(k) + } + + c.write(meta, e) + + return e.key, nil +} + +// read retrieves the entry from the cache matching the provided ID if present. The second +// return value indicates whether or not the key was present in the cache. +func (c *keyCache) read(meta KeyMeta) (cacheEntry, bool) { + id := cacheKey(meta.ID, meta.Created) + + if meta.IsLatest() { + if latest, ok := c.getLatestKeyMeta(meta.ID); ok { + id = cacheKey(latest.ID, latest.Created) + } + } + + e, ok := c.keys.Get(id) + if !ok { + log.Debugf("%s miss -- id: %s\n", c, id) + } + + return e, ok +} + +// getLatestKeyMeta returns the KeyMeta for the latest key for the provided ID. +// The second return value indicates whether or not the key was present in the cache. +func (c *keyCache) getLatestKeyMeta(id string) (KeyMeta, bool) { + latest, ok := c.latest[cacheKey(id, 0)] + + return latest, ok +} + +// mapLatestKeyMeta maps the provided latest KeyMeta to the provided ID. +func (c *keyCache) mapLatestKeyMeta(id string, latest KeyMeta) { + c.latest[cacheKey(id, 0)] = latest +} + +// write entry e to the cache using id as the key. +func (c *keyCache) write(meta KeyMeta, e cacheEntry) { + if meta.IsLatest() { + meta = KeyMeta{ID: meta.ID, Created: e.key.Created()} + + c.mapLatestKeyMeta(meta.ID, meta) + } else if latest, ok := c.getLatestKeyMeta(meta.ID); !ok || latest.Created < e.key.Created() { + c.mapLatestKeyMeta(meta.ID, meta) + } + + id := cacheKey(meta.ID, meta.Created) + + if existing, ok := c.keys.Get(id); ok { + log.Debugf("%s update -> old: %s, new: %s, id: %s\n", c, existing.key, e.key, id) + } + + log.Debugf("%s write -> key: %s, id: %s\n", c, e.key, id) + c.keys.Set(id, e) +} + +// GetOrLoadLatest returns the latest key from the cache matching the provided ID +// if it's already been loaded. If the key is not present in the cache it will +// retrieve the key using the provided KeyLoader and store the key if successful. +// In the event that the cached or loaded key is invalid (see [keyCache.IsInvalid]), +// the key will be reloaded and the cache updated. +func (c *keyCache) GetOrLoadLatest(id string, loader func(KeyMeta) (*internal.CryptoKey, error)) (*cachedCryptoKey, error) { + c.rw.Lock() + defer c.rw.Unlock() + + meta := KeyMeta{ID: id} + + key, ok := c.getFresh(meta) + if !ok { + log.Debugf("%s.GetOrLoadLatest get miss -- id: %s\n", c, id) + + var err error + key, err = c.load(meta, loader) + + if err != nil { + return nil, err + } + } + + if c.IsInvalid(key.CryptoKey) { + reloaded, err := loader(meta) + if err != nil { + return nil, err + } + + log.Debugf("%s.GetOrLoadLatest reload -- invalid: %s, new: %s, id: %s\n", c, key, reloaded, id) + + e := newCacheEntry(reloaded) + + // ensure we've got a cache entry with a fully qualified cache key + c.write(KeyMeta{ID: id, Created: reloaded.Created()}, e) + + return tracked(e.key), nil + } + + return tracked(key), nil +} + +// IsInvalid returns true if the provided key is no longer valid. +func (c *keyCache) IsInvalid(key *internal.CryptoKey) bool { + return internal.IsKeyInvalid(key, c.policy.ExpireKeyAfter) +} + +// Close frees all memory locked by the keys in this cache. +// It MUST be called after a session is complete to avoid +// running into MEMLOCK limits. +func (c *keyCache) Close() error { + log.Debugf("%s closing\n", c) + + return c.keys.Close() +} + +// String returns a string representation of this cache. +func (c *keyCache) String() string { + return fmt.Sprintf("keyCache(%p){type=%s,size=%d,cap=%d}", c, c.cacheType, c.keys.Len(), c.keys.Capacity()) +} + +// Verify neverCache implements the cache interface. +var _ keyCacher = (*neverCache)(nil) + +type neverCache struct{} + +// GetOrLoad always executes the provided function to load the value. It never actually caches. +func (neverCache) GetOrLoad(id KeyMeta, loader func(KeyMeta) (*internal.CryptoKey, error)) (*cachedCryptoKey, error) { + k, err := loader(id) + if err != nil { + return nil, err + } + + return &cachedCryptoKey{CryptoKey: k}, nil +} + +// GetOrLoadLatest always executes the provided function to load the latest value. It never actually caches. +func (neverCache) GetOrLoadLatest(id string, loader func(KeyMeta) (*internal.CryptoKey, error)) (*cachedCryptoKey, error) { + k, err := loader(KeyMeta{ID: id}) + if err != nil { + return nil, err + } + + return &cachedCryptoKey{CryptoKey: k}, nil +} + +// Close is a no-op function to satisfy the cache interface. +func (neverCache) Close() error { + return nil +} diff --git a/go/appencryption/key_cache_benchmark_test.go b/go/appencryption/key_cache_benchmark_test.go new file mode 100644 index 000000000..434e32b04 --- /dev/null +++ b/go/appencryption/key_cache_benchmark_test.go @@ -0,0 +1,415 @@ +package appencryption + +import ( + "flag" + "fmt" + "sync/atomic" + "testing" + "time" + + "github.com/godaddy/asherah/go/securememory/memguard" + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + + "github.com/godaddy/asherah/go/appencryption/internal" + "github.com/godaddy/asherah/go/appencryption/pkg/log" +) + +var ( + secretFactory = new(memguard.SecretFactory) + created = time.Now().Unix() + enableDebug = flag.Bool("debug", false, "enable debug logging") +) + +func ConfigureLogging() { + if *enableDebug { + log.SetLogger(logger{}) + } +} + +func BenchmarkKeyCache_GetOrLoad_MultipleThreadsReadExistingKey(b *testing.B) { + ConfigureLogging() + + c := newKeyCache(CacheTypeIntermediateKeys, NewCryptoPolicy()) + + c.keys.Set(cacheKey(testKey, created), cacheEntry{ + key: &cachedCryptoKey{CryptoKey: internal.NewCryptoKeyForTest(created, false)}, + loadedAt: time.Now(), + }) + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + key, err := c.GetOrLoad(KeyMeta{testKey, created}, func(_ KeyMeta) (key *internal.CryptoKey, e error) { + // The passed function is irrelevant because we'll always find the value in the cache + return nil, errors.New("loader should not be executed") + }) + + assert.NoError(b, err) + assert.Equal(b, created, key.Created()) + } + }) +} + +func BenchmarkKeyCache_GetOrLoad_MultipleThreadsWriteSameKey(b *testing.B) { + ConfigureLogging() + + c := newKeyCache(CacheTypeIntermediateKeys, NewCryptoPolicy()) + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _, err := c.GetOrLoad(KeyMeta{testKey, created}, func(_ KeyMeta) (key *internal.CryptoKey, e error) { + // Add a delay to simulate time spent in performing a metastore read + time.Sleep(5 * time.Millisecond) + return internal.NewCryptoKeyForTest(created, false), nil + }) + + assert.NoError(b, err) + + latest, _ := c.getLatestKeyMeta(testKey) + latestKey := cacheKey(latest.ID, latest.Created) + + assert.Equal(b, created, c.keys.GetOrPanic(latestKey).key.Created()) + } + }) +} + +type logger struct{} + +func (logger) Debugf(format string, v ...interface{}) { + fmt.Printf(format, v...) +} + +func BenchmarkKeyCache_GetOrLoad_MultipleThreadsWriteUniqueKeys(b *testing.B) { + ConfigureLogging() + + var ( + c = newKeyCache(CacheTypeIntermediateKeys, NewCryptoPolicy()) + i int64 + ) + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + curr := atomic.AddInt64(&i, 1) - 1 + + loader := func(_ KeyMeta) (key *internal.CryptoKey, e error) { + // Add a delay to simulate time spent in performing a metastore read + return internal.NewCryptoKeyForTest(created, false), nil + } + + keyID := fmt.Sprintf("%s-%d", testKey, curr) + + _, err := c.GetOrLoad(KeyMeta{keyID, created}, loader) + assert.NoError(b, err) + + // ensure we have a "latest" entry for this key as well + latest, err := c.GetOrLoadLatest(keyID, loader) + assert.NoError(b, err) + assert.NotNil(b, latest) + } + }) + assert.NotNil(b, c.keys) + + expected := i + if expected > DefaultKeyCacheMaxSize { + expected = DefaultKeyCacheMaxSize + } + + assert.Equal(b, expected, int64(c.keys.Len())) +} + +func BenchmarkKeyCache_GetOrLoad_MultipleThreadsReadRevokedKey(b *testing.B) { + var ( + c = newKeyCache(CacheTypeIntermediateKeys, NewCryptoPolicy()) + created = time.Now().Add(-(time.Minute * 100)).Unix() + ) + + key, err := internal.NewCryptoKey(secretFactory, created, false, []byte("testing")) + + assert.NoError(b, err) + + cacheEntry := cacheEntry{ + key: &cachedCryptoKey{CryptoKey: key}, + loadedAt: time.Unix(created, 0), + } + + defer c.Close() + c.keys.Set(cacheKey(testKey, created), cacheEntry) + c.mapLatestKeyMeta(testKey, KeyMeta{testKey, created}) + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _, err := c.GetOrLoad(KeyMeta{testKey, created}, func(_ KeyMeta) (key *internal.CryptoKey, e error) { + return internal.NewCryptoKey(secretFactory, created, true, []byte("testing")) + }) + + assert.NoError(b, err) + + latest, _ := c.getLatestKeyMeta(testKey) + latestKey := cacheKey(latest.ID, latest.Created) + assert.Equal(b, created, c.keys.GetOrPanic(latestKey).key.Created()) + assert.True(b, c.keys.GetOrPanic(latestKey).key.Revoked()) + assert.True(b, c.keys.GetOrPanic(cacheKey(testKey, created)).key.Revoked()) + } + }) +} + +func BenchmarkKeyCache_GetOrLoad_MultipleThreadsRead_NeedReloadKey(b *testing.B) { + var ( + c = newKeyCache(CacheTypeIntermediateKeys, NewCryptoPolicy()) + created = time.Now().Add(-(time.Minute * 100)).Unix() + ) + + key, err := internal.NewCryptoKey(secretFactory, created, false, []byte("testing")) + + assert.NoError(b, err) + + cacheEntry := cacheEntry{ + key: &cachedCryptoKey{CryptoKey: key}, + loadedAt: time.Unix(created, 0), + } + + defer c.Close() + + c.keys.Set(cacheKey(testKey, created), cacheEntry) + c.mapLatestKeyMeta(testKey, KeyMeta{testKey, created}) + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + k, err := c.GetOrLoad(KeyMeta{testKey, created}, func(_ KeyMeta) (*internal.CryptoKey, error) { + // Note: this function should only happen on first load (although could execute more than once currently), if it doesn't, then something is broken + return internal.NewCryptoKey(secretFactory, created, false, []byte("testing")) + }) + + if err != nil { + b.Error(err) + } + if created != k.Created() { + b.Error("created mismatch") + } + } + }) +} + +func BenchmarkKeyCache_GetOrLoad_MultipleThreadsReadUniqueKeys(b *testing.B) { + c := newKeyCache(CacheTypeIntermediateKeys, NewCryptoPolicy()) + + for i := 0; i < b.N && i < DefaultKeyCacheMaxSize; i++ { + keyID := fmt.Sprintf(testKey+"-%d", i) + meta := KeyMeta{ID: keyID, Created: created} + + c.mapLatestKeyMeta(meta.ID, meta) + c.keys.Set(cacheKey(meta.ID, meta.Created), cacheEntry{ + key: &cachedCryptoKey{CryptoKey: internal.NewCryptoKeyForTest(created, false), refs: 1}, + loadedAt: time.Now(), + }) + } + + i := atomic.Int64{} + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + curr := i.Add(1) - 1 + curr = curr % DefaultKeyCacheMaxSize + + id := fmt.Sprintf(testKey+"-%d", curr) + key, err := c.GetOrLoad(KeyMeta{id, created}, func(_ KeyMeta) (key *internal.CryptoKey, e error) { + // The passed function is irrelevant because we'll always find the value in the cache + return nil, errors.New(fmt.Sprintf("loader should not be executed for id=%s", id)) + }) + assert.NoError(b, err) + assert.Equal(b, created, key.Created()) + } + }) +} + +func BenchmarkKeyCache_GetOrLoadLatest_MultipleThreadsReadExistingKey(b *testing.B) { + c := newKeyCache(CacheTypeIntermediateKeys, NewCryptoPolicy()) + + c.mapLatestKeyMeta(testKey, KeyMeta{testKey, created}) + c.keys.Set(cacheKey(testKey, created), cacheEntry{ + key: &cachedCryptoKey{CryptoKey: internal.NewCryptoKeyForTest(created, false)}, + loadedAt: time.Now(), + }) + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + key, err := c.GetOrLoadLatest(testKey, func(_ KeyMeta) (*internal.CryptoKey, error) { + // The passed function is irrelevant because we'll always find the value in the cache + return nil, nil + }) + assert.NoError(b, err) + assert.Equal(b, created, key.Created()) + } + }) +} + +func BenchmarkKeyCache_GetOrLoadLatest_MultipleThreadsWriteSameKey(b *testing.B) { + c := newKeyCache(CacheTypeIntermediateKeys, NewCryptoPolicy()) + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _, err := c.GetOrLoadLatest(testKey, func(_ KeyMeta) (*internal.CryptoKey, error) { + // Add a delay to simulate time spent in performing a metastore read + time.Sleep(5 * time.Millisecond) + return internal.NewCryptoKeyForTest(created, false), nil + }) + assert.NoError(b, err) + + latest, _ := c.getLatestKeyMeta(testKey) + latestKey := cacheKey(latest.ID, latest.Created) + assert.Equal(b, created, c.keys.GetOrPanic(latestKey).key.Created()) + } + }) +} + +func BenchmarkKeyCache_GetOrLoadLatest_MultipleThreadsWriteUniqueKey(b *testing.B) { + var ( + c = newKeyCache(CacheTypeIntermediateKeys, NewCryptoPolicy()) + i int64 + ) + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + curr := atomic.AddInt64(&i, 1) - 1 + _, err := c.GetOrLoadLatest(cacheKey(testKey, curr), func(_ KeyMeta) (*internal.CryptoKey, error) { + return internal.NewCryptoKeyForTest(created, false), nil + }) + assert.NoError(b, err) + } + }) + assert.NotNil(b, c.keys) + + expected := i + if expected > DefaultKeyCacheMaxSize { + expected = DefaultKeyCacheMaxSize + } + + assert.Equal(b, expected, int64(c.keys.Len())) +} + +func BenchmarkKeyCache_GetOrLoadLatest_MultipleThreadsReadStaleRevokedKey(b *testing.B) { + ConfigureLogging() + + var ( + c = newKeyCache(CacheTypeIntermediateKeys, NewCryptoPolicy()) + created = time.Now().Add(-(time.Minute * 100)).Unix() + ) + + key, err := internal.NewCryptoKey(secretFactory, created, false, []byte("testing")) + cacheEntry := cacheEntry{ + key: &cachedCryptoKey{CryptoKey: key, refs: 1}, + loadedAt: time.Unix(created, 0), + } + + assert.NoError(b, err) + + defer c.Close() + + meta := KeyMeta{ID: testKey, Created: created} + c.mapLatestKeyMeta(testKey, meta) + c.keys.Set(cacheKey(meta.ID, meta.Created), cacheEntry) + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + key, err := c.GetOrLoadLatest(testKey, func(_ KeyMeta) (key *internal.CryptoKey, e error) { + return internal.NewCryptoKey(secretFactory, time.Now().Unix(), true, []byte("testing")) + }) + + assert.NoError(b, err) + assert.True(b, key.Revoked()) + assert.Greater(b, key.Created(), created) + } + }) +} + +func BenchmarkKeyCache_GetOrLoadLatest_MultipleThreadsReadRevokedKey(b *testing.B) { + ConfigureLogging() + + var ( + c = newKeyCache(CacheTypeIntermediateKeys, NewCryptoPolicy()) + created = time.Now().Unix() + ) + + key, err := internal.NewCryptoKey(secretFactory, created, true, []byte("testing")) + cacheEntry := cacheEntry{ + key: &cachedCryptoKey{CryptoKey: key, refs: 1}, + loadedAt: time.Unix(created, 0), + } + + assert.NoError(b, err) + + defer c.Close() + + meta := KeyMeta{ID: testKey, Created: created} + c.mapLatestKeyMeta(testKey, meta) + c.keys.Set(cacheKey(meta.ID, meta.Created), cacheEntry) + + count := atomic.Int64{} + reloadCount := atomic.Int64{} + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + count.Add(1) + + key, err := c.GetOrLoadLatest(testKey, func(_ KeyMeta) (key *internal.CryptoKey, e error) { + reloadCount.Add(1) + + return internal.NewCryptoKey(secretFactory, time.Now().Unix(), false, []byte("testing")) + }) + + assert.NoError(b, err) + assert.False(b, key.Revoked()) + } + }) +} + +func BenchmarkKeyCache_GetOrLoadLatest_MultipleThreadsReadUniqueKeys(b *testing.B) { + ConfigureLogging() + + c := newKeyCache(CacheTypeIntermediateKeys, NewCryptoPolicy()) + + for i := 0; i < b.N && i < DefaultKeyCacheMaxSize; i++ { + keyID := fmt.Sprintf(testKey+"-%d", i) + meta := KeyMeta{ID: keyID, Created: created} + c.mapLatestKeyMeta(keyID, meta) + c.keys.Set(cacheKey(meta.ID, meta.Created), cacheEntry{ + key: &cachedCryptoKey{CryptoKey: internal.NewCryptoKeyForTest(created, false)}, + loadedAt: time.Now(), + }) + } + + i := atomic.Int64{} + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + curr := i.Add(1) - 1 + curr = curr % DefaultKeyCacheMaxSize + + keyID := fmt.Sprintf(testKey+"-%d", curr) + + key, err := c.GetOrLoadLatest(keyID, func(_ KeyMeta) (key *internal.CryptoKey, e error) { + // The passed function is irrelevant because we'll always find the value in the cache + return nil, errors.New(fmt.Sprintf("loader should not be executed for id=%s", keyID)) + }) + if err != nil { + b.Error(err) + } + + assert.Equal(b, created, key.Created()) + + key.Close() + } + }) +} diff --git a/go/appencryption/cache_test.go b/go/appencryption/key_cache_test.go similarity index 56% rename from go/appencryption/cache_test.go rename to go/appencryption/key_cache_test.go index 875801aff..2a86dcf78 100644 --- a/go/appencryption/cache_test.go +++ b/go/appencryption/key_cache_test.go @@ -10,7 +10,6 @@ import ( "github.com/pkg/errors" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" @@ -30,7 +29,7 @@ type CacheTestSuite struct { func (suite *CacheTestSuite) SetupTest() { suite.policy = NewCryptoPolicy() - suite.keyCache = newKeyCache(suite.policy) + suite.keyCache = newKeyCache(CacheTypeIntermediateKeys, suite.policy) suite.created = time.Now().Unix() } @@ -46,13 +45,14 @@ func (suite *CacheTestSuite) Test_CacheKey() { } func (suite *CacheTestSuite) Test_NewKeyCache() { - cache := newKeyCache(NewCryptoPolicy()) + cache := newKeyCache(CacheTypeIntermediateKeys, NewCryptoPolicy()) defer cache.Close() assert.NotNil(suite.T(), cache) assert.IsType(suite.T(), new(keyCache), cache) assert.NotNil(suite.T(), cache.keys) assert.NotNil(suite.T(), cache.policy) + assert.Equal(suite.T(), DefaultKeyCacheMaxSize, cache.keys.Capacity()) } func (suite *CacheTestSuite) Test_IsReloadRequired_WithIntervalNotElapsed() { @@ -60,7 +60,7 @@ func (suite *CacheTestSuite) Test_IsReloadRequired_WithIntervalNotElapsed() { if assert.NoError(suite.T(), err) { entry := cacheEntry{ loadedAt: time.Now(), - key: key, + key: &cachedCryptoKey{CryptoKey: key}, } defer key.Close() @@ -74,7 +74,7 @@ func (suite *CacheTestSuite) Test_IsReloadRequired_WithIntervalElapsed() { if assert.NoError(suite.T(), err) { entry := cacheEntry{ loadedAt: time.Now().Add(-2 * time.Hour), - key: key, + key: &cachedCryptoKey{CryptoKey: key}, } defer key.Close() @@ -89,7 +89,7 @@ func (suite *CacheTestSuite) Test_IsReloadRequired_WithRevoked() { entry := cacheEntry{ // Note this loadedAt would normally require reload loadedAt: time.Now().Add(-2 * time.Hour), - key: key, + key: &cachedCryptoKey{CryptoKey: key}, } defer key.Close() @@ -99,16 +99,16 @@ func (suite *CacheTestSuite) Test_IsReloadRequired_WithRevoked() { } func (suite *CacheTestSuite) TestKeyCache_GetOrLoad_WithCachedKeyNoReloadRequired() { - _, err := suite.keyCache.GetOrLoad(KeyMeta{testKey, suite.created}, keyLoaderFunc(func() (key *internal.CryptoKey, e error) { + _, err := suite.keyCache.GetOrLoad(KeyMeta{testKey, suite.created}, func(_ KeyMeta) (key *internal.CryptoKey, e error) { cryptoKey, err := internal.NewCryptoKey(secretFactory, suite.created, false, []byte("blah")) return cryptoKey, err - })) + }) assert.NoError(suite.T(), err) - key, err := suite.keyCache.GetOrLoad(KeyMeta{testKey, suite.created}, keyLoaderFunc(func() (*internal.CryptoKey, error) { + key, err := suite.keyCache.GetOrLoad(KeyMeta{testKey, suite.created}, func(_ KeyMeta) (*internal.CryptoKey, error) { return nil, errors.New("should not be called") - })) + }) assert.NoError(suite.T(), err) assert.NotNil(suite.T(), key) @@ -116,84 +116,86 @@ func (suite *CacheTestSuite) TestKeyCache_GetOrLoad_WithCachedKeyNoReloadRequire } func (suite *CacheTestSuite) TestKeyCache_GetOrLoad_WithEmptyCache() { - key, err := suite.keyCache.GetOrLoad(KeyMeta{testKey, suite.created}, keyLoaderFunc(func() (*internal.CryptoKey, error) { - cryptoKey, err := internal.NewCryptoKey(secretFactory, suite.created, false, []byte("blah")) - if err != nil { - return nil, err - } - return cryptoKey, nil - })) + meta := KeyMeta{ID: testKey, Created: suite.created} + key, err := suite.keyCache.GetOrLoad(meta, func(_ KeyMeta) (*internal.CryptoKey, error) { + return internal.NewCryptoKey(secretFactory, suite.created, false, []byte("blah")) + }) assert.NoError(suite.T(), err) assert.NotNil(suite.T(), key) assert.Equal(suite.T(), suite.created, key.Created()) - assert.Equal(suite.T(), suite.created, suite.keyCache.keys[cacheKey(testKey, 0)].key.Created()) + + latestKey, _ := suite.keyCache.getLatestKeyMeta(testKey) + assert.Equal(suite.T(), latestKey, meta) } func (suite *CacheTestSuite) TestKeyCache_GetOrLoad_DoesNotSetKeyOnError() { - key, err := suite.keyCache.GetOrLoad(KeyMeta{}, keyLoaderFunc(func() (*internal.CryptoKey, error) { + key, err := suite.keyCache.GetOrLoad(KeyMeta{}, func(_ KeyMeta) (*internal.CryptoKey, error) { return new(internal.CryptoKey), errors.New("error") - })) + }) if assert.Error(suite.T(), err) { assert.Nil(suite.T(), key) - assert.Empty(suite.T(), suite.keyCache.keys) + assert.Zero(suite.T(), suite.keyCache.keys.Len()) } } func (suite *CacheTestSuite) TestKeyCache_GetOrLoad_WithOldCachedKeyLoadNewerUpdatesLatest() { olderCreated := time.Now().Add(-(time.Hour * 24)).Unix() - _, err := suite.keyCache.GetOrLoad(KeyMeta{testKey, olderCreated}, keyLoaderFunc(func() (key *internal.CryptoKey, e error) { + _, err := suite.keyCache.GetOrLoad(KeyMeta{testKey, olderCreated}, func(_ KeyMeta) (key *internal.CryptoKey, e error) { cryptoKey, err := internal.NewCryptoKey(secretFactory, olderCreated, false, []byte("blah")) if err != nil { return nil, err } return cryptoKey, nil - })) + }) assert.NoError(suite.T(), err) - key, err := suite.keyCache.GetOrLoad(KeyMeta{testKey, suite.created}, keyLoaderFunc(func() (*internal.CryptoKey, error) { + key, err := suite.keyCache.GetOrLoad(KeyMeta{testKey, suite.created}, func(_ KeyMeta) (*internal.CryptoKey, error) { cryptoKey, err2 := internal.NewCryptoKey(secretFactory, suite.created, false, []byte("newerblah")) if err2 != nil { return nil, err2 } return cryptoKey, nil - })) + }) assert.NoError(suite.T(), err) assert.NotNil(suite.T(), key) assert.Equal(suite.T(), suite.created, key.Created()) - assert.Equal(suite.T(), suite.created, suite.keyCache.keys[cacheKey(testKey, 0)].key.Created()) - assert.Equal(suite.T(), suite.created, suite.keyCache.keys[cacheKey(testKey, suite.created)].key.Created()) - assert.Equal(suite.T(), olderCreated, suite.keyCache.keys[cacheKey(testKey, olderCreated)].key.Created()) + + latestKey, _ := suite.keyCache.getLatestKeyMeta(testKey) + assert.Equal(suite.T(), latestKey, KeyMeta{ID: testKey, Created: key.Created()}) + + assert.Equal(suite.T(), suite.created, suite.keyCache.keys.GetOrPanic(cacheKey(testKey, suite.created)).key.Created()) + assert.Equal(suite.T(), olderCreated, suite.keyCache.keys.GetOrPanic(cacheKey(testKey, olderCreated)).key.Created()) } func (suite *CacheTestSuite) TestKeyCache_GetOrLoad_WithCachedKeyReloadRequiredAndNowRevoked() { key, err := internal.NewCryptoKey(secretFactory, suite.created, false, []byte("blah")) if assert.NoError(suite.T(), err) { entry := cacheEntry{ - key: key, + key: &cachedCryptoKey{CryptoKey: key}, loadedAt: time.Now().Add(-2 * suite.policy.RevokeCheckInterval), } - suite.keyCache.keys[cacheKey(testKey, suite.created)] = entry - suite.keyCache.keys[cacheKey(testKey, 0)] = entry + suite.keyCache.keys.Set(cacheKey(testKey, suite.created), entry) + suite.keyCache.keys.Set(cacheKey(testKey, 0), entry) revokedKey, e := internal.NewCryptoKey(secretFactory, suite.created, true, []byte("blah")) if assert.NoError(suite.T(), e) { - key, err := suite.keyCache.GetOrLoad(KeyMeta{testKey, suite.created}, keyLoaderFunc(func() (*internal.CryptoKey, error) { + key, err := suite.keyCache.GetOrLoad(KeyMeta{testKey, suite.created}, func(_ KeyMeta) (*internal.CryptoKey, error) { return revokedKey, nil - })) + }) assert.NoError(suite.T(), err) assert.NotNil(suite.T(), key) assert.Equal(suite.T(), suite.created, key.Created()) assert.True(suite.T(), key.Revoked()) - assert.True(suite.T(), suite.keyCache.keys[cacheKey(testKey, 0)].key.Revoked()) + assert.True(suite.T(), suite.keyCache.keys.GetOrPanic(cacheKey(testKey, 0)).key.Revoked()) // Verify we closed the new one we loaded and kept the cached one open assert.True(suite.T(), revokedKey.IsClosed()) - assert.False(suite.T(), suite.keyCache.keys[cacheKey(testKey, suite.created)].key.IsClosed()) + assert.False(suite.T(), suite.keyCache.keys.GetOrPanic(cacheKey(testKey, suite.created)).key.IsClosed()) } } } @@ -204,44 +206,44 @@ func (suite *CacheTestSuite) TestKeyCache_GetOrLoad_WithCachedKeyReloadRequiredB if assert.NoError(suite.T(), err) { entry := cacheEntry{ - key: key, + key: &cachedCryptoKey{CryptoKey: key}, loadedAt: time.Unix(created, 0), } - suite.keyCache.keys[cacheKey(testKey, created)] = entry - suite.keyCache.keys[cacheKey(testKey, 0)] = entry + suite.keyCache.keys.Set(cacheKey(testKey, created), entry) + suite.keyCache.keys.Set(cacheKey(testKey, 0), entry) reloadedKey, e := internal.NewCryptoKey(secretFactory, created, false, []byte("blah")) assert.NoError(suite.T(), e) - key, err := suite.keyCache.GetOrLoad(KeyMeta{testKey, created}, keyLoaderFunc(func() (*internal.CryptoKey, error) { + key, err := suite.keyCache.GetOrLoad(KeyMeta{testKey, created}, func(_ KeyMeta) (*internal.CryptoKey, error) { return reloadedKey, nil - })) + }) assert.NoError(suite.T(), err) assert.NotNil(suite.T(), key) assert.Equal(suite.T(), created, key.Created()) - assert.Greater(suite.T(), suite.keyCache.keys[cacheKey(testKey, created)].loadedAt.Unix(), created) + assert.Greater(suite.T(), suite.keyCache.keys.GetOrPanic(cacheKey(testKey, created)).loadedAt.Unix(), created) // Verify we closed the new one we loaded and kept the cached one open assert.True(suite.T(), reloadedKey.IsClosed()) - assert.False(suite.T(), suite.keyCache.keys[cacheKey(testKey, created)].key.IsClosed()) + assert.False(suite.T(), suite.keyCache.keys.GetOrPanic(cacheKey(testKey, created)).key.IsClosed()) } } func (suite *CacheTestSuite) TestKeyCache_GetOrLoadLatest_WithCachedKeyNoReloadRequired() { - _, err := suite.keyCache.GetOrLoad(KeyMeta{testKey, suite.created}, keyLoaderFunc(func() (key *internal.CryptoKey, e error) { + _, err := suite.keyCache.GetOrLoad(KeyMeta{testKey, suite.created}, func(_ KeyMeta) (key *internal.CryptoKey, e error) { cryptoKey, err := internal.NewCryptoKey(secretFactory, suite.created, false, []byte("blah")) if err != nil { return nil, err } return cryptoKey, nil - })) + }) assert.NoError(suite.T(), err) - key, err := suite.keyCache.GetOrLoadLatest(testKey, keyLoaderFunc(func() (*internal.CryptoKey, error) { + key, err := suite.keyCache.GetOrLoadLatest(testKey, func(_ KeyMeta) (*internal.CryptoKey, error) { return nil, errors.New("should not be called") - })) + }) assert.NoError(suite.T(), err) assert.NotNil(suite.T(), key) @@ -249,110 +251,95 @@ func (suite *CacheTestSuite) TestKeyCache_GetOrLoadLatest_WithCachedKeyNoReloadR } func (suite *CacheTestSuite) TestKeyCache_GetOrLoadLatest_WithEmptyCache() { - key, err := suite.keyCache.GetOrLoadLatest(testKey, keyLoaderFunc(func() (*internal.CryptoKey, error) { - cryptoKey, err := internal.NewCryptoKey(secretFactory, suite.created, false, []byte("blah")) - if err != nil { - return nil, err - } - return cryptoKey, nil - })) + key, err := suite.keyCache.GetOrLoadLatest(testKey, func(_ KeyMeta) (*internal.CryptoKey, error) { + return internal.NewCryptoKey(secretFactory, suite.created, false, []byte("blah")) + }) assert.NoError(suite.T(), err) assert.NotNil(suite.T(), key) assert.Equal(suite.T(), suite.created, key.Created()) - assert.Equal(suite.T(), suite.created, suite.keyCache.keys[cacheKey(testKey, 0)].key.Created()) + + latestKey, _ := suite.keyCache.getLatestKeyMeta(testKey) + assert.Equal(suite.T(), latestKey, KeyMeta{ID: testKey, Created: suite.created}) } func (suite *CacheTestSuite) TestKeyCache_GetOrLoadLatest_DoesNotSetKeyOnError() { - key, err := suite.keyCache.GetOrLoadLatest(testKey, keyLoaderFunc(func() (*internal.CryptoKey, error) { + key, err := suite.keyCache.GetOrLoadLatest(testKey, func(_ KeyMeta) (*internal.CryptoKey, error) { return new(internal.CryptoKey), errors.New("error") - })) + }) if assert.Error(suite.T(), err) { assert.Nil(suite.T(), key) - assert.Empty(suite.T(), suite.keyCache.keys) + assert.Zero(suite.T(), suite.keyCache.keys.Len()) } } func (suite *CacheTestSuite) TestKeyCache_GetOrLoadLatest_WithCachedKeyReloadRequiredAndNowRevoked() { key, err := internal.NewCryptoKey(secretFactory, suite.created, false, []byte("blah")) - if assert.NoError(suite.T(), err) { - entry := cacheEntry{ - key: key, - loadedAt: time.Now().Add(-2 * suite.policy.RevokeCheckInterval), - } + suite.Require().NoError(err) - suite.keyCache.keys[cacheKey(testKey, suite.created)] = entry - suite.keyCache.keys[cacheKey(testKey, 0)] = entry + entry := newCacheEntry(key) + entry.loadedAt = time.Now().Add(-2 * suite.policy.RevokeCheckInterval) - revokedKey, e := internal.NewCryptoKey(secretFactory, suite.created, true, []byte("blah")) - if assert.NoError(suite.T(), e) { - key, err := suite.keyCache.GetOrLoadLatest(testKey, keyLoaderFunc(func() (*internal.CryptoKey, error) { - return revokedKey, nil - })) - - assert.NoError(suite.T(), err) - assert.NotNil(suite.T(), key) - assert.Equal(suite.T(), suite.created, key.Created()) - assert.True(suite.T(), key.Revoked()) - assert.True(suite.T(), suite.keyCache.keys[cacheKey(testKey, 0)].key.Revoked()) - // Verify we closed the new one we loaded and kept the cached one open - assert.True(suite.T(), revokedKey.IsClosed()) - assert.False(suite.T(), suite.keyCache.keys[cacheKey(testKey, suite.created)].key.IsClosed()) - } - } -} + suite.keyCache.mapLatestKeyMeta(testKey, KeyMeta{ID: testKey, Created: suite.created}) + suite.keyCache.keys.Set(cacheKey(testKey, suite.created), entry) -type mockKeyReloader struct { - mock.Mock + revokedKey, e := internal.NewCryptoKey(secretFactory, suite.created, true, []byte("blah")) + suite.Require().NoError(e) - loader keyLoaderFunc -} + first := true + calls := 0 -func (r *mockKeyReloader) Load() (*internal.CryptoKey, error) { - args := r.Called() + // Because the entry's loadedAt is older than the revoke check interval, the key should be treated as "stale" + // which should trigger the following: + // 1. A cache miss is recorded because the key is no longer "fresh" + // 2. The key is loaded via the loader function below, which returns a revoked key on the first call + // 3. The cache, having received a revoked key, increments the reloaded count + // 4. The key is reloaded via the loader function, which returns a new key on subsequent calls + latest, err := suite.keyCache.GetOrLoadLatest(testKey, func(_ KeyMeta) (*internal.CryptoKey, error) { + calls++ - if r.loader != nil { - return r.loader() - } + if first { + first = false + return revokedKey, nil + } - return args.Get(0).(*internal.CryptoKey), args.Error(1) -} + return internal.NewCryptoKey(secretFactory, suite.created, false, []byte("blah")) + }) -func (r *mockKeyReloader) IsInvalid(key *internal.CryptoKey) bool { - args := r.Called(key.Created()) - return args.Bool(0) + assert.NoError(suite.T(), err) + assert.NotNil(suite.T(), latest) + assert.Equal(suite.T(), suite.created, latest.Created()) + assert.Equal(suite.T(), 2, calls) + assert.False(suite.T(), latest.Revoked()) + // Verify we closed the new one we loaded and kept the cached one open + assert.True(suite.T(), revokedKey.IsClosed()) + assert.False(suite.T(), suite.keyCache.keys.GetOrPanic(cacheKey(testKey, suite.created)).key.IsClosed()) } -func (suite *CacheTestSuite) TestKeyCache_GetOrLoadLatest_KeyReloader_WithCachedKeyAndInvalidKey() { +func (suite *CacheTestSuite) TestKeyCache_GetOrLoadLatest_WithCachedKeyAndInvalidKey() { orig, err := internal.NewCryptoKey(secretFactory, suite.created, true, []byte("blah")) require.NoError(suite.T(), err) entry := cacheEntry{ - key: orig, + key: &cachedCryptoKey{CryptoKey: orig}, loadedAt: time.Now(), } - suite.keyCache.keys[cacheKey(testKey, suite.created)] = entry - suite.keyCache.keys[cacheKey(testKey, 0)] = entry + suite.keyCache.mapLatestKeyMeta(testKey, KeyMeta{ID: testKey, Created: suite.created}) + suite.keyCache.keys.Set(cacheKey(testKey, suite.created), entry) newerCreated := time.Now().Add(1 * time.Second).Unix() require.Greater(suite.T(), newerCreated, suite.created) - reloader := &mockKeyReloader{ - loader: keyLoaderFunc(func() (*internal.CryptoKey, error) { - reloadedKey, e := internal.NewCryptoKey(secretFactory, newerCreated, false, []byte("blah")) - assert.NoError(suite.T(), e) + loader := func(_ KeyMeta) (*internal.CryptoKey, error) { + reloadedKey, e := internal.NewCryptoKey(secretFactory, newerCreated, false, []byte("blah")) + assert.NoError(suite.T(), e) - return reloadedKey, e - }), + return reloadedKey, e } - reloader.On("IsInvalid", orig.Created()).Return(true) - reloader.On("Load").Return().Once() - - key, err := suite.keyCache.GetOrLoadLatest(testKey, reloader) - reloader.AssertExpectations(suite.T()) + key, err := suite.keyCache.GetOrLoadLatest(testKey, loader) assert.NoError(suite.T(), err) assert.NotNil(suite.T(), key) @@ -362,80 +349,95 @@ func (suite *CacheTestSuite) TestKeyCache_GetOrLoadLatest_KeyReloader_WithCached assert.False(suite.T(), key.Revoked()) // cached key is still revoked - cached := suite.keyCache.keys[cacheKey(testKey, suite.created)] + cached := suite.keyCache.keys.GetOrPanic(cacheKey(testKey, suite.created)) assert.True(suite.T(), cached.key.Revoked(), fmt.Sprintf("%+v - created: %d", cached.key, cached.key.Created())) } -func (suite *CacheTestSuite) TestKeyCache_GetOrLoadLatest_KeyReloader_WithCachedKeyAndValidKey() { +func (suite *CacheTestSuite) TestKeyCache_GetOrLoadLatest_WithCachedKeyAndValidKey() { key, err := internal.NewCryptoKey(secretFactory, suite.created, false, []byte("blah")) if assert.NoError(suite.T(), err) { entry := cacheEntry{ - key: key, + key: &cachedCryptoKey{CryptoKey: key}, loadedAt: time.Now(), } - suite.keyCache.keys[cacheKey(testKey, suite.created)] = entry - suite.keyCache.keys[cacheKey(testKey, 0)] = entry - - reloader := new(mockKeyReloader) - reloader.On("IsInvalid", key.Created()).Return(false) + suite.keyCache.keys.Set(cacheKey(testKey, suite.created), entry) - key, err := suite.keyCache.GetOrLoadLatest(testKey, reloader) + key, err := suite.keyCache.GetOrLoadLatest(testKey, func(_ KeyMeta) (*internal.CryptoKey, error) { + return key, nil + }) assert.NoError(suite.T(), err) assert.NotNil(suite.T(), key) assert.Equal(suite.T(), suite.created, key.Created()) - assert.False(suite.T(), suite.keyCache.keys[cacheKey(testKey, 0)].key.Revoked()) - - reloader.AssertNotCalled(suite.T(), "Load", mock.Anything) - reloader.AssertExpectations(suite.T()) + assert.False(suite.T(), key.Revoked()) } } func (suite *CacheTestSuite) TestKeyCache_Close() { - cache := newKeyCache(NewCryptoPolicy()) + cache := newKeyCache(CacheTypeIntermediateKeys, NewCryptoPolicy()) - key, err := cache.GetOrLoadLatest(testKey, keyLoaderFunc(func() (*internal.CryptoKey, error) { + key, err := cache.GetOrLoadLatest(testKey, func(_ KeyMeta) (*internal.CryptoKey, error) { cryptoKey, err := internal.NewCryptoKey(secretFactory, suite.created, false, []byte("blah")) if err != nil { return nil, err } + return cryptoKey, nil - })) + }) assert.NoError(suite.T(), err) + key.Close() + assert.False(suite.T(), key.IsClosed(), "key should not be closed yet, as it is still in the cache") + err = cache.Close() assert.NoError(suite.T(), err) assert.True(suite.T(), key.IsClosed()) - assert.True(suite.T(), cache.keys[cacheKey(testKey, suite.created)].key.IsClosed()) - assert.True(suite.T(), cache.keys[cacheKey(testKey, 0)].key.IsClosed()) } -func (suite *CacheTestSuite) TestKeyCache_Close_MultipleCallsNoError() { - cache := newKeyCache(NewCryptoPolicy()) +func (suite *CacheTestSuite) TestKeyCache_Close_CacheThenKey() { + cache := newKeyCache(CacheTypeIntermediateKeys, NewCryptoPolicy()) - key, err := cache.GetOrLoadLatest(testKey, keyLoaderFunc(func() (*internal.CryptoKey, error) { - cryptoKey, err := internal.NewCryptoKey(secretFactory, time.Now().Unix(), false, []byte("blah")) + key, err := cache.GetOrLoadLatest(testKey, func(_ KeyMeta) (*internal.CryptoKey, error) { + cryptoKey, err := internal.NewCryptoKey(secretFactory, suite.created, false, []byte("blah")) if err != nil { return nil, err } + return cryptoKey, nil - })) + }) assert.NoError(suite.T(), err) err = cache.Close() - assert.NoError(suite.T(), err) + assert.False(suite.T(), key.IsClosed(), "key should not be closed yet, key reference still exists") + + key.Close() assert.True(suite.T(), key.IsClosed()) +} + +func (suite *CacheTestSuite) TestKeyCache_Close_MultipleCallsNoError() { + cache := newKeyCache(CacheTypeIntermediateKeys, NewCryptoPolicy()) + + key, err := cache.GetOrLoadLatest(testKey, func(_ KeyMeta) (*internal.CryptoKey, error) { + return internal.NewCryptoKey(secretFactory, time.Now().Unix(), false, []byte("blah")) + }) + assert.NoError(suite.T(), err) + + key.Close() err = cache.Close() assert.NoError(suite.T(), err) + assert.True(suite.T(), key.IsClosed()) + + err = cache.Close() + assert.NoError(suite.T(), err) } func (suite *CacheTestSuite) TestKeyCache_String() { - cache := newKeyCache(NewCryptoPolicy()) + cache := newKeyCache(CacheTypeIntermediateKeys, NewCryptoPolicy()) defer cache.Close() assert.Contains(suite.T(), cache.String(), "keyCache(") @@ -443,13 +445,9 @@ func (suite *CacheTestSuite) TestKeyCache_String() { func (suite *CacheTestSuite) TestNeverCache_GetOrLoad() { var cache neverCache - key, err := cache.GetOrLoad(KeyMeta{testKey, created}, keyLoaderFunc(func() (key *internal.CryptoKey, e error) { - cryptoKey, err := internal.NewCryptoKey(secretFactory, created, false, []byte("blah")) - if err != nil { - return nil, err - } - return cryptoKey, nil - })) + key, err := cache.GetOrLoad(KeyMeta{testKey, created}, func(_ KeyMeta) (key *internal.CryptoKey, e error) { + return internal.NewCryptoKey(secretFactory, created, false, []byte("blah")) + }) if assert.NoError(suite.T(), err) { // neverCache can't close keys we create @@ -462,13 +460,9 @@ func (suite *CacheTestSuite) TestNeverCache_GetOrLoad() { func (suite *CacheTestSuite) TestNeverCache_GetOrLoadLatest() { var cache neverCache - key, err := cache.GetOrLoadLatest(testKey, keyLoaderFunc(func() (key *internal.CryptoKey, e error) { - cryptoKey, err := internal.NewCryptoKey(secretFactory, created, false, []byte("blah")) - if err != nil { - return nil, err - } - return cryptoKey, nil - })) + key, err := cache.GetOrLoadLatest(testKey, func(_ KeyMeta) (key *internal.CryptoKey, e error) { + return internal.NewCryptoKey(secretFactory, created, false, []byte("blah")) + }) if assert.NoError(suite.T(), err) { // neverCache can't close keys we create @@ -487,27 +481,32 @@ func (suite *CacheTestSuite) TestNeverCache_Close() { assert.NoError(suite.T(), err) } -func (suite *CacheTestSuite) TestSharedKeyCache_GetOrLoad() { +func (suite *CacheTestSuite) TestKeyCache_GetOrLoad_Concurrent_100() { if testing.Short() { suite.T().Skip("too slow for testing.Short") } var ( - cache = newKeyCache(NewCryptoPolicy()) + cache = newKeyCache(CacheTypeIntermediateKeys, NewCryptoPolicy()) i = 0 wg sync.WaitGroup counter int32 ) - loadFunc := keyLoaderFunc(func() (*internal.CryptoKey, error) { - <-time.After(time.Nanosecond * time.Duration(rand.Intn(30))) + loadFunc := func(_ KeyMeta) (*internal.CryptoKey, error) { + <-time.After(time.Millisecond * time.Duration(rand.Intn(30))) atomic.AddInt32(&counter, 1) return new(internal.CryptoKey), nil - }) + } meta := KeyMeta{ID: "testing", Created: time.Now().Unix()} + _, err := cache.GetOrLoad(meta, loadFunc) + if err != nil { + suite.T().Error(err) + } + for ; i < 100; i++ { wg.Add(1) @@ -517,24 +516,22 @@ func (suite *CacheTestSuite) TestSharedKeyCache_GetOrLoad() { key, err := cache.GetOrLoad(meta, loadFunc) if key == nil { suite.T().Error("key == nil") - suite.T().Fail() } if err != nil { suite.T().Error(err) - suite.T().Fail() } }() } wg.Wait() - // This seems to be causing intermittent issues with go2xunit parsing - //d := time.Since(startTime) - // - //fmt.Printf("Finished %d loops in: %s (%f/s)", i, d, float64(i)/d.Seconds()) - assert.Equal(suite.T(), int32(1), counter) + assert.Equal(suite.T(), 1, cache.keys.Len()) + + // metrics := cache.GetMetrics() + // assert.Equal(suite.T(), int64(1), metrics.MissCount) + // assert.Equal(suite.T(), int64(100), metrics.HitCount) } func TestCacheTestSuite(t *testing.T) { diff --git a/go/appencryption/parameterized_test.go b/go/appencryption/parameterized_test.go index 8b419bef3..3cef965d0 100644 --- a/go/appencryption/parameterized_test.go +++ b/go/appencryption/parameterized_test.go @@ -187,25 +187,25 @@ func createRevokedKey(src *internal.CryptoKey, factory securememory.SecretFactor } func createSession(crypto AEAD, metastore Metastore, kms KeyManagementService, factory securememory.SecretFactory, - policy *CryptoPolicy, partition partition, ikCache cache, skCache cache) *Session { + policy *CryptoPolicy, partition partition, ikCache keyCacher, skCache keyCacher) *Session { return &Session{ encryption: &envelopeEncryption{ - partition: partition, - Metastore: metastore, - KMS: kms, - Policy: policy, - Crypto: crypto, - SecretFactory: factory, - systemKeys: skCache, - intermediateKeys: ikCache, + partition: partition, + Metastore: metastore, + KMS: kms, + Policy: policy, + Crypto: crypto, + SecretFactory: factory, + skCache: skCache, + ikCache: ikCache, }} } func createCache(partition partition, cacheIK, cacheSK string, intermediateKey, systemKey *internal.CryptoKey, - policy *CryptoPolicy) (cache, cache) { - var ikCache, skCache cache - skCache = newKeyCache(policy) - ikCache = newKeyCache(policy) + policy *CryptoPolicy) (keyCacher, keyCacher) { + var ikCache, skCache keyCacher + skCache = newKeyCache(CacheTypeSystemKeys, policy) + ikCache = newKeyCache(CacheTypeIntermediateKeys, policy) sk := systemKey ik := intermediateKey @@ -222,9 +222,9 @@ func createCache(partition partition, cacheIK, cacheSK string, intermediateKey, } // Preload the cache with the system keys - _, _ = skCache.GetOrLoad(*meta, keyLoaderFunc(func() (*internal.CryptoKey, error) { + _, _ = skCache.GetOrLoad(*meta, func(_ KeyMeta) (*internal.CryptoKey, error) { return sk, nil - })) + }) } if cacheIK != EMPTY { @@ -239,9 +239,9 @@ func createCache(partition partition, cacheIK, cacheSK string, intermediateKey, } // Preload the cache with the intermediate keys - _, _ = ikCache.GetOrLoad(*meta, keyLoaderFunc(func() (*internal.CryptoKey, error) { + _, _ = ikCache.GetOrLoad(*meta, func(_ KeyMeta) (*internal.CryptoKey, error) { return ik, nil - })) + }) } return ikCache, skCache diff --git a/go/appencryption/pkg/cache/cache.go b/go/appencryption/pkg/cache/cache.go new file mode 100644 index 000000000..0de3f158e --- /dev/null +++ b/go/appencryption/pkg/cache/cache.go @@ -0,0 +1,483 @@ +// Package cache provides a cache implementation with support for multiple +// eviction policies. +// +// Currently supported eviction policies: +// - LRU (least recently used) +// - LFU (least frequently used) +// - SLRU (segmented least recently used) +// - TinyLFU (tiny least frequently used) +// +// The cache is safe for concurrent access. +package cache + +import ( + "container/list" + "fmt" + "sync" + "time" + + "github.com/godaddy/asherah/go/appencryption/pkg/log" +) + +// Interface is intended to be a generic interface for cache implementations. +type Interface[K comparable, V any] interface { + Get(key K) (V, bool) + GetOrPanic(key K) V + Set(key K, value V) + Delete(key K) bool + Len() int + Capacity() int + Close() error +} + +// CachePolicy is an enum for the different eviction policies. +type CachePolicy string + +const ( + // LRU is the least recently used cache policy. + LRU CachePolicy = "lru" + // LFU is the least frequently used cache policy. + LFU CachePolicy = "lfu" + // SLRU is the segmented least recently used cache policy. + SLRU CachePolicy = "slru" + // TinyLFU is the tiny least frequently used cache policy. + TinyLFU CachePolicy = "tinylfu" +) + +// String returns the string representation of the eviction policy. +func (e CachePolicy) String() string { + return string(e) +} + +// EvictFunc is called when an item is evicted from the cache. The key and +// value of the evicted item are passed to the function. +type EvictFunc[K comparable, V any] func(key K, value V) + +// NopEvict is a no-op EvictFunc. +func NopEvict[K comparable, V any](K, V) {} + +// event is the cache event (evictItem or closeCache). +type event int + +const ( + // evictItem is sent on the events channel when an item is evicted from the cache. + evictItem event = iota + // closeCache is sent on the events channel when the cache is closed. + closeCache +) + +type cacheItem[K comparable, V any] struct { + key K + value V + + parent *list.Element // Pointer to the frequencyParent + + expiration time.Time // Expiration time +} + +// cacheEvent is the event sent on the events channel. +type cacheEvent[K comparable, V any] struct { + event event + item *cacheItem[K, V] +} + +// policy is the generic interface for eviction policies. +type policy[K comparable, V any] interface { + // init initializes the policy with the given capacity. + init(int) + // capacity returns the capacity of the policy. + capacity() int + // close removes all items from the cache, sends a close event to the event + // processing goroutine, and waits for it to exit. + close() + // admit is called when an item is admitted to the cache. + admit(item *cacheItem[K, V]) + // access is called when an item is accessed. + access(item *cacheItem[K, V]) + // victim returns the victim item to be evicted. + victim() *cacheItem[K, V] + // remove is called when an item is removed from the cache. + remove(item *cacheItem[K, V]) +} + +// Clock is an interface for getting the current time. +type Clock interface { + Now() time.Time +} + +// realClock is the default Clock implementation. +type realClock struct{} + +// Now returns the current time. +func (c *realClock) Now() time.Time { + return time.Now() +} + +type builder[K comparable, V any] struct { + capacity int + policy policy[K, V] + evictFunc EvictFunc[K, V] + clock Clock + expiry time.Duration + isSync bool +} + +// New returns a new cache builder with the given capacity. Use the builder to +// set the eviction policy, eviction callback, and other options. Call Build() +// to create the cache. +func New[K comparable, V any](capacity int) *builder[K, V] { + return &builder[K, V]{ + capacity: capacity, + policy: new(lru[K, V]), + evictFunc: NopEvict[K, V], + clock: new(realClock), + } +} + +// WithEvictFunc sets the EvictFunc for the cache. +func (b *builder[K, V]) WithEvictFunc(fn EvictFunc[K, V]) *builder[K, V] { + b.evictFunc = fn + + return b +} + +// WithPolicy sets the eviction policy for the cache. The default policy is LRU. +func (b *builder[K, V]) WithPolicy(policy CachePolicy) *builder[K, V] { + switch policy { + case LRU: + b.policy = new(lru[K, V]) + case LFU: + b.policy = new(lfu[K, V]) + case SLRU: + b.policy = new(slru[K, V]) + case TinyLFU: + b.policy = new(tinyLFU[K, V]) + default: + panic(fmt.Sprintf("cache: unsupported policy \"%s\"", policy.String())) + } + + return b +} + +// LRU sets the cache eviction policy to LRU (least recently used). +func (b *builder[K, V]) LRU() *builder[K, V] { + return b.WithPolicy(LRU) +} + +// LFU sets the cache eviction policy to LFU (least frequently used). +func (b *builder[K, V]) LFU() *builder[K, V] { + return b.WithPolicy(LFU) +} + +// SLRU sets the cache eviction policy to SLRU (segmented least recently used). +func (b *builder[K, V]) SLRU() *builder[K, V] { + return b.WithPolicy(SLRU) +} + +// TinyLFU sets the cache eviction policy to TinyLFU (tiny least frequently used). +func (b *builder[K, V]) TinyLFU() *builder[K, V] { + return b.WithPolicy(TinyLFU) +} + +// WithClock sets the Clock for the cache. +func (b *builder[K, V]) WithClock(clock Clock) *builder[K, V] { + b.clock = clock + + return b +} + +// WithExpiry sets the expiry for the cache. +func (b *builder[K, V]) WithExpiry(expiry time.Duration) *builder[K, V] { + b.expiry = expiry + + return b +} + +// Synchronous sets the cache to use a synchronous eviction process. By +// default, the cache uses a concurrent eviction process which executes the +// eviction callback in a separate goroutine. +// Use this option to ensure eviction is processed inline, prior to adding +// a new item to the cache. +func (b *builder[K, V]) Synchronous() *builder[K, V] { + b.isSync = true + + return b +} + +// Build creates the cache. +func (b *builder[K, V]) Build() Interface[K, V] { + c := &cache[K, V]{ + byKey: make(map[K]*cacheItem[K, V]), + + policy: b.policy, + clock: b.clock, + expiry: b.expiry, + onEvictCallback: b.evictFunc, + isSync: b.isSync, + } + + c.policy.init(b.capacity) + + c.startup() + + return c +} + +// cache is the generic cache type. +type cache[K comparable, V any] struct { + byKey map[K]*cacheItem[K, V] // Hashmap containing *CacheItems for O(1) access + size int // Current number of items in the cache + events chan cacheEvent[K, V] // Channel to events when an item is evicted + policy policy[K, V] // Eviction policy + + mux sync.RWMutex // synchronize access to the cache + + closing bool + closeWG sync.WaitGroup + + // onEvictCallback is called when an item is evicted from the cache. The key, value, + // and frequency of the evicted item are passed to the function. Set to + // a custom function to handle evicted items. The default is a no-op. + onEvictCallback EvictFunc[K, V] + + // clock is used to get the current time. Set to a custom Clock to use a + // custom clock. The default is the real time clock. + clock Clock + + // expiry is the duration after which an item is considered expired. Set to + // a custom duration to use a custom expiry. The default is no expiry. + expiry time.Duration + + // isSync is true if the cache uses a synchronized eviction process. The default + // is false, which uses a concurrent eviction process. + isSync bool +} + +// processEvents processes events in a separate goroutine. +func (c *cache[K, V]) processEvents() { + defer c.closeWG.Done() + + for event := range c.events { + switch event.event { + case evictItem: + log.Debugf("%s executing evict callback for item: %v", c, event.item.key) + c.onEvictCallback(event.item.key, event.item.value) + case closeCache: + log.Debugf("%s closed, exiting event loop", c) + + return + } + } +} + +// Close the cache and remove all items. The cache cannot be used after it is +// closed. +func (c *cache[K, V]) Close() error { + c.mux.Lock() + defer c.mux.Unlock() + + // if the cache is already closed, do nothing + if c.closing { + return nil + } + + c.closing = true + + for c.size > 0 { + c.evict() + } + + c.shutdown() + + c.byKey = nil + + c.policy.close() + + return nil +} + +// startup starts the cache event processing goroutine. +func (c *cache[K, V]) startup() { + if c.isSync { + // no need to start the event processing goroutine + return + } + + c.events = make(chan cacheEvent[K, V]) + + c.closeWG.Add(1) + + go c.processEvents() +} + +// shutdown closes the events channel and waits for the event processing +// goroutine to exit. +func (c *cache[K, V]) shutdown() { + if c.isSync { + return + } + + c.events <- cacheEvent[K, V]{event: closeCache} + + c.closeWG.Wait() + + close(c.events) + + c.events = nil +} + +// Len returns the number of items in the cache. +func (c *cache[K, V]) Len() int { + c.mux.RLock() + defer c.mux.RUnlock() + + return c.size +} + +// Capacity returns the maximum number of items in the cache. +func (c *cache[K, V]) Capacity() int { + c.mux.RLock() + defer c.mux.RUnlock() + + return c.policy.capacity() +} + +// Set adds a value to the cache. If an item with the given key already exists, +// its value is updated. +func (c *cache[K, V]) Set(key K, value V) { + c.mux.Lock() + defer c.mux.Unlock() + + if c.closing { + return + } + + if item, ok := c.byKey[key]; ok { + item.value = value + + if c.expiry > 0 { + item.expiration = c.clock.Now().Add(c.expiry) + } + + c.policy.access(item) + + return + } + + // if the cache is full, evict an item + if c.size == c.policy.capacity() { + c.evict() + } + + item := &cacheItem[K, V]{ + key: key, + value: value, + } + + if c.expiry > 0 { + item.expiration = c.clock.Now().Add(c.expiry) + } + + c.byKey[key] = item + + c.size++ + + c.policy.admit(item) +} + +// Get returns a value from the cache. If an item with the given key does not +// exist, the second return value will be false. +func (c *cache[K, V]) Get(key K) (V, bool) { + c.mux.Lock() + defer c.mux.Unlock() + + if c.closing { + return c.zeroValue(), false + } + + item, ok := c.byKey[key] + if !ok { + return c.zeroValue(), false + } + + if c.expiry > 0 && item.expiration.Before(c.clock.Now()) { + c.evictItem(item) + return c.zeroValue(), false + } + + c.policy.access(item) + + return item.value, true +} + +// GetOrPanic returns the value for the given key. If the key does not exist, a +// panic is raised. +func (c *cache[K, V]) GetOrPanic(key K) V { + if item, ok := c.Get(key); ok { + return item + } + + panic(fmt.Sprintf("key does not exist: %v", key)) +} + +// Delete removes the given key from the cache. If the key does not exist, the +// return value is false. +func (c *cache[K, V]) Delete(key K) bool { + c.mux.Lock() + defer c.mux.Unlock() + + if c.closing { + return false + } + + item, ok := c.byKey[key] + if !ok { + return false + } + + delete(c.byKey, key) + + c.size-- + + c.policy.remove(item) + + return true +} + +// zeroValue returns the zero value for type V. +func (c *cache[K, V]) zeroValue() V { + var v V + return v +} + +// evict removes an item from the cache and sends an evict event or, if the +// cache uses a synchronized eviction process, calls the evict callback. +func (c *cache[K, V]) evict() { + item := c.policy.victim() + c.evictItem(item) +} + +// evictItem removes the given item from the cache and sends an evict event. +func (c *cache[K, V]) evictItem(item *cacheItem[K, V]) { + delete(c.byKey, item.key) + + c.size-- + + c.policy.remove(item) + + if c.isSync { + log.Debugf("%s executing evict callback for item (synchronous): %v", c, item.key) + + c.onEvictCallback(item.key, item.value) + + return + } + + log.Debugf("%s sending evict event for item: %v", c, item.key) + c.events <- cacheEvent[K, V]{event: evictItem, item: item} +} + +// String returns a string representation of this cache. +func (c *cache[K, V]) String() string { + return fmt.Sprintf("cache[%T, %T](%p)", *new(K), *new(V), c) +} diff --git a/go/appencryption/pkg/cache/cache_test.go b/go/appencryption/pkg/cache/cache_test.go new file mode 100644 index 000000000..bc7702600 --- /dev/null +++ b/go/appencryption/pkg/cache/cache_test.go @@ -0,0 +1,153 @@ +package cache_test + +import ( + "testing" + "time" + + "github.com/stretchr/testify/suite" + + "github.com/godaddy/asherah/go/appencryption/pkg/cache" +) + +type CacheSuite struct { + suite.Suite + clock *fakeClock + expiry time.Duration +} + +func TestCacheSuite(t *testing.T) { + suite.Run(t, new(CacheSuite)) +} + +// fakeClock is a fake clock that returns a static time. +type fakeClock struct { + now time.Time +} + +// Now returns the current time. +func (c *fakeClock) Now() time.Time { + return c.now +} + +// SetNow sets the current time. +func (c *fakeClock) SetNow(now time.Time) { + c.now = now +} + +func (suite *CacheSuite) SetupTest() { + suite.clock = &fakeClock{ + now: time.Now(), + } + + suite.expiry = time.Hour +} + +func (suite *CacheSuite) newCache() cache.Interface[int, string] { + cb := cache.New[int, string](2).WithClock(suite.clock).WithExpiry(suite.expiry) + + return cb.Build() +} + +func (suite *CacheSuite) TestBuild() { + c := suite.newCache() + + suite.Assert().Equal(0, c.Len()) + suite.Assert().Equal(2, c.Capacity()) +} + +func (suite *CacheSuite) TestClosing() { + c := suite.newCache() + + suite.Assert().NoError(c.Close()) + + // set/get do nothing after closing + c.Set(1, "one") + suite.Assert().Equal(0, c.Len()) + + // getting a value does nothing, returns false + _, ok := c.Get(1) + suite.Assert().False(ok) + + // delete does nothing + suite.Assert().False(c.Delete(1)) + + // closing again does nothing + suite.Assert().NoError(c.Close()) +} + +func (suite *CacheSuite) TestExpiry() { + c := suite.newCache() + + c.Set(1, "one") + c.Set(2, "two") + + one, ok := c.Get(1) + suite.Assert().Equal("one", one) + suite.Assert().True(ok) + + two, ok := c.Get(2) + suite.Assert().Equal("two", two) + suite.Assert().True(ok) + + // advance clock + suite.clock.SetNow(suite.clock.Now().Add(suite.expiry + time.Second)) + + // get should return false + _, ok = c.Get(1) + suite.Assert().False(ok) + + _, ok = c.Get(2) + suite.Assert().False(ok) +} + +func (suite *CacheSuite) TestSynchronousEviction() { + evicted := false + + cb := cache.New[int, string](2).Synchronous() + cb.WithEvictFunc(func(key int, value string) { + suite.Assert().Equal(1, key) + suite.Assert().Equal("one", value) + + evicted = true + }) + + c := cb.Build() + + c.Set(1, "one") + c.Set(2, "two") + c.Set(3, "three") + + suite.Assert().True(evicted) + + // 1 should be evicted + _, ok := c.Get(1) + suite.Assert().False(ok) + + _, ok = c.Get(2) + suite.Assert().True(ok) + + // 3 should still be there + three, ok := c.Get(3) + suite.Assert().Equal("three", three) + suite.Assert().True(ok) +} + +func (suite *CacheSuite) TestSynchronousClosing() { + c := cache.New[int, string](2).Synchronous().Build() + + suite.Assert().NoError(c.Close()) + + // set/get do nothing after closing + c.Set(1, "one") + suite.Assert().Equal(0, c.Len()) + + // getting a value does nothing, returns false + _, ok := c.Get(1) + suite.Assert().False(ok) + + // delete does nothing + suite.Assert().False(c.Delete(1)) + + // closing again does nothing + suite.Assert().NoError(c.Close()) +} diff --git a/go/appencryption/pkg/cache/internal/NOTICE b/go/appencryption/pkg/cache/internal/NOTICE new file mode 100644 index 000000000..1ab653f38 --- /dev/null +++ b/go/appencryption/pkg/cache/internal/NOTICE @@ -0,0 +1,26 @@ +Copyright (c) 2016, Quoc-Viet Nguyen. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this +list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, +this list of conditions and the following disclaimer in the documentation +and/or other materials provided with the distribution. + +3. Neither the names of the copyright holders nor the names of its contributors +may be used to endorse or promote products derived from this software without +specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/go/appencryption/pkg/cache/internal/doc.go b/go/appencryption/pkg/cache/internal/doc.go new file mode 100644 index 000000000..00bd2f3a7 --- /dev/null +++ b/go/appencryption/pkg/cache/internal/doc.go @@ -0,0 +1,7 @@ +// Package internal contains data structures used by cache implementations. +// +// These data structures are derived from the [Mango Cache] source code. +// See NOTICE for important copyright and licensing information. +// +// [Mango Cache]: https://github.com/goburrow/cache +package internal diff --git a/go/appencryption/pkg/cache/internal/filter.go b/go/appencryption/pkg/cache/internal/filter.go new file mode 100644 index 000000000..4eefbb671 --- /dev/null +++ b/go/appencryption/pkg/cache/internal/filter.go @@ -0,0 +1,103 @@ +package internal + +import ( + "math" +) + +// BloomFilter is Bloom Filter implementation used as a cache admission policy. +// See http://billmill.org/bloomfilter-tutorial/ +type BloomFilter struct { + numHashes uint32 // number of hashes per element + bitsMask uint32 // size of bit vector + bits []uint64 // filter bit vector +} + +// Init initializes bloomFilter with the given expected insertions ins and +// false positive probability fpp. +func (f *BloomFilter) Init(ins int, fpp float64) { + ln2 := math.Log(2.0) + factor := -math.Log(fpp) / (ln2 * ln2) + + numBits := nextPowerOfTwo(uint32(float64(ins) * factor)) + if numBits == 0 { + numBits = 1 + } + + f.bitsMask = numBits - 1 + + if ins == 0 { + f.numHashes = 1 + } else { + f.numHashes = uint32(ln2 * float64(numBits) / float64(ins)) + } + + if size := int(numBits+63) / 64; len(f.bits) != size { + f.bits = make([]uint64, size) + } else { + f.Reset() + } +} + +// nextPowerOfTwo returns the smallest power of two which is greater than or equal to i. +func nextPowerOfTwo(i uint32) uint32 { + n := i - 1 + n |= n >> 1 + n |= n >> 2 + n |= n >> 4 + n |= n >> 8 + n |= n >> 16 + n++ + + return n +} + +// Put inserts a hash value into the bloom filter. +// It returns true if the value may already in the filter. +func (f *BloomFilter) Put(h uint64) bool { + h1, h2 := uint32(h), uint32(h>>32) + + var o uint = 1 + for i := uint32(0); i < f.numHashes; i++ { + o &= f.set((h1 + (i * h2)) & f.bitsMask) + } + + return o == 1 +} + +// contains returns true if the given hash is may be in the filter. +func (f *BloomFilter) Contains(h uint64) bool { + h1, h2 := uint32(h), uint32(h>>32) + + var o uint = 1 + for i := uint32(0); i < f.numHashes; i++ { + o &= f.get((h1 + (i * h2)) & f.bitsMask) + } + + return o == 1 +} + +// set sets bit at index i and returns previous value. +func (f *BloomFilter) set(i uint32) uint { + idx, shift := i/64, i%64 + val := f.bits[idx] + mask := uint64(1) << shift + f.bits[idx] |= mask + + return uint((val & mask) >> shift) +} + +// get returns bit set at index i. +func (f *BloomFilter) get(i uint32) uint { + idx, shift := i/64, i%64 + val := f.bits[idx] + mask := uint64(1) << shift + + return uint((val & mask) >> shift) +} + +// Reset clears the bloom filter. +func (f *BloomFilter) Reset() { + for i := range f.bits { + f.bits[i] = 0 + } +} diff --git a/go/appencryption/pkg/cache/internal/filter_test.go b/go/appencryption/pkg/cache/internal/filter_test.go new file mode 100644 index 000000000..8d1e94d3c --- /dev/null +++ b/go/appencryption/pkg/cache/internal/filter_test.go @@ -0,0 +1,43 @@ +package internal_test + +import ( + "testing" + + "github.com/godaddy/asherah/go/appencryption/pkg/cache/internal" +) + +func TestBloomFilter(t *testing.T) { + const numIns = 100000 + + f := internal.BloomFilter{} + f.Init(numIns, 0.01) + + var i uint64 + for i = 0; i < numIns; i += 2 { + existed := f.Put(i) + if existed { + t.Fatalf("unexpected put(%d): %v, want: false", i, existed) + } + } + + for i = 0; i < numIns; i += 2 { + existed := f.Contains(i) + if !existed { + t.Fatalf("unexpected contains(%d): %v, want: true", i, existed) + } + } + + for i = 1; i < numIns; i += 2 { + existed := f.Contains(i) + if existed { + t.Fatalf("unexpected contains(%d): %v, want: false", i, existed) + } + } + + for i = 0; i < numIns; i += 2 { + existed := f.Put(i) + if !existed { + t.Fatalf("unexpected put(%d): %v, want: true", i, existed) + } + } +} diff --git a/go/appencryption/pkg/cache/internal/hash.go b/go/appencryption/pkg/cache/internal/hash.go new file mode 100644 index 000000000..fd0ce5e8b --- /dev/null +++ b/go/appencryption/pkg/cache/internal/hash.go @@ -0,0 +1,130 @@ +package internal + +import ( + "math" + "reflect" +) + +// Hash is an interface implemented by cache keys to +// override default hash function. +type Hash interface { + Sum64() uint64 +} + +// ComputeHash calculates hash value of the given key. +// +//nolint:gocyclo +func ComputeHash(k interface{}) uint64 { + switch h := k.(type) { + case Hash: + return h.Sum64() + case int: + return hashU64(uint64(h)) + case int8: + return hashU32(uint32(h)) + case int16: + return hashU32(uint32(h)) + case int32: + return hashU32(uint32(h)) + case int64: + return hashU64(uint64(h)) + case uint: + return hashU64(uint64(h)) + case uint8: + return hashU32(uint32(h)) + case uint16: + return hashU32(uint32(h)) + case uint32: + return hashU32(h) + case uint64: + return hashU64(h) + case uintptr: + return hashU64(uint64(h)) + case float32: + return hashU32(math.Float32bits(h)) + case float64: + return hashU64(math.Float64bits(h)) + case bool: + if h { + return 1 + } + + return 0 + case string: + return hashString(h) + } + // TODO: complex64 and complex128 + if h, ok := hashPointer(k); ok { + return h + } + // TODO: use gob to encode k to bytes then hash. + return 0 +} + +const ( + fnvOffset uint64 = 14695981039346656037 + fnvPrime uint64 = 1099511628211 +) + +func hashU64(v uint64) uint64 { + // Inline code from hash/fnv to reduce memory allocations + h := fnvOffset + // for i := uint(0); i < 64; i += 8 { + // h ^= (v >> i) & 0xFF + // h *= fnvPrime + // } + h ^= (v >> 0) & 0xFF + h *= fnvPrime + h ^= (v >> 8) & 0xFF + h *= fnvPrime + h ^= (v >> 16) & 0xFF + h *= fnvPrime + h ^= (v >> 24) & 0xFF + h *= fnvPrime + h ^= (v >> 32) & 0xFF + h *= fnvPrime + h ^= (v >> 40) & 0xFF + h *= fnvPrime + h ^= (v >> 48) & 0xFF + h *= fnvPrime + h ^= (v >> 56) & 0xFF + h *= fnvPrime + + return h +} + +func hashU32(v uint32) uint64 { + h := fnvOffset + h ^= uint64(v>>0) & 0xFF + h *= fnvPrime + h ^= uint64(v>>8) & 0xFF + h *= fnvPrime + h ^= uint64(v>>16) & 0xFF + h *= fnvPrime + h ^= uint64(v>>24) & 0xFF + h *= fnvPrime + + return h +} + +// hashString calculates hash value using FNV-1a algorithm. +func hashString(data string) uint64 { + // Inline code from hash/fnv to reduce memory allocations + h := fnvOffset + for _, b := range data { + h ^= uint64(b) + h *= fnvPrime + } + + return h +} + +func hashPointer(k interface{}) (uint64, bool) { + v := reflect.ValueOf(k) + switch v.Kind() { + case reflect.Ptr, reflect.UnsafePointer, reflect.Func, reflect.Slice, reflect.Map, reflect.Chan: + return hashU64(uint64(v.Pointer())), true + default: + return 0, false + } +} diff --git a/go/appencryption/pkg/cache/internal/hash_test.go b/go/appencryption/pkg/cache/internal/hash_test.go new file mode 100644 index 000000000..89ae484e8 --- /dev/null +++ b/go/appencryption/pkg/cache/internal/hash_test.go @@ -0,0 +1,59 @@ +package internal_test + +import ( + "testing" + + "github.com/stretchr/testify/suite" + + "github.com/godaddy/asherah/go/appencryption/pkg/cache/internal" +) + +type HashSuite struct { + suite.Suite +} + +func TestHashSuite(t *testing.T) { + suite.Run(t, new(HashSuite)) +} + +type hashable struct{} + +func (h hashable) Sum64() uint64 { + return 42 +} + +func (suite *HashSuite) TestComputeHash() { + tests := []struct { + input interface{} + expected uint64 + }{ + {input: -1, expected: 0x8cf51a8bfca3883d}, + {input: int8(-8), expected: 0xc49d767d487ba59e}, + {input: int16(-16), expected: 0xbff576369e732626}, + {input: int32(-32), expected: 0xfc0775b30ed9a536}, + {input: int64(-64), expected: 0xd1bdb52ab00c8d2}, + {input: uint(1), expected: 0x89cd31291d2aefa4}, + {input: uint8(8), expected: 0x4cfad6c24f7bf87d}, + {input: uint16(16), expected: 0x4cd037050129dd05}, + {input: uint32(32), expected: 0x4dcff574d71681d5}, + {input: uint64(64), expected: 0x6779ba74e3ecc205}, + {input: uintptr(uint64(64)), expected: 0x6779ba74e3ecc205}, + {input: float32(2.5), expected: 0x4cb8767f9d714215}, + {input: float64(2.5), expected: 0xa8ba2032280e4061}, + {input: true, expected: 1}, + {input: "1", expected: 0xaf63ac4c86019afc}, + {input: hashable{}, expected: 42}, + } + + for i, test := range tests { + i := i + suite.Assert().Equal(test.expected, internal.ComputeHash(test.input), "test %d", i) + } +} + +func (suite *HashSuite) TestComputeHashForPointer() { + input := make([]byte, 0) + + h := internal.ComputeHash(input) + suite.Assert().NotEqual(uint64(0), h) +} diff --git a/go/appencryption/pkg/cache/internal/sketch.go b/go/appencryption/pkg/cache/internal/sketch.go new file mode 100644 index 000000000..72ee2f859 --- /dev/null +++ b/go/appencryption/pkg/cache/internal/sketch.go @@ -0,0 +1,91 @@ +package internal + +const sketchDepth = 4 + +// CountMinSketch is an implementation of count-min sketch with 4-bit counters. +// See http://dimacs.rutgers.edu/~graham/pubs/papers/cmsoft.pdf +type CountMinSketch struct { + counters []uint64 + mask uint32 +} + +// init initialize count-min sketch with the given width. +func (c *CountMinSketch) Init(width int) { + // Need (width x 4 x 4) bits = width/4 x uint64 + size := nextPowerOfTwo(uint32(width)) >> 2 + if size < 1 { + size = 1 + } + + c.mask = size - 1 + if len(c.counters) == int(size) { + c.clear() + } else { + c.counters = make([]uint64, size) + } +} + +// Add increases counters associated with the given hash. +func (c *CountMinSketch) Add(h uint64) { + h1, h2 := uint32(h), uint32(h>>32) + + for i := uint32(0); i < sketchDepth; i++ { + idx, off := c.position(h1 + i*h2) + c.inc(idx, (16*i)+off) + } +} + +// Estimate returns minimum value of counters associated with the given hash. +func (c *CountMinSketch) Estimate(h uint64) uint8 { + h1, h2 := uint32(h), uint32(h>>32) + + var min uint8 = 0xFF + + for i := uint32(0); i < sketchDepth; i++ { + idx, off := c.position(h1 + i*h2) + + count := c.val(idx, (16*i)+off) + if count < min { + min = count + } + } + + return min +} + +// Reset divides all counters by two. +func (c *CountMinSketch) Reset() { + for i, v := range c.counters { + if v != 0 { + c.counters[i] = (v >> 1) & 0x7777777777777777 + } + } +} + +func (c *CountMinSketch) position(h uint32) (idx uint32, off uint32) { + idx = (h >> 2) & c.mask + off = (h & 3) << 2 + + return +} + +// inc increases value at index idx. +func (c *CountMinSketch) inc(idx, off uint32) { + v := c.counters[idx] + + if count := uint8(v>>off) & 0x0F; count < 15 { + c.counters[idx] = v + (1 << off) + } +} + +// val returns value at index idx. +func (c *CountMinSketch) val(idx, off uint32) uint8 { + v := c.counters[idx] + return uint8(v>>off) & 0x0F +} + +func (c *CountMinSketch) clear() { + for i := range c.counters { + c.counters[i] = 0 + } +} diff --git a/go/appencryption/pkg/cache/internal/sketch_test.go b/go/appencryption/pkg/cache/internal/sketch_test.go new file mode 100644 index 000000000..ecae31db7 --- /dev/null +++ b/go/appencryption/pkg/cache/internal/sketch_test.go @@ -0,0 +1,67 @@ +package internal_test + +import ( + "testing" + + "github.com/godaddy/asherah/go/appencryption/pkg/cache/internal" +) + +func TestCountMinSketch(t *testing.T) { + const max = 15 + + cm := &internal.CountMinSketch{} + cm.Init(max) + + for i := 0; i < max; i++ { + // Increase value at i j times + for j := i; j > 0; j-- { + cm.Add(uint64(i)) + } + } + + for i := 0; i < max; i++ { + n := cm.Estimate(uint64(i)) + if int(n) != i { + t.Fatalf("unexpected estimate(%d): %d, want: %d", i, n, i) + } + } + + cm.Reset() + + for i := 0; i < max; i++ { + n := cm.Estimate(uint64(i)) + if int(n) != i/2 { + t.Fatalf("unexpected estimate(%d): %d, want: %d", i, n, i/2) + } + } + + cm.Reset() + + for i := 0; i < max; i++ { + n := cm.Estimate(uint64(i)) + if int(n) != i/4 { + t.Fatalf("unexpected estimate(%d): %d, want: %d", i, n, i/4) + } + } + + for i := 0; i < 100; i++ { + cm.Add(1) + } + + if n := cm.Estimate(1); n != 15 { + t.Fatalf("unexpected estimate(%d): %d, want: %d", 1, n, 15) + } +} + +func BenchmarkCountMinSketchReset(b *testing.B) { + cm := &internal.CountMinSketch{} + cm.Init(1<<15 - 1) + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + cm.Add(0xCAFECAFECAFECAFE) + cm.Reset() + } +} diff --git a/go/appencryption/pkg/cache/lfu.go b/go/appencryption/pkg/cache/lfu.go new file mode 100644 index 000000000..d809cdca1 --- /dev/null +++ b/go/appencryption/pkg/cache/lfu.go @@ -0,0 +1,152 @@ +//nolint:forcetypeassert // we know the type of the value +package cache + +import ( + "container/list" +) + +type frequencyParent[K comparable, V any] struct { + entries map[*cacheItem[K, V]]*list.Element // entries in this frequency to pointer to access list + frequency int + byAccess *list.List // linked list of all entries in access order +} + +// lfu implements a cache policy as described in +// ["An O(1) algorithm for implementing the lfu cache eviction scheme"]. +// +// A cache utilizing this policy is safe for concurrent use and has a +// runtime complexity of O(1) for all operations. +// +// ["An O(1) algorithm for implementing the lfu cache eviction scheme"]: https://arxiv.org/pdf/2110.11602.pdf +type lfu[K comparable, V any] struct { + cap int + frequencies *list.List // Linked list containing all frequencyParents in order of least frequently used +} + +// init initializes the LFU cache policy. +func (c *lfu[K, V]) init(capacity int) { + c.cap = capacity + c.frequencies = list.New() +} + +// capacity returns the capacity of the cache. +func (c *lfu[K, V]) capacity() int { + return c.cap +} + +// access is called when an item is accessed in the cache. It increments the +// frequency of the item. +func (c *lfu[K, V]) access(item *cacheItem[K, V]) { + c.increment(item) +} + +// admit is called when an item is added to the cache. It increments the +// frequency of the item. +func (c *lfu[K, V]) admit(item *cacheItem[K, V]) { + c.increment(item) +} + +// remove is called when an item is removed from the cache. It removes the item +// from the frequency. +func (c *lfu[K, V]) remove(item *cacheItem[K, V]) { + c.delete(item.parent, item) +} + +// victim returns the least frequently used item in the cache. +func (c *lfu[K, V]) victim() *cacheItem[K, V] { + if frequency := c.frequencies.Front(); frequency != nil { + elem := frequency.Value.(*frequencyParent[K, V]).byAccess.Front() + if elem != nil { + return elem.Value.(*cacheItem[K, V]) + } + } + + return nil +} + +// increment the frequency of the given item. If the frequency parent +// does not exist, it is created. +func (c *lfu[K, V]) increment(item *cacheItem[K, V]) { + current := item.parent + + // next will be this item's new parent + var next *list.Element + + // nextAmount will be the new frequency for this item + var nextAmount int + + if current == nil { + // the item has not yet been assigned a frequency so + // this is the first time it is being accessed + nextAmount = 1 + + // set next to the first frequency + next = c.frequencies.Front() + } else { + // increment the access frequency for the item + nextAmount = current.Value.(*frequencyParent[K, V]).frequency + 1 + + // set next to the next greater frequency + next = current.Next() + } + + // if the next frequency does not exist or the next frequency is not the + // next frequency amount, create a new frequency item and insert it + // after the current frequency + if next == nil || next.Value.(*frequencyParent[K, V]).frequency != nextAmount { + newFrequencyParent := &frequencyParent[K, V]{ + entries: make(map[*cacheItem[K, V]]*list.Element), + frequency: nextAmount, + byAccess: list.New(), + } + + if current == nil { + // current is nil so insert the new frequency item at the front + next = c.frequencies.PushFront(newFrequencyParent) + } else { + // otherwise insert the new frequency item after the current + next = c.frequencies.InsertAfter(newFrequencyParent, current) + } + } + + // set the item's parent to the next frequency + item.parent = next + + // add the item to the frequency's access list + nextAccess := next.Value.(*frequencyParent[K, V]).byAccess.PushBack(item) + + // add the item to the frequency's entries with a pointer to the access list + next.Value.(*frequencyParent[K, V]).entries[item] = nextAccess + + // if the item was previously assigned a frequency, remove it from the + // old frequency's entries + if current != nil { + c.delete(current, item) + } +} + +// delete removes the given item from the frequency and removes the frequency +// if it is empty. +func (c *lfu[K, V]) delete(frequency *list.Element, item *cacheItem[K, V]) { + frequencyParent := frequency.Value.(*frequencyParent[K, V]) + + // remove the item from the frequency's access list + frequencyParent.byAccess.Remove(frequencyParent.entries[item]) + + // remove the item from the frequency's entries + delete(frequencyParent.entries, item) + + if len(frequencyParent.entries) == 0 { + frequencyParent.entries = nil + frequencyParent.byAccess = nil + + c.frequencies.Remove(frequency) + } +} + +// close removes all items from the cache, sends a close event on the events +// channel, and waits for the cache to close. +func (c *lfu[K, V]) close() { + c.frequencies = nil + c.cap = 0 +} diff --git a/go/appencryption/pkg/cache/lfu_example_test.go b/go/appencryption/pkg/cache/lfu_example_test.go new file mode 100644 index 000000000..69a6e423c --- /dev/null +++ b/go/appencryption/pkg/cache/lfu_example_test.go @@ -0,0 +1,46 @@ +package cache_test + +import ( + "fmt" + + "github.com/godaddy/asherah/go/appencryption/pkg/cache" +) + +func ExampleNew() { + evictionMsg := make(chan string) + + // This callback is executed via a background goroutine whenever an + // item is evicted from the cache. We use a channel to synchronize + // the goroutine with this example function so we can verify the + // item that was evicted. + evict := func(key int, value string) { + evictionMsg <- fmt.Sprintln("evicted:", key, value) + } + + // Create a new LFU cache with a capacity of 3 items and an eviction callback. + cache := cache.New[int, string](3).LFU().WithEvictFunc(evict).Build() + + // Add some items to the cache. + cache.Set(1, "foo") + cache.Set(2, "bar") + cache.Set(3, "baz") + + // Get an item from the cache. + value, ok := cache.Get(1) + if ok { + fmt.Println("got:", value) + } + + // Set a new value for an existing key + cache.Set(2, "two") + + // Add another item to the cache which will evict the least frequently used + // item (3). + cache.Set(4, "qux") + + // Print the eviction message sent via the callback above. + fmt.Print(<-evictionMsg) + // Output: + // got: foo + // evicted: 3 baz +} diff --git a/go/appencryption/pkg/cache/lfu_test.go b/go/appencryption/pkg/cache/lfu_test.go new file mode 100644 index 000000000..5869697fa --- /dev/null +++ b/go/appencryption/pkg/cache/lfu_test.go @@ -0,0 +1,160 @@ +package cache_test + +import ( + "fmt" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/suite" + + "github.com/godaddy/asherah/go/appencryption/pkg/cache" +) + +type LFUSuite struct { + suite.Suite + cache cache.Interface[int, string] +} + +func TestLFUSuite(t *testing.T) { + suite.Run(t, new(LFUSuite)) +} + +func (suite *LFUSuite) SetupTest() { + suite.cache = cache.New[int, string](2).LFU().Build() +} + +func (suite *LFUSuite) TestNewLFU() { + suite.Assert().Equal(0, suite.cache.Len()) + suite.Assert().Equal(2, suite.cache.Capacity()) +} + +func (suite *LFUSuite) TestSet() { + suite.cache.Set(1, "one") + suite.Assert().Equal(1, suite.cache.Len()) + + suite.cache.Set(2, "two") + suite.Assert().Equal(2, suite.cache.Len()) + + suite.cache.Set(3, "three") + suite.Assert().Equal(2, suite.cache.Len()) +} + +func (suite *LFUSuite) TestGet() { + suite.cache.Set(1, "one") + suite.cache.Set(2, "two") + + one, ok := suite.cache.Get(1) + suite.Assert().Equal("one", one) + suite.Assert().True(ok) + + two, ok := suite.cache.Get(2) + suite.Assert().Equal("two", two) + suite.Assert().True(ok) + + val, ok := suite.cache.Get(3) + suite.Assert().False(ok) + suite.Assert().Equal("", val) +} + +func (suite *LFUSuite) TestGetOrPanic() { + suite.cache.Set(1, "one") + suite.cache.Set(2, "two") + + suite.Assert().Equal("one", suite.cache.GetOrPanic(1)) + suite.Assert().Equal("two", suite.cache.GetOrPanic(2)) + + suite.Assert().Panics(func() { suite.cache.GetOrPanic(3) }) +} + +func (suite *LFUSuite) TestDelete() { + suite.cache.Set(1, "one") + suite.cache.Set(2, "two") + + suite.Assert().Equal(2, suite.cache.Len()) + + // ensure the key is deleted and the size is decremented + ok := suite.cache.Delete(1) + suite.Assert().True(ok) + suite.Assert().Equal(1, suite.cache.Len()) + + // subsequent delete should return false + ok = suite.cache.Delete(1) + suite.Assert().False(ok) + + // ensure the key is no longer in the cache + one, ok := suite.cache.Get(1) + suite.Assert().Equal("", one) + suite.Assert().False(ok) + + suite.cache.Delete(2) + suite.Assert().Equal(0, suite.cache.Len()) +} + +func (suite *LFUSuite) TestEviction() { + suite.cache.Set(1, "one") + suite.cache.Set(2, "two") + + // access 1 to increase frequency + suite.cache.Set(1, "one") + + suite.cache.Set(3, "three") + + _, ok := suite.cache.Get(1) + suite.Assert().True(ok) + + // 2 should be evicted as it has the lowest frequency + _, ok = suite.cache.Get(2) + suite.Assert().False(ok) + + _, ok = suite.cache.Get(3) + suite.Assert().True(ok) +} + +func (suite *LFUSuite) TestClose() { + suite.cache.Set(1, "one") + suite.cache.Set(2, "two") + + suite.cache.Close() + + suite.Assert().Equal(0, suite.cache.Len()) + suite.Assert().Equal(0, suite.cache.Capacity()) +} + +func (suite *LFUSuite) TestWithEvictFunc() { + mux := sync.Mutex{} + evicted := map[int]int{} + + suite.cache = cache.New[int, string](100). + WithEvictFunc(func(key int, _ string) { + mux.Lock() + evicted[key] = 1 + mux.Unlock() + }). + LFU(). + Build() + + // overfill the cache + for i := 0; i < 105; i++ { + suite.cache.Set(i, fmt.Sprintf("value-%d", i)) + } + + // wait for the background goroutine to evict items + suite.Assert().Eventually(func() bool { + mux.Lock() + defer mux.Unlock() + + return len(evicted) == 5 + }, 100*time.Millisecond, 10*time.Millisecond, "eviction callback was not called") + + // verify the first five items were evicted + for i := 0; i < 5; i++ { + suite.Assert().Contains(evicted, i) + } + + // close the cache and evict the remaining items + suite.cache.Close() + + suite.Assert().Equal(0, suite.cache.Len()) + suite.Assert().Equal(105, len(evicted)) +} diff --git a/go/appencryption/pkg/cache/lru.go b/go/appencryption/pkg/cache/lru.go new file mode 100644 index 000000000..301bd062f --- /dev/null +++ b/go/appencryption/pkg/cache/lru.go @@ -0,0 +1,168 @@ +//nolint:forcetypeassert // we know the type of the value +package cache + +import ( + "container/list" +) + +// lru is a least recently used cache policy implementation. +type lru[K comparable, V any] struct { + cap int + evictList *list.List +} + +// init initializes the LRU cache policy. +func (c *lru[K, V]) init(capacity int) { + c.cap = capacity + c.evictList = list.New() +} + +// capacity returns the capacity of the cache. +func (c *lru[K, V]) capacity() int { + return c.cap +} + +// len returns the number of items in the cache. +func (c *lru[K, V]) len() int { + return c.evictList.Len() +} + +// access is called when an item is accessed in the cache. It moves the item to +// the front of the eviction list. +func (c *lru[K, V]) access(item *cacheItem[K, V]) { + c.evictList.MoveToFront(item.parent) +} + +// admit is called when an item is added to the cache. It adds the item to the +// front of the eviction list. +func (c *lru[K, V]) admit(item *cacheItem[K, V]) { + item.parent = c.evictList.PushFront(item) +} + +// remove is called when an item is removed from the cache. It removes the item +// from the eviction list. +func (c *lru[K, V]) remove(item *cacheItem[K, V]) { + c.evictList.Remove(item.parent) +} + +// victim returns the least recently used item in the cache. +func (c *lru[K, V]) victim() *cacheItem[K, V] { + oldest := c.evictList.Back() + if oldest == nil { + return nil + } + + return oldest.Value.(*cacheItem[K, V]) +} + +// close implements the policy interface. +func (c *lru[K, V]) close() { + c.evictList = nil + c.cap = 0 +} + +const protectedRatio = 0.8 + +// slruItem is an item in the SLRU cache. +type slruItem[K comparable, V any] struct { + *cacheItem[K, V] + protected bool +} + +// slru is a Segmented LRU cache policy implementation. +type slru[K comparable, V any] struct { + cap int + + protectedCapacity int + protectedList *list.List + + probationCapacity int + probationList *list.List +} + +// init initializes the SLRU cache policy. +func (c *slru[K, V]) init(capacity int) { + c.cap = capacity + + c.protectedList = list.New() + c.probationList = list.New() + + c.protectedCapacity = int(float64(capacity) * protectedRatio) + c.probationCapacity = capacity - c.protectedCapacity +} + +// capacity returns the capacity of the cache. +func (c *slru[K, V]) capacity() int { + return c.cap +} + +// access is called when an item is accessed in the cache. It moves the item to +// the front of its respective eviction list. +func (c *slru[K, V]) access(item *cacheItem[K, V]) { + sitem := item.parent.Value.(*slruItem[K, V]) + if sitem.protected { + c.protectedList.MoveToFront(item.parent) + return + } + + // must be in probation list, promote to protected list + sitem.protected = true + + c.probationList.Remove(item.parent) + + item.parent = c.protectedList.PushFront(sitem) + + // if the protected list is too big, demote the oldest item to the probation list + if c.protectedList.Len() > c.protectedCapacity { + b := c.protectedList.Back() + c.protectedList.Remove(b) + + bitem := b.Value.(*slruItem[K, V]) + bitem.protected = false + + bitem.parent = c.probationList.PushFront(bitem) + } +} + +// admit is called when an item is added to the cache. It adds the item to the +// front of the probation list. +func (c *slru[K, V]) admit(item *cacheItem[K, V]) { + newItem := &slruItem[K, V]{ + cacheItem: item, + protected: false, + } + + item.parent = c.probationList.PushFront(newItem) +} + +// victim returns the least recently used item in the cache. +func (c *slru[K, V]) victim() *cacheItem[K, V] { + if c.probationList.Len() > 0 { + return c.probationList.Back().Value.(*slruItem[K, V]).cacheItem + } + + if c.protectedList.Len() > 0 { + return c.protectedList.Back().Value.(*slruItem[K, V]).cacheItem + } + + return nil +} + +// remove is called when an item is removed from the cache. It removes the item +// from the eviction list. +func (c *slru[K, V]) remove(item *cacheItem[K, V]) { + sitem := item.parent.Value.(*slruItem[K, V]) + if sitem.protected { + c.protectedList.Remove(item.parent) + return + } + + c.probationList.Remove(item.parent) +} + +// close implements the policy interface. +func (c *slru[K, V]) close() { + c.protectedList = nil + c.probationList = nil + c.cap = 0 +} diff --git a/go/appencryption/pkg/cache/lru_test.go b/go/appencryption/pkg/cache/lru_test.go new file mode 100644 index 000000000..654af8197 --- /dev/null +++ b/go/appencryption/pkg/cache/lru_test.go @@ -0,0 +1,284 @@ +package cache_test + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/suite" + + "github.com/godaddy/asherah/go/appencryption/pkg/cache" +) + +type LRUSuite struct { + suite.Suite + cache cache.Interface[int, string] +} + +func TestLRUSuite(t *testing.T) { + suite.Run(t, new(LRUSuite)) +} + +func (suite *LRUSuite) SetupTest() { + suite.cache = cache.New[int, string](10).Build() +} + +func (suite *LRUSuite) TestNewLRU() { + suite.Assert().Equal(0, suite.cache.Len()) + suite.Assert().Equal(10, suite.cache.Capacity()) +} + +func (suite *LRUSuite) TestSet() { + // fill to capacity + for i := 0; i < suite.cache.Capacity(); i++ { + suite.cache.Set(i, fmt.Sprintf("#%d", i)) + } + + // verify size + suite.Assert().Equal(suite.cache.Capacity(), suite.cache.Len()) +} + +func (suite *LRUSuite) TestGet() { + suite.cache.Set(1, "one") + suite.cache.Set(2, "two") + + one, ok := suite.cache.Get(1) + suite.Assert().Equal("one", one) + suite.Assert().True(ok) + + two, ok := suite.cache.Get(2) + suite.Assert().Equal("two", two) + suite.Assert().True(ok) + + val, ok := suite.cache.Get(3) + suite.Assert().False(ok) + suite.Assert().Equal("", val) +} + +func (suite *LRUSuite) TestGetOrPanic() { + suite.cache.Set(1, "one") + suite.cache.Set(2, "two") + + one := suite.cache.GetOrPanic(1) + suite.Assert().Equal("one", one) + + two := suite.cache.GetOrPanic(2) + suite.Assert().Equal("two", two) + + suite.Assert().Panics(func() { suite.cache.GetOrPanic(3) }) +} + +func (suite *LRUSuite) TestDelete() { + suite.cache.Set(1, "one") + suite.cache.Set(2, "two") + + suite.Assert().Equal(2, suite.cache.Len()) + + suite.cache.Delete(1) + suite.Assert().Equal(1, suite.cache.Len()) + + suite.cache.Delete(2) + suite.Assert().Equal(0, suite.cache.Len()) +} + +func (suite *LRUSuite) TestClose() { + suite.cache.Set(1, "one") + suite.cache.Set(2, "two") + + suite.Assert().Equal(2, suite.cache.Len()) + + suite.cache.Close() + suite.Assert().Equal(0, suite.cache.Len()) +} + +func (suite *LRUSuite) TestEviction() { + // fill the cache to capacity + for i := 0; i < suite.cache.Capacity(); i++ { + suite.cache.Set(i, fmt.Sprintf("#%d", i)) + } + + // access the first item to make it the most recently used + suite.cache.Get(0) + + // add a new item to the cache + suite.cache.Set(10, "#10") + + // the least recently used item should have been evicted + _, ok := suite.cache.Get(1) + suite.Assert().False(ok) + + // the most recently used item should still be in the cache + _, ok = suite.cache.Get(0) + suite.Assert().True(ok) + + suite.Assert().Equal(10, suite.cache.Len()) +} + +func (suite *LRUSuite) TestWithEvictFunc() { + done := make(chan struct{}) + + evicted := false + cache := cache.New[int, string](1).WithEvictFunc(func(key int, value string) { + evicted = true + + suite.Assert().Equal(1, key) + suite.Assert().Equal("one", value) + + close(done) + }).Build() + + cache.Set(1, "one") + cache.Set(2, "two") + + <-done + + suite.Assert().True(evicted) + suite.Assert().Equal(1, cache.Len()) +} + +type SLRUSuite struct { + suite.Suite + cache cache.Interface[int, string] +} + +func TestSLRUSuite(t *testing.T) { + suite.Run(t, new(SLRUSuite)) +} + +func (suite *SLRUSuite) SetupTest() { + suite.cache = cache.New[int, string](10).SLRU().Build() +} + +func (suite *SLRUSuite) TestNewSLRU() { + suite.Assert().Equal(0, suite.cache.Len()) + suite.Assert().Equal(10, suite.cache.Capacity()) +} + +func (suite *SLRUSuite) TestSet() { + suite.cache.Set(1, "one") + suite.Assert().Equal(1, suite.cache.Len()) + + suite.cache.Set(2, "two") + suite.Assert().Equal(2, suite.cache.Len()) + + suite.cache.Set(3, "three") + suite.Assert().Equal(3, suite.cache.Len()) +} + +func (suite *SLRUSuite) TestGet() { + suite.cache.Set(1, "one") + suite.cache.Set(2, "two") + + one, ok := suite.cache.Get(1) + suite.Assert().Equal("one", one) + suite.Assert().True(ok) + + two, ok := suite.cache.Get(2) + suite.Assert().Equal("two", two) + suite.Assert().True(ok) + + val, ok := suite.cache.Get(3) + suite.Assert().False(ok) + suite.Assert().Equal("", val) +} + +func (suite *SLRUSuite) TestGetOrPanic() { + suite.cache.Set(1, "one") + suite.cache.Set(2, "two") + + one := suite.cache.GetOrPanic(1) + suite.Assert().Equal("one", one) + + two := suite.cache.GetOrPanic(2) + suite.Assert().Equal("two", two) + + suite.Assert().Panics(func() { suite.cache.GetOrPanic(3) }) +} + +func (suite *SLRUSuite) TestDelete() { + suite.cache.Set(1, "one") + suite.cache.Set(2, "two") + + suite.Assert().Equal(2, suite.cache.Len()) + + suite.cache.Delete(1) + suite.Assert().Equal(1, suite.cache.Len()) + + suite.cache.Delete(2) + suite.Assert().Equal(0, suite.cache.Len()) +} + +func (suite *SLRUSuite) TestClose() { + suite.cache.Set(1, "one") + suite.cache.Set(2, "two") + + // throw in a get for good measure + suite.cache.Get(1) + + suite.Assert().Equal(2, suite.cache.Len()) + + suite.cache.Close() + suite.Assert().Equal(0, suite.cache.Len()) +} + +func (suite *SLRUSuite) TestCloseEmpty() { + suite.cache.Close() +} + +func (suite *SLRUSuite) TestEviction() { + // fill the cache to capacity + for i := 0; i < suite.cache.Capacity(); i++ { + suite.cache.Set(i, fmt.Sprintf("#%d", i)) + } + + // access the first item to make it the most recently used + suite.cache.Get(0) + + // add a new item to the cache + suite.cache.Set(10, "#10") + + // the least recently used item should have been evicted + _, ok := suite.cache.Get(1) + suite.Assert().False(ok) + + // the most recently used item should still be in the cache + _, ok = suite.cache.Get(0) + suite.Assert().True(ok) + + // verify other items are still in the cache + for i := 2; i < suite.cache.Capacity(); i++ { + _, ok := suite.cache.Get(i) + suite.Assert().True(ok) + } + + suite.Assert().Equal(10, suite.cache.Len()) +} + +func (suite *SLRUSuite) TestWithEvictFunc() { + done := make(chan struct{}) + + evicted := false + cache := cache.New[int, string](10).SLRU().WithEvictFunc(func(key int, value string) { + evicted = true + + suite.Assert().Equal(1, key) + suite.Assert().Equal("#1", value) + + close(done) + }).Build() + + // fill the cache to capacity + for i := 0; i < cache.Capacity(); i++ { + cache.Set(i, fmt.Sprintf("#%d", i)) + } + + // access the first item to make it the most recently used + cache.Get(0) + + // add a new item to the cache + cache.Set(10, "#10") + + <-done + + suite.Assert().True(evicted) + suite.Assert().Equal(10, cache.Len()) +} diff --git a/go/appencryption/pkg/cache/tlfu.go b/go/appencryption/pkg/cache/tlfu.go new file mode 100644 index 000000000..3cb87d6ab --- /dev/null +++ b/go/appencryption/pkg/cache/tlfu.go @@ -0,0 +1,202 @@ +package cache + +import ( + "github.com/godaddy/asherah/go/appencryption/pkg/cache/internal" +) + +const ( + samplesMultiplier = 8 + insertionsMultiplier = 2 + countersMultiplier = 1 + falsePositiveProbability = 0.1 + admissionRatio = 0.01 +) + +// tinyLFUEntry is an entry in the tinyLFU cache. +type tinyLFUEntry[K comparable, V any] struct { + hash uint64 + parent policy[K, V] +} + +// tinyLFU is a tiny LFU cache policy implementation derived from +// [Mango Cache] and based on the algorithm described in the paper +// ["TinyLFU: A Highly Efficient Cache Admission Policy"] by Gil Einziger, +// Roy Friedman, and Ben Manes. +// +// [Mango Cache]: https://github.com/goburrow/cache +// ["TinyLFU: A Highly Efficient Cache Admission Policy"]: https://arxiv.org/pdf/1512.00727v2.pdf +type tinyLFU[K comparable, V any] struct { + cap int + + filter internal.BloomFilter // 1bit counter + counter internal.CountMinSketch // 4bit counter + + additions int + samples int + + lru lru[K, V] + slru slru[K, V] + + keys map[K]tinyLFUEntry[K, V] // Hashmap containing *tinyLFUEntry for O(1) access +} + +// init initializes the tinyLFU cache policy. +func (c *tinyLFU[K, V]) init(capacity int) { + c.cap = capacity + + c.keys = make(map[K]tinyLFUEntry[K, V]) + + c.samples = capacity * samplesMultiplier + + c.filter.Init(capacity*insertionsMultiplier, falsePositiveProbability) + c.counter.Init(capacity * countersMultiplier) + + // The admission window is a fixed percentage of the cache capacity. + // The LRU is the first part of the admission window, and the SLRU is + // the second part. + // + // Note that for small cache sizes the admission window may be 0, in which + // case the SLRU is the entire cache and the doorkeeper is not used. + lruCap := int(float64(capacity) * admissionRatio) + c.lru.init(lruCap) + + slruCap := capacity - lruCap + c.slru.init(slruCap) +} + +// capacity returns the capacity of the cache. +func (c *tinyLFU[K, V]) capacity() int { + return c.cap +} + +// access is called when an item is accessed in the cache. It increments the +// frequency of the item. +func (c *tinyLFU[K, V]) access(item *cacheItem[K, V]) { + c.increment(item) + + c.keys[item.key].parent.access(item) +} + +// admit is called when an item is added to the cache. It increments the +// frequency of the item. +func (c *tinyLFU[K, V]) admit(item *cacheItem[K, V]) { + if c.bypassed() { + c.slru.admit(item) + return + } + + c.increment(item) + + // If there's room in the admission window, add it to the LRU + if c.lru.len() < c.lru.cap { + c.admitTo(item, &c.lru) + + return + } + + victim := c.lru.victim() + + // Otherwise, promote the victim from the LRU to the SLRU + c.lru.remove(victim) + c.admitTo(victim, &c.slru) + + // then add the new item to the LRU + c.admitTo(item, &c.lru) +} + +// bypassed returns true if the doorkeeper is not in use. +func (c *tinyLFU[K, V]) bypassed() bool { + return c.lru.cap == 0 +} + +// admitTo adds the item to the provided eviction list. +func (c *tinyLFU[K, V]) admitTo(item *cacheItem[K, V], list policy[K, V]) { + list.admit(item) + + c.keys[item.key] = tinyLFUEntry[K, V]{ + hash: internal.ComputeHash(item.key), + parent: list, + } +} + +// victim returns the victim item to be evicted. +func (c *tinyLFU[K, V]) victim() *cacheItem[K, V] { + candidate := c.lru.victim() + + // If the LRU is empty, just return the SLRU victim. + // This is the case when the cache is closing and + // the items are being purged. + if candidate == nil { + return c.slru.victim() + } + + victim := c.slru.victim() + + // If the SLRU is empty, just return the LRU victim. + if victim == nil { + return candidate + } + + // we have both a candidate and a victim + // ...may the best item win! + candidateFreq := c.estimate(c.keys[candidate.key].hash) + victimFreq := c.estimate(c.keys[victim.key].hash) + + // If the candidate is more frequently accessed than the victim, + // remove the candidate from the LRU and add it to the SLRU. + if candidateFreq > victimFreq { + c.lru.remove(candidate) + + c.admitTo(candidate, &c.slru) + + return victim + } + + return candidate +} + +// estimate returns the estimated frequency of the item. +func (c *tinyLFU[K, V]) estimate(h uint64) uint8 { + freq := c.counter.Estimate(h) + if c.filter.Contains(h) { + freq++ + } + + return freq +} + +// remove is called when an item is removed from the cache. It removes the item +// from the appropriate eviction list. +func (c *tinyLFU[K, V]) remove(item *cacheItem[K, V]) { + c.keys[item.key].parent.remove(item) +} + +// increment increments the frequency of the item. +func (c *tinyLFU[K, V]) increment(item *cacheItem[K, V]) { + if c.bypassed() { + return + } + + c.additions++ + + if c.additions >= c.samples { + c.filter.Reset() + c.counter.Reset() + + c.additions = 0 + } + + k := c.keys[item.key] + + if c.filter.Put(k.hash) { + c.counter.Add(k.hash) + } +} + +// close removes all items from the cache. +func (c *tinyLFU[K, V]) close() { + c.lru.close() + c.slru.close() + + c.cap = 0 +} diff --git a/go/appencryption/pkg/cache/tlfu_test.go b/go/appencryption/pkg/cache/tlfu_test.go new file mode 100644 index 000000000..ec95bc04d --- /dev/null +++ b/go/appencryption/pkg/cache/tlfu_test.go @@ -0,0 +1,134 @@ +package cache_test + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/suite" + + "github.com/godaddy/asherah/go/appencryption/pkg/cache" +) + +type TinyLFUSuite struct { + suite.Suite + cache cache.Interface[int, string] +} + +func TestTinyLFUSuite(t *testing.T) { + // t.SkipNow() + suite.Run(t, new(TinyLFUSuite)) +} + +func (suite *TinyLFUSuite) SetupTest() { + suite.cache = cache.New[int, string](100).TinyLFU().Build() +} + +func (suite *TinyLFUSuite) TestNewTinyLFU() { + suite.Assert().Equal(0, suite.cache.Len()) + suite.Assert().Equal(100, suite.cache.Capacity()) +} + +func (suite *TinyLFUSuite) TestSet() { + // fill cache + for i := 0; i < suite.cache.Capacity(); i++ { + suite.cache.Set(i, fmt.Sprintf("%d", i)) + } + + suite.Assert().Equal(suite.cache.Capacity(), suite.cache.Len()) + + // add one more + suite.cache.Set(100, "one hundred") + suite.Assert().Equal(suite.cache.Capacity(), suite.cache.Len()) +} + +func (suite *TinyLFUSuite) TestGet() { + suite.cache.Set(1, "one") + suite.cache.Set(2, "two") + + one, ok := suite.cache.Get(1) + suite.Assert().Equal("one", one) + suite.Assert().True(ok) + + two, ok := suite.cache.Get(2) + suite.Assert().Equal("two", two) + suite.Assert().True(ok) + + val, ok := suite.cache.Get(3) + suite.Assert().False(ok) + suite.Assert().Equal("", val) +} + +func (suite *TinyLFUSuite) TestGetOrPanic() { + suite.cache.Set(1, "one") + suite.cache.Set(2, "two") + + one := suite.cache.GetOrPanic(1) + suite.Assert().Equal("one", one) + + two := suite.cache.GetOrPanic(2) + suite.Assert().Equal("two", two) + + suite.Assert().Panics(func() { suite.cache.GetOrPanic(3) }) +} + +func (suite *TinyLFUSuite) TestDelete() { + suite.cache.Set(1, "one") + suite.cache.Set(2, "two") + + suite.Assert().True(suite.cache.Delete(1)) + suite.Assert().Equal(1, suite.cache.Len()) + + suite.Assert().False(suite.cache.Delete(3)) + suite.Assert().Equal(1, suite.cache.Len()) +} + +func (suite *TinyLFUSuite) TestEvict() { + // fill the cache to capacity + for i := 0; i < suite.cache.Capacity(); i++ { + suite.cache.Set(i, fmt.Sprintf("#%d", i)) + } + + // access half of the items + for i := 0; i < suite.cache.Capacity()/2; i++ { + _, ok := suite.cache.Get(i) + suite.Assert().True(ok) + } + + // add one more + suite.cache.Set(999, "nine ninety nine") + + // access the new item + _, ok := suite.cache.Get(999) + suite.Assert().True(ok) + + // verify the cache is at capacity + suite.Assert().Equal(suite.cache.Capacity(), suite.cache.Len()) + + // overwrite half of the items + for i := 0; i < suite.cache.Capacity(); i++ { + key := i + 1000 + suite.cache.Set(key, fmt.Sprintf("##%d", key)) + } + + // verify 999 is still in the cache + _, ok = suite.cache.Get(999) + suite.Assert().True(ok, "item 999 should be in the cache") + + // verify all of the previously accessed items are still in the cache + for i := 0; i < suite.cache.Capacity()/2; i++ { + _, ok := suite.cache.Get(i) + suite.Assert().True(ok, "item %d should be in the cache", i) + } +} + +func (suite *TinyLFUSuite) TestClose() { + suite.cache.Set(1, "one") + suite.cache.Set(2, "two") + + suite.Assert().Equal(2, suite.cache.Len()) + + suite.cache.Close() + + suite.Assert().Equal(0, suite.cache.Len()) + suite.Assert().Equal(0, suite.cache.Capacity()) +} diff --git a/go/appencryption/policy.go b/go/appencryption/policy.go index fd39aaae9..420c27430 100644 --- a/go/appencryption/policy.go +++ b/go/appencryption/policy.go @@ -9,6 +9,7 @@ const ( DefaultExpireAfter = time.Hour * 24 * 90 // 90 days DefaultRevokedCheckInterval = time.Minute * 60 DefaultCreateDatePrecision = time.Minute + DefaultKeyCacheMaxSize = 1000 DefaultSessionCacheMaxSize = 1000 DefaultSessionCacheDuration = time.Hour * 2 DefaultSessionCacheEngine = "default" @@ -27,8 +28,26 @@ type CryptoPolicy struct { CreateDatePrecision time.Duration // CacheIntermediateKeys determines whether Intermediate Keys will be cached. CacheIntermediateKeys bool + // IntermediateKeyCacheMaxSize controls the maximum size of the cache if intermediate key caching is enabled. + IntermediateKeyCacheMaxSize int + // IntermediateKeyCacheEvictionPolicy controls the eviction policy to use for the shared cache. + // Supported values are "lru", "lfu", "slru", and "tinylfu". Default is "lru". + IntermediateKeyCacheEvictionPolicy string + // SharedIntermediateKeyCache determines whether Intermediate Keys will use a single shared cache. If enabled, + // Intermediate Keys will share a single cache across all sessions for a given factory. + // This option is useful if you have a large number of sessions and want to reduce the memory footprint of the + // cache. + // + // This option is ignored if CacheIntermediateKeys is disabled. + SharedIntermediateKeyCache bool // CacheSystemKeys determines whether System Keys will be cached. CacheSystemKeys bool + // SystemKeyCacheMaxSize controls the maximum size of the cache if system key caching is enabled. If + // SharedKeyCache is enabled, this value will determine the maximum size of the shared cache. + SystemKeyCacheMaxSize int + // SystemKeyCacheEvictionPolicy controls the eviction policy to use for the shared cache. + // Supported values are "lru", "lfu", "slru", and "tinylfu". Default is "lru". + SystemKeyCacheEvictionPolicy string // CacheSessions determines whether sessions will be cached. CacheSessions bool // SessionCacheMaxSize controls the maximum size of the cache if session caching is enabled. @@ -36,12 +55,9 @@ type CryptoPolicy struct { // SessionCacheDuration controls the amount of time a session will remain cached without being accessed // if session caching is enabled. SessionCacheDuration time.Duration - // WithSessionCacheEngine determines the underlying cache implemenataion in use by the session cache - // if session caching is enabled. - // - // Deprecated: multiple cache implementations are no longer supported and this option will be removed - // in a future release. - SessionCacheEngine string + // SessionCacheEvictionPolicy controls the eviction policy to use for the shared cache. + // Supported values are "lru", "lfu", "slru", and "tinylfu". Default is "slru". + SessionCacheEvictionPolicy string } // PolicyOption is used to configure a CryptoPolicy. @@ -69,6 +85,15 @@ func WithNoCache() PolicyOption { } } +// WithSharedIntermediateKeyCache enables a shared cache for Intermediate Keys with the provided capacity. The shared +// cache will be used by all sessions for a given factory. +func WithSharedIntermediateKeyCache(capacity int) PolicyOption { + return func(policy *CryptoPolicy) { + policy.SharedIntermediateKeyCache = true + policy.IntermediateKeyCacheMaxSize = capacity + } +} + // WithSessionCache enables session caching. When used all sessions for a given partition will share underlying // System and Intermediate Key caches. func WithSessionCache() PolicyOption { @@ -92,29 +117,20 @@ func WithSessionCacheDuration(d time.Duration) PolicyOption { } } -// WithSessionCacheEngine determines the underlying cache implemenataion in use by the session cache -// if session caching is enabled. -// -// Deprecated: multiple cache implementations are no longer supported and this option will be removed -// in a future release. -func WithSessionCacheEngine(engine string) PolicyOption { - return func(policy *CryptoPolicy) { - policy.SessionCacheEngine = engine - } -} - // NewCryptoPolicy returns a new CryptoPolicy with default values. func NewCryptoPolicy(opts ...PolicyOption) *CryptoPolicy { policy := &CryptoPolicy{ - ExpireKeyAfter: DefaultExpireAfter, - RevokeCheckInterval: DefaultRevokedCheckInterval, - CreateDatePrecision: DefaultCreateDatePrecision, - CacheSystemKeys: true, - CacheIntermediateKeys: true, - CacheSessions: false, - SessionCacheMaxSize: DefaultSessionCacheMaxSize, - SessionCacheDuration: DefaultSessionCacheDuration, - SessionCacheEngine: DefaultSessionCacheEngine, + ExpireKeyAfter: DefaultExpireAfter, + RevokeCheckInterval: DefaultRevokedCheckInterval, + CreateDatePrecision: DefaultCreateDatePrecision, + CacheSystemKeys: true, + CacheIntermediateKeys: true, + IntermediateKeyCacheMaxSize: DefaultKeyCacheMaxSize, + SystemKeyCacheMaxSize: DefaultKeyCacheMaxSize, + SharedIntermediateKeyCache: false, + CacheSessions: false, + SessionCacheMaxSize: DefaultSessionCacheMaxSize, + SessionCacheDuration: DefaultSessionCacheDuration, } for _, opt := range opts { @@ -124,12 +140,6 @@ func NewCryptoPolicy(opts ...PolicyOption) *CryptoPolicy { return policy } -// isKeyExpired checks if the key's created timestamp is older than the -// allowed number of days. -func isKeyExpired(created int64, expireAfter time.Duration) bool { - return time.Now().After(time.Unix(created, 0).Add(expireAfter)) -} - // newKeyTimestamp returns a unix timestamp in seconds truncated to the provided Duration. func newKeyTimestamp(truncate time.Duration) int64 { if truncate > 0 { diff --git a/go/appencryption/policy_test.go b/go/appencryption/policy_test.go index a3495a2e0..3c788c316 100644 --- a/go/appencryption/policy_test.go +++ b/go/appencryption/policy_test.go @@ -17,10 +17,12 @@ func Test_NewCryptoPolicy_WithDefaults(t *testing.T) { assert.Equal(t, DefaultCreateDatePrecision, p.CreateDatePrecision) assert.True(t, p.CacheSystemKeys) assert.True(t, p.CacheIntermediateKeys) + assert.Equal(t, DefaultKeyCacheMaxSize, p.SystemKeyCacheMaxSize) + assert.Equal(t, DefaultKeyCacheMaxSize, p.IntermediateKeyCacheMaxSize) + assert.False(t, p.SharedIntermediateKeyCache) assert.False(t, p.CacheSessions) assert.Equal(t, DefaultSessionCacheMaxSize, p.SessionCacheMaxSize) assert.Equal(t, DefaultSessionCacheDuration, p.SessionCacheDuration) - assert.Equal(t, DefaultSessionCacheEngine, p.SessionCacheEngine) } func Test_NewCryptoPolicy_WithOptions(t *testing.T) { @@ -28,7 +30,6 @@ func Test_NewCryptoPolicy_WithOptions(t *testing.T) { expireAfterDuration := time.Second * 100 sessionCacheMaxSize := 42 sessionCacheDuration := time.Second * 42 - sessionCacheEngine := "deprecated" policy := NewCryptoPolicy( WithRevokeCheckInterval(revokeCheckInterval), @@ -37,7 +38,6 @@ func Test_NewCryptoPolicy_WithOptions(t *testing.T) { WithSessionCache(), WithSessionCacheMaxSize(sessionCacheMaxSize), WithSessionCacheDuration(sessionCacheDuration), - WithSessionCacheEngine(sessionCacheEngine), ) assert.Equal(t, revokeCheckInterval, policy.RevokeCheckInterval) @@ -47,7 +47,33 @@ func Test_NewCryptoPolicy_WithOptions(t *testing.T) { assert.True(t, policy.CacheSessions) assert.Equal(t, sessionCacheMaxSize, policy.SessionCacheMaxSize) assert.Equal(t, sessionCacheDuration, policy.SessionCacheDuration) - assert.Equal(t, sessionCacheEngine, policy.SessionCacheEngine) +} + +func Test_NewCryptoPolicy_WithOptions_SharedIntermediateKeyCache(t *testing.T) { + revokeCheckInterval := time.Second * 156 + expireAfterDuration := time.Second * 100 + keyCacheMaxSize := 10 + sessionCacheMaxSize := 42 + sessionCacheDuration := time.Second * 42 + + policy := NewCryptoPolicy( + WithRevokeCheckInterval(revokeCheckInterval), + WithExpireAfterDuration(expireAfterDuration), + WithSharedIntermediateKeyCache(keyCacheMaxSize), + WithSessionCache(), + WithSessionCacheMaxSize(sessionCacheMaxSize), + WithSessionCacheDuration(sessionCacheDuration), + ) + + assert.Equal(t, revokeCheckInterval, policy.RevokeCheckInterval) + assert.Equal(t, expireAfterDuration, policy.ExpireKeyAfter) + assert.True(t, policy.CacheSystemKeys) + assert.True(t, policy.CacheIntermediateKeys) + assert.True(t, policy.SharedIntermediateKeyCache) + assert.Equal(t, keyCacheMaxSize, policy.IntermediateKeyCacheMaxSize) + assert.True(t, policy.CacheSessions) + assert.Equal(t, sessionCacheMaxSize, policy.SessionCacheMaxSize) + assert.Equal(t, sessionCacheDuration, policy.SessionCacheDuration) } func Test_IsKeyExpired(t *testing.T) { @@ -78,7 +104,7 @@ func Test_IsKeyExpired(t *testing.T) { key := internal.NewCryptoKeyForTest(tt.CreatedAt.Unix(), false) - verify.Equal(tt.Expect, isKeyExpired(key.Created(), time.Hour*24*time.Duration(tt.ExpireAfterDays))) + verify.Equal(tt.Expect, internal.IsKeyExpired(key.Created(), time.Hour*24*time.Duration(tt.ExpireAfterDays))) }) } } diff --git a/go/appencryption/session.go b/go/appencryption/session.go index e711f5aee..70dd83e1b 100644 --- a/go/appencryption/session.go +++ b/go/appencryption/session.go @@ -14,13 +14,14 @@ import ( // SessionFactory is used to create new encryption sessions and manage // the lifetime of the intermediate keys. type SessionFactory struct { - sessionCache sessionCache - systemKeys cache - Config *Config - Metastore Metastore - Crypto AEAD - KMS KeyManagementService - SecretFactory securememory.SecretFactory + sessionCache sessionCache + systemKeys keyCacher + intermediateKeys keyCacher // only used if shared key cache is enabled + Config *Config + Metastore Metastore + Crypto AEAD + KMS KeyManagementService + SecretFactory securememory.SecretFactory } // FactoryOption is used to configure additional options in a SessionFactory. @@ -48,21 +49,28 @@ func NewSessionFactory(config *Config, store Metastore, kms KeyManagementService config.Policy = NewCryptoPolicy() } - var skCache cache + var skCache keyCacher if config.Policy.CacheSystemKeys { - skCache = newKeyCache(config.Policy) + skCache = newKeyCache(CacheTypeSystemKeys, config.Policy) log.Debugf("new skCache: %v\n", skCache) } else { skCache = new(neverCache) } + var ikCache keyCacher + if config.Policy.SharedIntermediateKeyCache { + ikCache = newKeyCache(CacheTypeIntermediateKeys, config.Policy) + log.Debugf("new shared ikCache: %v\n", ikCache) + } + factory := &SessionFactory{ - systemKeys: skCache, - Config: config, - Metastore: store, - Crypto: crypto, - KMS: kms, - SecretFactory: new(memguard.SecretFactory), + systemKeys: skCache, + intermediateKeys: ikCache, + Config: config, + Metastore: store, + Crypto: crypto, + KMS: kms, + SecretFactory: new(memguard.SecretFactory), } if config.Policy.CacheSessions { @@ -85,6 +93,10 @@ func (f *SessionFactory) Close() error { f.sessionCache.Close() } + if f.Config.Policy.SharedIntermediateKeyCache { + f.intermediateKeys.Close() + } + return f.systemKeys.Close() } @@ -102,17 +114,29 @@ func (f *SessionFactory) GetSession(id string) (*Session, error) { } func newSession(f *SessionFactory, id string) (*Session, error) { + skCache := f.systemKeys + + var ikCache keyCacher + if f.Config.Policy.SharedIntermediateKeyCache { + ikCache = f.intermediateKeys + } else { + ikCache = f.newIKCache() + } + s := &Session{ encryption: &envelopeEncryption{ - partition: f.newPartition(id), - Metastore: f.Metastore, - KMS: f.KMS, - Policy: f.Config.Policy, - Crypto: f.Crypto, - SecretFactory: f.SecretFactory, - systemKeys: f.systemKeys, - intermediateKeys: f.newIKCache(), + partition: f.newPartition(id), + Metastore: f.Metastore, + KMS: f.KMS, + Policy: f.Config.Policy, + Crypto: f.Crypto, + SecretFactory: f.SecretFactory, + skCache: skCache, + ikCache: ikCache, }, + + ikCache: ikCache, + skCache: skCache, } log.Debugf("[newSession] for id %s. Session(%p){Encryption(%p)}", id, s, s.encryption) @@ -128,9 +152,9 @@ func (f *SessionFactory) newPartition(id string) partition { return newPartition(id, f.Config.Service, f.Config.Product) } -func (f *SessionFactory) newIKCache() cache { +func (f *SessionFactory) newIKCache() keyCacher { if f.Config.Policy.CacheIntermediateKeys { - return newKeyCache(f.Config.Policy) + return newKeyCache(CacheTypeIntermediateKeys, f.Config.Policy) } return new(neverCache) @@ -139,6 +163,9 @@ func (f *SessionFactory) newIKCache() cache { // Session is used to encrypt and decrypt data related to a specific partition ID. type Session struct { encryption Encryption + + ikCache keyCacher + skCache keyCacher } // Encrypt encrypts a provided slice of bytes and returns a DataRowRecord, which contains required diff --git a/go/appencryption/session_cache.go b/go/appencryption/session_cache.go index 443450b42..ae7e285eb 100644 --- a/go/appencryption/session_cache.go +++ b/go/appencryption/session_cache.go @@ -4,8 +4,7 @@ import ( "sync" "time" - mango "github.com/goburrow/cache" - + "github.com/godaddy/asherah/go/appencryption/pkg/cache" "github.com/godaddy/asherah/go/appencryption/pkg/log" ) @@ -15,198 +14,10 @@ type sessionCache interface { Close() } -// cacheStash is a temporary staging ground for the session cache. -type cacheStash struct { - tmp map[string]*Session - mux sync.RWMutex - events chan stashEvent -} - -type event uint8 - -const ( - stashClose event = iota - stashRemove -) - -type stashEvent struct { - id string - event event -} - -func (c *cacheStash) process() { - for e := range c.events { - switch e.event { - case stashRemove: - c.mux.Lock() - delete(c.tmp, e.id) - c.mux.Unlock() - case stashClose: - close(c.events) - - return - } - } -} - -func (c *cacheStash) add(id string, s *Session) { - c.mux.Lock() - c.tmp[id] = s - c.mux.Unlock() -} - -func (c *cacheStash) get(id string) (s *Session, ok bool) { - c.mux.RLock() - s, ok = c.tmp[id] - c.mux.RUnlock() - - return s, ok -} - -func (c *cacheStash) remove(id string) { - c.events <- stashEvent{ - id: id, - event: stashRemove, - } -} - -func (c *cacheStash) close() { - c.events <- stashEvent{ - event: stashClose, - } -} - -func (c *cacheStash) len() int { - c.mux.RLock() - defer c.mux.RUnlock() - - return len(c.tmp) -} - -func newCacheStash() *cacheStash { - return &cacheStash{ - tmp: make(map[string]*Session), - events: make(chan stashEvent), - } -} - -// mangoCache is a sessionCache implementation based on goburrow's -// Mango cache (https://github.com/goburrow/cache). -type mangoCache struct { - inner mango.LoadingCache - loader sessionLoaderFunc - - // mu protects the inner queue - mu sync.Mutex - - stash *cacheStash -} - -func (m *mangoCache) Get(id string) (*Session, error) { - m.mu.Lock() - defer m.mu.Unlock() - - sess, err := m.getOrAdd(id) - if err != nil { - return nil, err - } - - incrementSharedSessionUsage(sess) - - return sess, nil -} - -func (m *mangoCache) getOrAdd(id string) (*Session, error) { - // (fast path) if it's cached return it immediately - if val, ok := m.inner.GetIfPresent(id); ok { - sess := sessionOrPanic(val) - - m.stash.remove(id) - - return sess, nil - } - - // check the stash first to prevent mango from reloading a value currently in queue to be cached. - if sess, ok := m.stash.get(id); ok { - return sess, nil - } - - // m.inner.Get will add a new item via the loader on cache miss. However, newly loaded keys are added to - // the cache asynchronously, so we'll need to add it to the stash down below. - val, err := m.inner.Get(id) - if err != nil { - return nil, err - } - - sess := sessionOrPanic(val) - - // if we're here then mango has loaded a new cache value (session), so we'll add it to the tmp cache for now to - // allow mango an opportunity to actually cache the value. - m.stash.add(id, sess) - - return sess, nil -} - -func sessionOrPanic(val mango.Value) *Session { - sess, ok := val.(*Session) - if !ok { - panic("unexpected value") - } - - return sess -} - func incrementSharedSessionUsage(s *Session) { s.encryption.(*sharedEncryption).incrementUsage() } -func (m *mangoCache) Count() int { - s := &mango.Stats{} - m.inner.Stats(s) - - return int(s.LoadSuccessCount - s.EvictionCount) -} - -func (m *mangoCache) Close() { - if log.DebugEnabled() { - s := &mango.Stats{} - m.inner.Stats(s) - log.Debugf("session cache stash len = %d\n", m.stash.len()) - log.Debugf("%v\n", s) - } - - m.inner.Close() - m.stash.close() -} - -func mangoRemovalListener(m *mangoCache, k mango.Key, v mango.Value) { - m.stash.remove(k.(string)) - - go v.(*Session).encryption.(*sharedEncryption).Remove() -} - -func newMangoCache(sessionLoader sessionLoaderFunc, policy *CryptoPolicy) *mangoCache { - cache := &mangoCache{ - loader: sessionLoader, - stash: newCacheStash(), - } - - cache.inner = mango.NewLoadingCache( - func(k mango.Key) (mango.Value, error) { - return sessionLoader(k.(string)) - }, - mango.WithMaximumSize(policy.SessionCacheMaxSize), - mango.WithExpireAfterAccess(policy.SessionCacheDuration), - mango.WithRemovalListener(func(k mango.Key, v mango.Value) { - mangoRemovalListener(cache, k, v) - }), - ) - - go cache.stash.process() - - return cache -} - // sharedEncryption is used to track the number of concurrent users to ensure sessions remain // cached while in use. type sharedEncryption struct { @@ -250,38 +61,110 @@ func (s *sharedEncryption) Remove() { // sessionLoaderFunc retrieves a Session corresponding to the given partition ID. type sessionLoaderFunc func(id string) (*Session, error) -// newSessionCache returns a new SessionCache with the configured cache implementation +// sessionInjectEncryption is used to inject e into s and is primarily used for testing. +func sessionInjectEncryption(s *Session, e Encryption) { + log.Debugf("injecting Encryption(%p) into Session(%p)", e, s) + + s.encryption = e +} + +// newSessionCacheWithCache returns a new SessionCache with the provided cache implementation // using the provided SessionLoaderFunc and CryptoPolicy. -func newSessionCache(loader sessionLoaderFunc, policy *CryptoPolicy) sessionCache { - wrapper := func(id string) (*Session, error) { - s, err := loader(id) - if err != nil { - return nil, err - } - - _, ok := s.encryption.(*sharedEncryption) - if !ok { - mu := new(sync.Mutex) - orig := s.encryption - wrapped := &sharedEncryption{ - Encryption: orig, - mu: mu, - cond: sync.NewCond(mu), - created: time.Now(), +func newSessionCacheWithCache(loader sessionLoaderFunc, policy *CryptoPolicy, cache cache.Interface[string, *Session]) sessionCache { + return &cacheWrapper{ + loader: func(id string) (*Session, error) { + log.Debugf("loading session for id: %s", id) + + s, err := loader(id) + if err != nil { + return nil, err } - sessionInjectEncryption(s, wrapped) - } + _, ok := s.encryption.(*sharedEncryption) + if !ok { + mu := new(sync.Mutex) + orig := s.encryption + wrapped := &sharedEncryption{ + Encryption: orig, + mu: mu, + cond: sync.NewCond(mu), + created: time.Now(), + } + + sessionInjectEncryption(s, wrapped) + } - return s, nil + return s, nil + }, + policy: policy, + cache: cache, } +} + +// cacheWrapper is a wrapper around a cache.Interface[string, *Session] that implements the +// sessionCache interface. +type cacheWrapper struct { + loader sessionLoaderFunc + policy *CryptoPolicy + cache cache.Interface[string, *Session] - return newMangoCache(wrapper, policy) + mu sync.Mutex } -// sessionInjectEncryption is used to inject e into s and is primarily used for testing. -func sessionInjectEncryption(s *Session, e Encryption) { - log.Debugf("injecting Encryption(%p) into Session(%p)", e, s) +func (c *cacheWrapper) Get(id string) (*Session, error) { + c.mu.Lock() + defer c.mu.Unlock() - s.encryption = e + val, err := c.getOrAdd(id) + if err != nil { + return nil, err + } + + incrementSharedSessionUsage(val) + + return val, nil +} + +func (c *cacheWrapper) getOrAdd(id string) (*Session, error) { + if val, ok := c.cache.Get(id); ok { + return val, nil + } + + val, err := c.loader(id) + if err != nil { + return nil, err + } + + c.cache.Set(id, val) + + return val, nil +} + +func (c *cacheWrapper) Count() int { + return c.cache.Len() +} + +func (c *cacheWrapper) Close() { + log.Debugf("closing session cache") + + c.cache.Close() +} + +func newSessionCache(loader sessionLoaderFunc, policy *CryptoPolicy) sessionCache { + cb := cache.New[string, *Session](policy.SessionCacheMaxSize) + cb.WithEvictFunc(func(k string, v *Session) { + go v.encryption.(*sharedEncryption).Remove() + }) + + if policy.SessionCacheDuration > 0 { + cb.WithExpiry(policy.SessionCacheDuration) + } + + if policy.SessionCacheEvictionPolicy == "" { + policy.SessionCacheEvictionPolicy = "slru" + } + + cb.WithPolicy(cache.CachePolicy(policy.SessionCacheEvictionPolicy)) + + return newSessionCacheWithCache(loader, policy, cb.Build()) } diff --git a/go/appencryption/session_cache_test.go b/go/appencryption/session_cache_test.go index 8e1bd1faa..758155b8f 100644 --- a/go/appencryption/session_cache_test.go +++ b/go/appencryption/session_cache_test.go @@ -100,6 +100,8 @@ func TestNewSessionCache(t *testing.T) { defer cache.Close() require.NotNil(t, cache) + + assert.Equal(t, cache.(*cacheWrapper).policy.SessionCacheEvictionPolicy, "slru") } func TestSessionCacheGetUsesLoader(t *testing.T) { @@ -227,15 +229,20 @@ func TestSessionCacheMaxCount(t *testing.T) { func TestSessionCacheDuration(t *testing.T) { ttl := time.Millisecond * 100 - // can't use more than 16 sessions here as that is the max drain - // for the mango cache implementation totalSessions := 16 b := newSessionBucket() policy := NewCryptoPolicy() policy.SessionCacheDuration = ttl - cache := newSessionCache(b.load, policy) + loaded := 0 + + loader := func(id string) (*Session, error) { + loaded++ + return b.load(id) + } + + cache := newSessionCache(loader, policy) require.NotNil(t, cache) defer cache.Close() @@ -244,18 +251,23 @@ func TestSessionCacheDuration(t *testing.T) { cache.Get(strconv.Itoa(i)) } + expectedCount := totalSessions + + // assert we have a load for each session + require.Equal(t, expectedCount, loaded) + // ensure the ttl has elapsed time.Sleep(ttl + time.Millisecond*50) - expectedCount := 0 + assert.Eventually(t, func() bool { + for i := 0; i < totalSessions; i++ { + cache.Get(strconv.Itoa(i)) + } - // mango cache implementation only reaps expired entries following a write, so we'll write a new - // cache entry and ensure it's the only one left - _, _ = cache.Get("99") // IDs 0-15 were created above - expectedCount = 1 + // now that the ttl has elapsed, we should have loaded the sessions again + // and the total loaded should be greater than the expected count + return loaded > expectedCount - assert.Eventually(t, func() bool { - return cache.Count() == expectedCount }, time.Second*10, time.Millisecond*10) } @@ -283,7 +295,7 @@ func TestSessionCacheCloseWithDebugLogging(t *testing.T) { // assert additional debug info was written to log assert.NotEqual(t, 0, l.Len()) - assert.Contains(t, l.String(), "session cache stash len = 0") + assert.Contains(t, l.String(), "closing session cache") log.SetLogger(nil) } @@ -385,51 +397,3 @@ func TestSharedSessionCloseDoesNotCloseUnderlyingSession(t *testing.T) { // shared sessions aren't actually closed until evicted from the cache assert.False(t, b.IsClosed(s1)) } - -func TestCacheStash(t *testing.T) { - id := "stashed item" - stash := newCacheStash() - - complete := make(chan bool) - - go func() { - stash.process() - - complete <- true - }() - - // stash is empty - s, ok := stash.get(id) - assert.Nil(t, s) - assert.False(t, ok) - assert.Equal(t, 0, stash.len()) - - // create a new session and stash it - sess := new(Session) - stash.add(id, sess) - - // stash now contains the session we just added - s, ok = stash.get(id) - assert.Equal(t, sess, s) - assert.True(t, ok) - assert.Equal(t, 1, stash.len()) - - // now remove the stashed session - stash.remove(id) - - // remove events are queued asynchronously - assert.Eventually(t, func() bool { - _, ok := stash.get(id) - - return !ok - }, 500*time.Millisecond, 10*time.Millisecond) - - // and verify it's gone - s, ok = stash.get(id) - assert.Nil(t, s) - assert.False(t, ok) - assert.Equal(t, 0, stash.len()) - - stash.close() - assert.True(t, <-complete) -} diff --git a/go/appencryption/session_test.go b/go/appencryption/session_test.go index 344fdd676..b0d6147f8 100644 --- a/go/appencryption/session_test.go +++ b/go/appencryption/session_test.go @@ -80,7 +80,7 @@ type MockCache struct { mock.Mock } -func (c *MockCache) GetOrLoad(id KeyMeta, loader keyLoader) (*internal.CryptoKey, error) { +func (c *MockCache) GetOrLoad(id KeyMeta, loader func(KeyMeta) (*internal.CryptoKey, error)) (*cachedCryptoKey, error) { var ( ret = c.Called(id, loader) key *internal.CryptoKey @@ -90,10 +90,10 @@ func (c *MockCache) GetOrLoad(id KeyMeta, loader keyLoader) (*internal.CryptoKey key = b.(*internal.CryptoKey) } - return key, ret.Error(1) + return &cachedCryptoKey{CryptoKey: key}, ret.Error(1) } -func (c *MockCache) GetOrLoadLatest(id string, loader keyLoader) (*internal.CryptoKey, error) { +func (c *MockCache) GetOrLoadLatest(id string, loader func(KeyMeta) (*internal.CryptoKey, error)) (*cachedCryptoKey, error) { var ( ret = c.Called(id, loader) key *internal.CryptoKey @@ -103,7 +103,7 @@ func (c *MockCache) GetOrLoadLatest(id string, loader keyLoader) (*internal.Cryp key = b.(*internal.CryptoKey) } - return key, ret.Error(1) + return &cachedCryptoKey{CryptoKey: key}, ret.Error(1) } func (c *MockCache) Close() error { @@ -121,9 +121,9 @@ func TestNewSessionFactory(t *testing.T) { } func TestNewSessionFactory_WithSessionCache(t *testing.T) { - policy := &CryptoPolicy{ - CacheSessions: true, - } + policy := NewCryptoPolicy() + policy.CacheSessions = true + factory := NewSessionFactory(&Config{ Policy: policy, }, nil, nil, nil) @@ -169,7 +169,7 @@ func TestSessionFactory_GetSession(t *testing.T) { sess, err := sessionFactory.GetSession("testing") if assert.NoError(t, err) { assert.NotNil(t, sess.encryption) - ik := sess.encryption.(*envelopeEncryption).intermediateKeys + ik := sess.encryption.(*envelopeEncryption).ikCache assert.IsType(t, new(neverCache), ik) } } @@ -184,7 +184,7 @@ func TestSessionFactory_GetSession_CanCacheIntermediateKeys(t *testing.T) { sess, err := sessionFactory.GetSession("testing") if assert.NoError(t, err) { assert.NotNil(t, sess.encryption) - ik := sess.encryption.(*envelopeEncryption).intermediateKeys + ik := sess.encryption.(*envelopeEncryption).ikCache assert.IsType(t, new(keyCache), ik) } }