Skip to content

Commit

Permalink
kds-cache: add fallback cache for CRLs on request failure
Browse files Browse the repository at this point in the history
  • Loading branch information
jmxnzo committed Dec 4, 2024
1 parent 65bae63 commit 5a0bba9
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 58 deletions.
30 changes: 19 additions & 11 deletions internal/attestation/certcache/cached_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,25 +44,33 @@ func (c *CachedHTTPSGetter) Get(url string) ([]byte, error) {
default:
}

// Don't cache CRLs. Unlike VCEKs, these can change over time and the KDS
// doesn't rate-limit requests to these.
canCache := !crlURL.MatchString(url)

if canCache {
if crlURL.MatchString(url) {
// For CRLs always query. When request failure, fallback to cache.
c.logger.Debug("Request CRL", "url", url)
res, err := c.HTTPSGetter.Get(url)
if err == nil {
c.cache.Set(url, res)
return res, nil
}
c.logger.Warn("Could not reach KDS", "error", err)
if cached, ok := c.cache.Get(url); ok {
c.logger.Debug("Get cached", "url", url)
c.logger.Debug("CRL request failed, fallback to cached CRL", "url", url)
return cached, nil
}
c.logger.Debug("CRL request failed and CRL was not found in cache", "url", url)
return nil, err
}

c.logger.Debug("Get not cached", "url", url)
// For VCEK get cache first and request if not present
if cached, ok := c.cache.Get(url); ok {
c.logger.Debug("Get cached VCEK", "url", url)
return cached, nil
}
c.logger.Debug("Request VCEK, missing in cache", "url", url)
res, err := c.HTTPSGetter.Get(url)
if err != nil {
return nil, err
}
if canCache {
c.cache.Set(url, res)
}
c.cache.Set(url, res)
return res, nil
}

Expand Down
125 changes: 78 additions & 47 deletions internal/attestation/certcache/cached_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package certcache

import (
"errors"
"log/slog"
"sync"
"testing"
Expand All @@ -12,32 +13,24 @@ import (
"github.com/edgelesssys/contrast/internal/memstore"
"github.com/stretchr/testify/assert"
"go.uber.org/goleak"
"k8s.io/utils/clock"
testingclock "k8s.io/utils/clock/testing"
)

func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}

const crlURLMatch string = "https://kdsintf.amd.com/vcek/v1/test/crl"

