From 5a0bba9e0807ed5e3cb823cd3143af89e1aa9548 Mon Sep 17 00:00:00 2001 From: jmxnzo Date: Tue, 3 Dec 2024 17:20:34 +0100 Subject: [PATCH] kds-cache: add fallback cache for CRLs on request failure --- .../attestation/certcache/cached_client.go | 30 +++-- .../certcache/cached_client_test.go | 125 +++++++++++------- 2 files changed, 97 insertions(+), 58 deletions(-) diff --git a/internal/attestation/certcache/cached_client.go b/internal/attestation/certcache/cached_client.go index 0262fcaf11..10e6898d66 100644 --- a/internal/attestation/certcache/cached_client.go +++ b/internal/attestation/certcache/cached_client.go @@ -44,25 +44,33 @@ func (c *CachedHTTPSGetter) Get(url string) ([]byte, error) { default: } - // Don't cache CRLs. Unlike VCEKs, these can change over time and the KDS - // doesn't rate-limit requests to these. - canCache := !crlURL.MatchString(url) - - if canCache { + if crlURL.MatchString(url) { + // For CRLs always query. When request failure, fallback to cache. + c.logger.Debug("Request CRL", "url", url) + res, err := c.HTTPSGetter.Get(url) + if err == nil { + c.cache.Set(url, res) + return res, nil + } + c.logger.Warn("Could not reach KDS", "error", err) if cached, ok := c.cache.Get(url); ok { - c.logger.Debug("Get cached", "url", url) + c.logger.Debug("CRL request failed, fallback to cached CRL", "url", url) return cached, nil } + c.logger.Debug("CRL request failed and CRL was not found in cache", "url", url) + return nil, err } - - c.logger.Debug("Get not cached", "url", url) + // For VCEK get cache first and request if not present + if cached, ok := c.cache.Get(url); ok { + c.logger.Debug("Get cached VCEK", "url", url) + return cached, nil + } + c.logger.Debug("Request VCEK, missing in cache", "url", url) res, err := c.HTTPSGetter.Get(url) if err != nil { return nil, err } - if canCache { - c.cache.Set(url, res) - } + c.cache.Set(url, res) return res, nil } diff --git a/internal/attestation/certcache/cached_client_test.go b/internal/attestation/certcache/cached_client_test.go index 8515eae071..0132c54a35 100644 --- a/internal/attestation/certcache/cached_client_test.go +++ b/internal/attestation/certcache/cached_client_test.go @@ -4,6 +4,7 @@ package certcache import ( + "errors" "log/slog" "sync" "testing" @@ -12,6 +13,7 @@ import ( "github.com/edgelesssys/contrast/internal/memstore" "github.com/stretchr/testify/assert" "go.uber.org/goleak" + "k8s.io/utils/clock" testingclock "k8s.io/utils/clock/testing" ) @@ -19,25 +21,16 @@ func TestMain(m *testing.M) { goleak.VerifyTestMain(m) } +const crlURLMatch string = "https://kdsintf.amd.com/vcek/v1/test/crl" + func TestMemcachedHTTPSGetter(t *testing.T) { - t.Run("Get", func(t *testing.T) { - assert := assert.New(t) + stepTime := 5 * time.Minute + testClock := testingclock.NewFakeClock(time.Now()) + ticker := testClock.NewTicker(stepTime) - fakeGetter := &fakeHTTPSGetter{ - content: map[string][]byte{ - "foo": []byte("bar"), - }, - hits: map[string]int{}, - } - stepTime := 5 * time.Minute - testClock := testingclock.NewFakeClock(time.Now()) - ticker := testClock.NewTicker(stepTime) - client := &CachedHTTPSGetter{ - HTTPSGetter: fakeGetter, - gcTicker: ticker, - cache: memstore.New[string, []byte](), - logger: slog.Default(), - } + t.Run("Get VCEK by request and from cache", func(t *testing.T) { + assert := assert.New(t) + fakeGetter, client := getFakeHTTPSGetters(ticker) res, err := client.Get("foo") assert.NoError(err) @@ -57,46 +50,66 @@ func TestMemcachedHTTPSGetter(t *testing.T) { assert.Equal([]byte("bar"), res) assert.Equal(2, fakeGetter.hits["foo"]) }) - t.Run("Get error", func(t *testing.T) { - fakeGetter := &fakeHTTPSGetter{ - getErr: assert.AnError, - content: map[string][]byte{}, - hits: map[string]int{}, - } - testClock := testingclock.NewFakeClock(time.Now()) - ticker := testClock.NewTicker(5 * time.Minute) - client := &CachedHTTPSGetter{ - HTTPSGetter: fakeGetter, - gcTicker: ticker, - cache: memstore.New[string, []byte](), - logger: slog.Default(), - } + t.Run("VCEK request fails and VCEK not in cache", func(t *testing.T) { assert := assert.New(t) + fakeGetter, client := getFakeHTTPSGetters(ticker) + + // Simulate a request failure by returning an error + fakeGetter.getErr = errors.New("VCEK request failure") _, err := client.Get("foo") assert.Error(err) assert.Equal(1, fakeGetter.hits["foo"]) }) - t.Run("Concurrent access", func(t *testing.T) { + + t.Run("Check CRLs are still requested after caching", func(t *testing.T) { assert := assert.New(t) + fakeGetter, client := getFakeHTTPSGetters(ticker) - fakeGetter := &fakeHTTPSGetter{ - content: map[string][]byte{ - "foo": []byte("bar"), - }, - hits: map[string]int{}, - } + res, err := client.Get(crlURLMatch) + assert.NoError(err) + assert.Equal([]byte("bar"), res) + assert.Equal(1, fakeGetter.hits[crlURLMatch]) + + // Even after the CRL is cached, the CRL should be requested(hit counter increase). + res, err = client.Get(crlURLMatch) + assert.NoError(err) + assert.Equal([]byte("bar"), res) + assert.Equal(2, fakeGetter.hits[crlURLMatch]) + }) + + t.Run("Check CRLs can be loaded by cache when request fails", func(t *testing.T) { + assert := assert.New(t) + fakeGetter, client := getFakeHTTPSGetters(ticker) + + // Preload CRL into the cache + client.cache.Set(crlURLMatch, []byte("bar")) + fakeGetter.getErr = errors.New("CRL request failure") + + // The CRL should be loaded from the cache and client.Get() won't result in an error + res, err := client.Get(crlURLMatch) + assert.NoError(err) + assert.Equal([]byte("bar"), res) + assert.Equal(1, fakeGetter.hits[crlURLMatch]) + }) + + t.Run("CRL request fails and CRL not in cache", func(t *testing.T) { + assert := assert.New(t) + fakeGetter, client := getFakeHTTPSGetters(ticker) + + fakeGetter.getErr = errors.New("CRL request failure") + + // No CRL cache and request failure results in error + _, err := client.Get(crlURLMatch) + assert.Error(err) + assert.Equal(1, fakeGetter.hits[crlURLMatch]) + }) + + t.Run("Concurrent access", func(t *testing.T) { + assert := assert.New(t) + _, client := getFakeHTTPSGetters(ticker) numGets := 5 - stepTime := 5 * time.Minute - testClock := testingclock.NewFakeClock(time.Now()) - ticker := testClock.NewTicker(stepTime) - client := &CachedHTTPSGetter{ - HTTPSGetter: fakeGetter, - gcTicker: ticker, - cache: memstore.New[string, []byte](), - logger: slog.Default(), - } var wg sync.WaitGroup getFunc := func() { @@ -124,6 +137,24 @@ type fakeHTTPSGetter struct { hits map[string]int } +// Returns the fakeHTTPSGetter for test assertions and its wrapper CachedHTTPSGetter. +func getFakeHTTPSGetters(ticker clock.Ticker) (*fakeHTTPSGetter, *CachedHTTPSGetter) { + fakeGetter := &fakeHTTPSGetter{ + content: map[string][]byte{ + "foo": []byte("bar"), + crlURLMatch: []byte("bar"), + }, + hits: map[string]int{}, + } + + return fakeGetter, &CachedHTTPSGetter{ + HTTPSGetter: fakeGetter, + gcTicker: ticker, + cache: memstore.New[string, []byte](), + logger: slog.Default(), + } +} + func (f *fakeHTTPSGetter) Get(url string) ([]byte, error) { f.hitsMux.Lock() defer f.hitsMux.Unlock()