From c2365513267f0c5a6eb9eeb6ad4c975185ab1dad Mon Sep 17 00:00:00 2001 From: Leonard Cohnen Date: Tue, 2 Jan 2024 17:51:03 +0100 Subject: [PATCH] snp: clear kds cache daily --- coordinator/intercom.go | 8 +- coordinator/main.go | 4 + go.mod | 2 +- internal/attestation/snp/cachedClient.go | 31 ++++--- internal/attestation/snp/cachedClient_test.go | 88 +++++++++++++++++++ internal/attestation/snp/validator.go | 5 +- internal/memstore/memstore.go | 6 ++ 7 files changed, 128 insertions(+), 16 deletions(-) create mode 100644 internal/attestation/snp/cachedClient_test.go diff --git a/coordinator/intercom.go b/coordinator/intercom.go index 708a1ea547..9b1cac0189 100644 --- a/coordinator/intercom.go +++ b/coordinator/intercom.go @@ -13,12 +13,14 @@ import ( "github.com/edgelesssys/nunki/internal/intercom" "google.golang.org/grpc" "google.golang.org/grpc/keepalive" + "k8s.io/utils/clock" ) type intercomServer struct { grpc *grpc.Server certGet certGetter caChainGetter certChainGetter + ticker clock.Ticker logger *slog.Logger intercom.UnimplementedIntercomServer @@ -29,7 +31,8 @@ type certGetter interface { } func newIntercomServer(meshAuth *meshAuthority, caGetter certChainGetter, log *slog.Logger) (*intercomServer, error) { - validator := snp.NewValidatorWithCallbacks(meshAuth, log, meshAuth) + ticker := clock.RealClock{}.NewTicker(24 * time.Hour) + validator := snp.NewValidatorWithCallbacks(meshAuth, ticker, log, meshAuth) credentials := atlscredentials.New(atls.NoIssuer, []atls.Validator{validator}) grpcServer := grpc.NewServer( grpc.Creds(credentials), @@ -39,6 +42,7 @@ func newIntercomServer(meshAuth *meshAuthority, caGetter certChainGetter, log *s grpc: grpcServer, certGet: meshAuth, caChainGetter: caGetter, + ticker: ticker, logger: log.WithGroup("intercom"), } intercom.RegisterIntercomServer(s.grpc, s) @@ -50,6 +54,8 @@ func (i *intercomServer) Serve(endpoint string) error { if err != nil { return fmt.Errorf("failed to listen: %v", err) } + + defer i.ticker.Stop() return i.grpc.Serve(lis) } diff --git a/coordinator/main.go b/coordinator/main.go index 98cad86403..7fbaeaf6af 100644 --- a/coordinator/main.go +++ b/coordinator/main.go @@ -6,10 +6,12 @@ import ( "log/slog" "net" "os" + "time" "github.com/edgelesssys/nunki/internal/ca" "github.com/edgelesssys/nunki/internal/coordapi" "github.com/edgelesssys/nunki/internal/intercom" + "k8s.io/utils/clock" ) func main() { @@ -56,6 +58,8 @@ func run() (retErr error) { } }() + ticker := clock.RealClock{}.NewTicker(24 * time.Hour) + defer ticker.Stop() intercomS, err := newIntercomServer(meshAuth, caInstance, logger) if err != nil { return fmt.Errorf("creating intercom server: %v", err) diff --git a/go.mod b/go.mod index 757c677099..a7b2b183b4 100644 --- a/go.mod +++ b/go.mod @@ -13,6 +13,7 @@ require ( gopkg.in/yaml.v2 v2.4.0 k8s.io/api v0.28.4 k8s.io/apimachinery v0.28.4 + k8s.io/utils v0.0.0-20230406110748-d93618cff8a2 ) require ( @@ -39,7 +40,6 @@ require ( gopkg.in/inf.v0 v0.9.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect k8s.io/klog/v2 v2.100.1 // indirect - k8s.io/utils v0.0.0-20230406110748-d93618cff8a2 // indirect sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd // indirect sigs.k8s.io/structured-merge-diff/v4 v4.2.3 // indirect sigs.k8s.io/yaml v1.3.0 // indirect diff --git a/internal/attestation/snp/cachedClient.go b/internal/attestation/snp/cachedClient.go index eab4ce5e5a..0ff8ee9a95 100644 --- a/internal/attestation/snp/cachedClient.go +++ b/internal/attestation/snp/cachedClient.go @@ -5,28 +5,41 @@ import ( "github.com/edgelesssys/nunki/internal/memstore" "github.com/google/go-sev-guest/verify/trust" + "k8s.io/utils/clock" ) type cachedKDSHTTPClient struct { trust.HTTPSGetter logger *slog.Logger - cache *memstore.Store[string, cacheEntry] + gcTicker clock.Ticker + cache *memstore.Store[string, []byte] } -func NewCachedKDSHTTPClient(log *slog.Logger) *cachedKDSHTTPClient { +func NewCachedKDSHTTPClient(ticker clock.Ticker, log *slog.Logger) *cachedKDSHTTPClient { trust.DefaultHTTPSGetter() - return &cachedKDSHTTPClient{ + + c := &cachedKDSHTTPClient{ HTTPSGetter: trust.DefaultHTTPSGetter(), logger: log.WithGroup("cached-kds-http-client"), - cache: memstore.New[string, cacheEntry](), + cache: memstore.New[string, []byte](), + gcTicker: ticker, } + + return c } func (c *cachedKDSHTTPClient) Get(url string) ([]byte, error) { + select { + case <-c.gcTicker.C(): + c.logger.Debug("Garbage collecting") + c.cache.Clear() + default: + } + if cached, ok := c.cache.Get(url); ok { c.logger.Debug("Get cached", "url", url) - return cached.data, nil + return cached, nil } c.logger.Debug("Get not cached", "url", url) @@ -34,12 +47,6 @@ func (c *cachedKDSHTTPClient) Get(url string) ([]byte, error) { if err != nil { return nil, err } - c.cache.Set(url, cacheEntry{ - data: res, - }) + c.cache.Set(url, res) return res, nil } - -type cacheEntry struct { - data []byte -} diff --git a/internal/attestation/snp/cachedClient_test.go b/internal/attestation/snp/cachedClient_test.go new file mode 100644 index 0000000000..976da50320 --- /dev/null +++ b/internal/attestation/snp/cachedClient_test.go @@ -0,0 +1,88 @@ +package snp + +import ( + "log/slog" + "testing" + "time" + + "github.com/edgelesssys/nunki/internal/memstore" + "github.com/stretchr/testify/assert" + "go.uber.org/goleak" + testingclock "k8s.io/utils/clock/testing" +) + +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +} + +func TestCachedKDSHTTPClient(t *testing.T) { + t.Run("Get", func(t *testing.T) { + assert := assert.New(t) + + fakeGetter := &fakeHTTPSGetter{ + content: map[string][]byte{ + "foo": []byte("bar"), + }, + hits: map[string]int{}, + } + stepTime := 5 * time.Minute + testClock := testingclock.NewFakeClock(time.Now()) + ticker := testClock.NewTicker(stepTime) + client := &cachedKDSHTTPClient{ + HTTPSGetter: fakeGetter, + gcTicker: ticker, + cache: memstore.New[string, []byte](), + logger: slog.Default(), + } + + res, err := client.Get("foo") + assert.NoError(err) + assert.Equal([]byte("bar"), res) + assert.Equal(1, fakeGetter.hits["foo"]) + + // Expect a second call to return the cached value and not increase the hit counter. + res, err = client.Get("foo") + assert.NoError(err) + assert.Equal([]byte("bar"), res) + assert.Equal(1, fakeGetter.hits["foo"]) + + // After the step time, the cache should be invalidated and hit the backend again. + testClock.Step(stepTime) + res, err = client.Get("foo") + assert.NoError(err) + assert.Equal([]byte("bar"), res) + assert.Equal(2, fakeGetter.hits["foo"]) + }) + t.Run("Get error", func(t *testing.T) { + fakeGetter := &fakeHTTPSGetter{ + getErr: assert.AnError, + content: map[string][]byte{}, + hits: map[string]int{}, + } + testClock := testingclock.NewFakeClock(time.Now()) + ticker := testClock.NewTicker(5 * time.Minute) + client := &cachedKDSHTTPClient{ + HTTPSGetter: fakeGetter, + gcTicker: ticker, + cache: memstore.New[string, []byte](), + logger: slog.Default(), + } + + assert := assert.New(t) + + _, err := client.Get("foo") + assert.Error(err) + assert.Equal(1, fakeGetter.hits["foo"]) + }) +} + +type fakeHTTPSGetter struct { + content map[string][]byte + hits map[string]int + getErr error +} + +func (f *fakeHTTPSGetter) Get(url string) ([]byte, error) { + f.hits[url]++ + return f.content[url], f.getErr +} diff --git a/internal/attestation/snp/validator.go b/internal/attestation/snp/validator.go index de43acb0b4..deb9ab2bc2 100644 --- a/internal/attestation/snp/validator.go +++ b/internal/attestation/snp/validator.go @@ -19,6 +19,7 @@ import ( "github.com/google/go-sev-guest/validate" "github.com/google/go-sev-guest/verify" "github.com/google/go-sev-guest/verify/trust" + "k8s.io/utils/clock" ) type Validator struct { @@ -51,11 +52,11 @@ func NewValidator(optsGen validateOptsGenerator, log *slog.Logger) *Validator { } } -func NewValidatorWithCallbacks(optsGen validateOptsGenerator, log *slog.Logger, callbacks ...validateCallbacker) *Validator { +func NewValidatorWithCallbacks(optsGen validateOptsGenerator, ticker clock.Ticker, log *slog.Logger, callbacks ...validateCallbacker) *Validator { return &Validator{ validateOptsGen: optsGen, callbackers: callbacks, - kdsGetter: NewCachedKDSHTTPClient(log), + kdsGetter: NewCachedKDSHTTPClient(ticker, log), logger: log.WithGroup("snp-validator"), } } diff --git a/internal/memstore/memstore.go b/internal/memstore/memstore.go index dafcc72182..8734fe9b89 100644 --- a/internal/memstore/memstore.go +++ b/internal/memstore/memstore.go @@ -35,3 +35,9 @@ func (s *Store[keyT, valueT]) GetAll() []valueT { } return values } + +func (s *Store[keyT, valueT]) Clear() { + s.mux.Lock() + defer s.mux.Unlock() + clear(s.m) +}