func TestMemcachedHTTPSGetter(t *testing.T) {
t.Run("Get", func(t *testing.T) {
assert := assert.New(t)
stepTime := 5 * time.Minute
testClock := testingclock.NewFakeClock(time.Now())
ticker := testClock.NewTicker(stepTime)

fakeGetter := &fakeHTTPSGetter{
content: map[string][]byte{
"foo": []byte("bar"),
},
hits: map[string]int{},
}
stepTime := 5 * time.Minute
testClock := testingclock.NewFakeClock(time.Now())
ticker := testClock.NewTicker(stepTime)
client := &CachedHTTPSGetter{
HTTPSGetter: fakeGetter,
gcTicker: ticker,
cache: memstore.New[string, []byte](),
logger: slog.Default(),
}
t.Run("Get VCEK by request and from cache", func(t *testing.T) {
assert := assert.New(t)
fakeGetter, client := getFakeHTTPSGetters(ticker)

res, err := client.Get("foo")
assert.NoError(err)
Expand All @@ -57,46 +50,66 @@ func TestMemcachedHTTPSGetter(t *testing.T) {
assert.Equal([]byte("bar"), res)
assert.Equal(2, fakeGetter.hits["foo"])
})
t.Run("Get error", func(t *testing.T) {
fakeGetter := &fakeHTTPSGetter{
getErr: assert.AnError,
content: map[string][]byte{},
hits: map[string]int{},
}
testClock := testingclock.NewFakeClock(time.Now())
ticker := testClock.NewTicker(5 * time.Minute)
client := &CachedHTTPSGetter{
HTTPSGetter: fakeGetter,
gcTicker: ticker,
cache: memstore.New[string, []byte](),
logger: slog.Default(),
}

t.Run("VCEK request fails and VCEK not in cache", func(t *testing.T) {
assert := assert.New(t)
fakeGetter, client := getFakeHTTPSGetters(ticker)

// Simulate a request failure by returning an error
fakeGetter.getErr = errors.New("VCEK request failure")

_, err := client.Get("foo")
assert.Error(err)
assert.Equal(1, fakeGetter.hits["foo"])
})
t.Run("Concurrent access", func(t *testing.T) {

t.Run("Check CRLs are still requested after caching", func(t *testing.T) {
assert := assert.New(t)
fakeGetter, client := getFakeHTTPSGetters(ticker)

fakeGetter := &fakeHTTPSGetter{
content: map[string][]byte{
"foo": []byte("bar"),
},
hits: map[string]int{},
}
res, err := client.Get(crlURLMatch)
assert.NoError(err)
assert.Equal([]byte("bar"), res)
assert.Equal(1, fakeGetter.hits[crlURLMatch])

// Even after the CRL is cached, the CRL should be requested(hit counter increase).
res, err = client.Get(crlURLMatch)
assert.NoError(err)
assert.Equal([]byte("bar"), res)
assert.Equal(2, fakeGetter.hits[crlURLMatch])
})

t.Run("Check CRLs can be loaded by cache when request fails", func(t *testing.T) {
assert := assert.New(t)
fakeGetter, client := getFakeHTTPSGetters(ticker)

// Preload CRL into the cache
client.cache.Set(crlURLMatch, []byte("bar"))
fakeGetter.getErr = errors.New("CRL request failure")

// The CRL should be loaded from the cache and client.Get() won't result in an error
res, err := client.Get(crlURLMatch)
assert.NoError(err)
assert.Equal([]byte("bar"), res)
assert.Equal(1, fakeGetter.hits[crlURLMatch])
})

t.Run("CRL request fails and CRL not in cache", func(t *testing.T) {
assert := assert.New(t)
fakeGetter, client := getFakeHTTPSGetters(ticker)

fakeGetter.getErr = errors.New("CRL request failure")

// No CRL cache and request failure results in error
_, err := client.Get(crlURLMatch)
assert.Error(err)
assert.Equal(1, fakeGetter.hits[crlURLMatch])
})

t.Run("Concurrent access", func(t *testing.T) {
assert := assert.New(t)
_, client := getFakeHTTPSGetters(ticker)
numGets := 5
stepTime := 5 * time.Minute
testClock := testingclock.NewFakeClock(time.Now())
ticker := testClock.NewTicker(stepTime)
client := &CachedHTTPSGetter{
HTTPSGetter: fakeGetter,
gcTicker: ticker,
cache: memstore.New[string, []byte](),
logger: slog.Default(),
}

var wg sync.WaitGroup
getFunc := func() {
Expand Down Expand Up @@ -124,6 +137,24 @@ type fakeHTTPSGetter struct {
hits map[string]int
}

// Returns the fakeHTTPSGetter for test assertions and its wrapper CachedHTTPSGetter.
func getFakeHTTPSGetters(ticker clock.Ticker) (*fakeHTTPSGetter, *CachedHTTPSGetter) {
fakeGetter := &fakeHTTPSGetter{
content: map[string][]byte{
"foo": []byte("bar"),
crlURLMatch: []byte("bar"),
},
hits: map[string]int{},
}

return fakeGetter, &CachedHTTPSGetter{
HTTPSGetter: fakeGetter,
gcTicker: ticker,
cache: memstore.New[string, []byte](),
logger: slog.Default(),
}
}

func (f *fakeHTTPSGetter) Get(url string) ([]byte, error) {
f.hitsMux.Lock()
defer f.hitsMux.Unlock()
Expand Down

0 comments on commit 5a0bba9

Please sign in to comment.