diff --git a/coordinator/intercom.go b/coordinator/intercom.go index 3b34459988..01581d7831 100644 --- a/coordinator/intercom.go +++ b/coordinator/intercom.go @@ -15,12 +15,14 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/keepalive" "google.golang.org/grpc/status" + "k8s.io/utils/clock" ) type intercomServer struct { grpc *grpc.Server certGet certGetter caChainGetter certChainGetter + ticker clock.Ticker logger *slog.Logger intercom.UnimplementedIntercomServer @@ -31,7 +33,8 @@ type certGetter interface { } func newIntercomServer(meshAuth *meshAuthority, caGetter certChainGetter, log *slog.Logger) *intercomServer { - validator := snp.NewValidatorWithCallbacks(meshAuth, log, meshAuth) + ticker := clock.RealClock{}.NewTicker(24 * time.Hour) + validator := snp.NewValidatorWithCallbacks(meshAuth, ticker, log, meshAuth) credentials := atlscredentials.New(atls.NoIssuer, []atls.Validator{validator}) grpcServer := grpc.NewServer( grpc.Creds(credentials), @@ -41,6 +44,7 @@ func newIntercomServer(meshAuth *meshAuthority, caGetter certChainGetter, log *s grpc: grpcServer, certGet: meshAuth, caChainGetter: caGetter, + ticker: ticker, logger: log.WithGroup("intercom"), } intercom.RegisterIntercomServer(s.grpc, s) @@ -52,6 +56,8 @@ func (i *intercomServer) Serve(endpoint string) error { if err != nil { return fmt.Errorf("failed to listen: %w", err) } + + defer i.ticker.Stop() return i.grpc.Serve(lis) } diff --git a/go.mod b/go.mod index f87e97d184..ccb9bb1418 100644 --- a/go.mod +++ b/go.mod @@ -13,6 +13,7 @@ require ( gopkg.in/yaml.v3 v3.0.1 k8s.io/api v0.29.0 k8s.io/apimachinery v0.29.0 + k8s.io/utils v0.0.0-20230726121419-3b25d923346b ) require ( @@ -39,7 +40,6 @@ require ( gopkg.in/inf.v0 v0.9.1 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect k8s.io/klog/v2 v2.110.1 // indirect - k8s.io/utils v0.0.0-20240102154912-e7106e64919e // indirect sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd // indirect sigs.k8s.io/structured-merge-diff/v4 v4.4.1 // indirect sigs.k8s.io/yaml v1.4.0 // indirect diff --git a/go.sum b/go.sum index 284a02a052..c2d31ae995 100644 --- a/go.sum +++ b/go.sum @@ -120,8 +120,8 @@ k8s.io/apimachinery v0.29.0 h1:+ACVktwyicPz0oc6MTMLwa2Pw3ouLAfAon1wPLtG48o= k8s.io/apimachinery v0.29.0/go.mod h1:eVBxQ/cwiJxH58eK/jd/vAk4mrxmVlnpBH5J2GbMeis= k8s.io/klog/v2 v2.110.1 h1:U/Af64HJf7FcwMcXyKm2RPM22WZzyR7OSpYj5tg3cL0= k8s.io/klog/v2 v2.110.1/go.mod h1:YGtd1984u+GgbuZ7e08/yBuAfKLSO0+uR1Fhi6ExXjo= -k8s.io/utils v0.0.0-20240102154912-e7106e64919e h1:eQ/4ljkx21sObifjzXwlPKpdGLrCfRziVtos3ofG/sQ= -k8s.io/utils v0.0.0-20240102154912-e7106e64919e/go.mod h1:OLgZIPagt7ERELqWJFomSt595RzquPNLL48iOWgYOg0= +k8s.io/utils v0.0.0-20230726121419-3b25d923346b h1:sgn3ZU783SCgtaSJjpcVVlRqd6GSnlTLKgpAAttJvpI= +k8s.io/utils v0.0.0-20230726121419-3b25d923346b/go.mod h1:OLgZIPagt7ERELqWJFomSt595RzquPNLL48iOWgYOg0= sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd h1:EDPBXCAspyGV4jQlpZSudPeMmr1bNJefnuqLsRAsHZo= sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd/go.mod h1:B8JuhiUyNFVKdsE8h686QcCxMaH6HrOAZj4vswFpcB0= sigs.k8s.io/structured-merge-diff/v4 v4.4.1 h1:150L+0vs/8DA78h1u02ooW1/fFq/Lwr+sGiqlzvrtq4= diff --git a/internal/attestation/snp/cachedClient.go b/internal/attestation/snp/cached_client.go similarity index 59% rename from internal/attestation/snp/cachedClient.go rename to internal/attestation/snp/cached_client.go index 2695d77170..301ee1c17e 100644 --- a/internal/attestation/snp/cachedClient.go +++ b/internal/attestation/snp/cached_client.go @@ -5,28 +5,41 @@ import ( "github.com/edgelesssys/nunki/internal/memstore" "github.com/google/go-sev-guest/verify/trust" + "k8s.io/utils/clock" ) type cachedKDSHTTPClient struct { trust.HTTPSGetter logger *slog.Logger - cache *memstore.Store[string, cacheEntry] + gcTicker clock.Ticker + cache *memstore.Store[string, []byte] } -func newCachedKDSHTTPClient(log *slog.Logger) *cachedKDSHTTPClient { +func newCachedKDSHTTPClient(ticker clock.Ticker, log *slog.Logger) *cachedKDSHTTPClient { trust.DefaultHTTPSGetter() - return &cachedKDSHTTPClient{ + + c := &cachedKDSHTTPClient{ HTTPSGetter: trust.DefaultHTTPSGetter(), logger: log.WithGroup("cached-kds-http-client"), - cache: memstore.New[string, cacheEntry](), + cache: memstore.New[string, []byte](), + gcTicker: ticker, } + + return c } func (c *cachedKDSHTTPClient) Get(url string) ([]byte, error) { + select { + case <-c.gcTicker.C(): + c.logger.Debug("Garbage collecting") + c.cache.Clear() + default: + } + if cached, ok := c.cache.Get(url); ok { c.logger.Debug("Get cached", "url", url) - return cached.data, nil + return cached, nil } c.logger.Debug("Get not cached", "url", url) @@ -34,12 +47,6 @@ func (c *cachedKDSHTTPClient) Get(url string) ([]byte, error) { if err != nil { return nil, err } - c.cache.Set(url, cacheEntry{ - data: res, - }) + c.cache.Set(url, res) return res, nil } - -type cacheEntry struct { - data []byte -} diff --git a/internal/attestation/snp/cached_client_test.go b/internal/attestation/snp/cached_client_test.go new file mode 100644 index 0000000000..cd8d261d8a --- /dev/null +++ b/internal/attestation/snp/cached_client_test.go @@ -0,0 +1,132 @@ +package snp + +import ( + "log/slog" + "sync" + "testing" + "time" + + "github.com/edgelesssys/nunki/internal/memstore" + "github.com/stretchr/testify/assert" + "go.uber.org/goleak" + testingclock "k8s.io/utils/clock/testing" +) + +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +} + +func TestCachedKDSHTTPClient(t *testing.T) { + t.Run("Get", func(t *testing.T) { + assert := assert.New(t) + + fakeGetter := &fakeHTTPSGetter{ + content: map[string][]byte{ + "foo": []byte("bar"), + }, + hits: map[string]int{}, + } + stepTime := 5 * time.Minute + testClock := testingclock.NewFakeClock(time.Now()) + ticker := testClock.NewTicker(stepTime) + client := &cachedKDSHTTPClient{ + HTTPSGetter: fakeGetter, + gcTicker: ticker, + cache: memstore.New[string, []byte](), + logger: slog.Default(), + } + + res, err := client.Get("foo") + assert.NoError(err) + assert.Equal([]byte("bar"), res) + assert.Equal(1, fakeGetter.hits["foo"]) + + // Expect a second call to return the cached value and not increase the hit counter. + res, err = client.Get("foo") + assert.NoError(err) + assert.Equal([]byte("bar"), res) + assert.Equal(1, fakeGetter.hits["foo"]) + + // After the step time, the cache should be invalidated and hit the backend again. + testClock.Step(stepTime) + res, err = client.Get("foo") + assert.NoError(err) + assert.Equal([]byte("bar"), res) + assert.Equal(2, fakeGetter.hits["foo"]) + }) + t.Run("Get error", func(t *testing.T) { + fakeGetter := &fakeHTTPSGetter{ + getErr: assert.AnError, + content: map[string][]byte{}, + hits: map[string]int{}, + } + testClock := testingclock.NewFakeClock(time.Now()) + ticker := testClock.NewTicker(5 * time.Minute) + client := &cachedKDSHTTPClient{ + HTTPSGetter: fakeGetter, + gcTicker: ticker, + cache: memstore.New[string, []byte](), + logger: slog.Default(), + } + + assert := assert.New(t) + + _, err := client.Get("foo") + assert.Error(err) + assert.Equal(1, fakeGetter.hits["foo"]) + }) + t.Run("Concurrent access", func(t *testing.T) { + assert := assert.New(t) + + fakeGetter := &fakeHTTPSGetter{ + content: map[string][]byte{ + "foo": []byte("bar"), + }, + hits: map[string]int{}, + } + numGets := 5 + stepTime := 5 * time.Minute + testClock := testingclock.NewFakeClock(time.Now()) + ticker := testClock.NewTicker(stepTime) + client := &cachedKDSHTTPClient{ + HTTPSGetter: fakeGetter, + gcTicker: ticker, + cache: memstore.New[string, []byte](), + logger: slog.Default(), + } + + var wg sync.WaitGroup + getFunc := func() { + defer wg.Done() + res, err := client.Get("foo") + assert.NoError(err) + assert.Equal([]byte("bar"), res) + } + + wg.Add(numGets) + go getFunc() + go getFunc() + go getFunc() + go getFunc() + go getFunc() + wg.Wait() + + // It's possible that the cache is not yet populated when it is checked by the second Get. + assert.Less(fakeGetter.hits["foo"], numGets) + }) +} + +type fakeHTTPSGetter struct { + content map[string][]byte + getErr error + + hitsMux sync.Mutex + hits map[string]int +} + +func (f *fakeHTTPSGetter) Get(url string) ([]byte, error) { + f.hitsMux.Lock() + defer f.hitsMux.Unlock() + f.hits[url]++ + return f.content[url], f.getErr +} diff --git a/internal/attestation/snp/validator.go b/internal/attestation/snp/validator.go index 4c4cd1ff17..ab2adcfcca 100644 --- a/internal/attestation/snp/validator.go +++ b/internal/attestation/snp/validator.go @@ -19,6 +19,7 @@ import ( "github.com/google/go-sev-guest/validate" "github.com/google/go-sev-guest/verify" "github.com/google/go-sev-guest/verify/trust" + "k8s.io/utils/clock" ) // Validator validates attestation statements. @@ -58,11 +59,11 @@ func NewValidator(optsGen validateOptsGenerator, log *slog.Logger) *Validator { } // NewValidatorWithCallbacks returns a new Validator with callbacks. -func NewValidatorWithCallbacks(optsGen validateOptsGenerator, log *slog.Logger, callbacks ...validateCallbacker) *Validator { +func NewValidatorWithCallbacks(optsGen validateOptsGenerator, ticker clock.Ticker, log *slog.Logger, callbacks ...validateCallbacker) *Validator { return &Validator{ validateOptsGen: optsGen, callbackers: callbacks, - kdsGetter: newCachedKDSHTTPClient(log), + kdsGetter: newCachedKDSHTTPClient(ticker, log), logger: log.WithGroup("snp-validator"), } } diff --git a/internal/memstore/memstore.go b/internal/memstore/memstore.go index f3b435e4e2..020763f805 100644 --- a/internal/memstore/memstore.go +++ b/internal/memstore/memstore.go @@ -40,3 +40,10 @@ func (s *Store[keyT, valueT]) GetAll() []valueT { } return values } + +// Clear clears all values from store. +func (s *Store[keyT, valueT]) Clear() { + s.mux.Lock() + defer s.mux.Unlock() + clear(s.m) +}