diff --git a/go.mod b/go.mod index 184134455e..e6c0001789 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( gopkg.in/yaml.v2 v2.4.0 k8s.io/api v0.28.4 k8s.io/apimachinery v0.28.4 + k8s.io/utils v0.0.0-20230406110748-d93618cff8a2 ) require ( @@ -38,7 +39,6 @@ require ( gopkg.in/inf.v0 v0.9.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect k8s.io/klog/v2 v2.100.1 // indirect - k8s.io/utils v0.0.0-20230406110748-d93618cff8a2 // indirect sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd // indirect sigs.k8s.io/structured-merge-diff/v4 v4.2.3 // indirect sigs.k8s.io/yaml v1.3.0 // indirect diff --git a/internal/attestation/snp/cachedClient.go b/internal/attestation/snp/cachedClient.go index eab4ce5e5a..cc0e575600 100644 --- a/internal/attestation/snp/cachedClient.go +++ b/internal/attestation/snp/cachedClient.go @@ -2,31 +2,40 @@ package snp import ( "log/slog" + "time" "github.com/edgelesssys/nunki/internal/memstore" "github.com/google/go-sev-guest/verify/trust" + "k8s.io/utils/clock" ) type cachedKDSHTTPClient struct { trust.HTTPSGetter logger *slog.Logger - cache *memstore.Store[string, cacheEntry] + gcTicker clock.Ticker + cache *memstore.Store[string, []byte] } func NewCachedKDSHTTPClient(log *slog.Logger) *cachedKDSHTTPClient { trust.DefaultHTTPSGetter() - return &cachedKDSHTTPClient{ + + gc := clock.RealClock{}.NewTicker(24 * time.Hour) + c := &cachedKDSHTTPClient{ HTTPSGetter: trust.DefaultHTTPSGetter(), logger: log.WithGroup("cached-kds-http-client"), - cache: memstore.New[string, cacheEntry](), + cache: memstore.New[string, []byte](), + gcTicker: gc, } + + go c.garbageCollect() + return c } func (c *cachedKDSHTTPClient) Get(url string) ([]byte, error) { if cached, ok := c.cache.Get(url); ok { c.logger.Debug("Get cached", "url", url) - return cached.data, nil + return cached, nil } c.logger.Debug("Get not cached", "url", url) @@ -34,12 +43,13 @@ func (c *cachedKDSHTTPClient) Get(url string) ([]byte, error) { if err != nil { return nil, err } - c.cache.Set(url, cacheEntry{ - data: res, - }) + c.cache.Set(url, res) return res, nil } -type cacheEntry struct { - data []byte +func (c *cachedKDSHTTPClient) garbageCollect() { + for range c.gcTicker.C() { + c.logger.Debug("Garbage collecting") + c.cache.Clear() + } } diff --git a/internal/memstore/memstore.go b/internal/memstore/memstore.go index dafcc72182..8734fe9b89 100644 --- a/internal/memstore/memstore.go +++ b/internal/memstore/memstore.go @@ -35,3 +35,9 @@ func (s *Store[keyT, valueT]) GetAll() []valueT { } return values } + +func (s *Store[keyT, valueT]) Clear() { + s.mux.Lock() + defer s.mux.Unlock() + clear(s.m) +}