Skip to content

Commit

Permalink
snp: clear kds cache daily
Browse files Browse the repository at this point in the history
  • Loading branch information
3u13r committed Jan 9, 2024
1 parent ec100d2 commit 2116f59
Show file tree
Hide file tree
Showing 7 changed files with 171 additions and 18 deletions.
8 changes: 7 additions & 1 deletion coordinator/intercom.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@ import (
"google.golang.org/grpc/codes"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/status"
"k8s.io/utils/clock"
)

type intercomServer struct {
grpc *grpc.Server
certGet certGetter
caChainGetter certChainGetter
ticker clock.Ticker
logger *slog.Logger

intercom.UnimplementedIntercomServer
Expand All @@ -31,7 +33,8 @@ type certGetter interface {
}

func newIntercomServer(meshAuth *meshAuthority, caGetter certChainGetter, log *slog.Logger) *intercomServer {
validator := snp.NewValidatorWithCallbacks(meshAuth, log, meshAuth)
ticker := clock.RealClock{}.NewTicker(24 * time.Hour)
validator := snp.NewValidatorWithCallbacks(meshAuth, ticker, log, meshAuth)
credentials := atlscredentials.New(atls.NoIssuer, []atls.Validator{validator})
grpcServer := grpc.NewServer(
grpc.Creds(credentials),
Expand All @@ -41,6 +44,7 @@ func newIntercomServer(meshAuth *meshAuthority, caGetter certChainGetter, log *s
grpc: grpcServer,
certGet: meshAuth,
caChainGetter: caGetter,
ticker: ticker,
logger: log.WithGroup("intercom"),
}
intercom.RegisterIntercomServer(s.grpc, s)
Expand All @@ -52,6 +56,8 @@ func (i *intercomServer) Serve(endpoint string) error {
if err != nil {
return fmt.Errorf("failed to listen: %w", err)
}

defer i.ticker.Stop()
return i.grpc.Serve(lis)
}

Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ require (
gopkg.in/yaml.v3 v3.0.1
k8s.io/api v0.29.0
k8s.io/apimachinery v0.29.0
k8s.io/utils v0.0.0-20230726121419-3b25d923346b
)

require (
Expand All @@ -39,7 +40,6 @@ require (
gopkg.in/inf.v0 v0.9.1 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
k8s.io/klog/v2 v2.110.1 // indirect
k8s.io/utils v0.0.0-20240102154912-e7106e64919e // indirect
sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd // indirect
sigs.k8s.io/structured-merge-diff/v4 v4.4.1 // indirect
sigs.k8s.io/yaml v1.4.0 // indirect
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ k8s.io/apimachinery v0.29.0 h1:+ACVktwyicPz0oc6MTMLwa2Pw3ouLAfAon1wPLtG48o=
k8s.io/apimachinery v0.29.0/go.mod h1:eVBxQ/cwiJxH58eK/jd/vAk4mrxmVlnpBH5J2GbMeis=
k8s.io/klog/v2 v2.110.1 h1:U/Af64HJf7FcwMcXyKm2RPM22WZzyR7OSpYj5tg3cL0=
k8s.io/klog/v2 v2.110.1/go.mod h1:YGtd1984u+GgbuZ7e08/yBuAfKLSO0+uR1Fhi6ExXjo=
k8s.io/utils v0.0.0-20240102154912-e7106e64919e h1:eQ/4ljkx21sObifjzXwlPKpdGLrCfRziVtos3ofG/sQ=
k8s.io/utils v0.0.0-20240102154912-e7106e64919e/go.mod h1:OLgZIPagt7ERELqWJFomSt595RzquPNLL48iOWgYOg0=
k8s.io/utils v0.0.0-20230726121419-3b25d923346b h1:sgn3ZU783SCgtaSJjpcVVlRqd6GSnlTLKgpAAttJvpI=
k8s.io/utils v0.0.0-20230726121419-3b25d923346b/go.mod h1:OLgZIPagt7ERELqWJFomSt595RzquPNLL48iOWgYOg0=
sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd h1:EDPBXCAspyGV4jQlpZSudPeMmr1bNJefnuqLsRAsHZo=
sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd/go.mod h1:B8JuhiUyNFVKdsE8h686QcCxMaH6HrOAZj4vswFpcB0=
sigs.k8s.io/structured-merge-diff/v4 v4.4.1 h1:150L+0vs/8DA78h1u02ooW1/fFq/Lwr+sGiqlzvrtq4=
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,41 +5,48 @@ import (

"github.com/edgelesssys/nunki/internal/memstore"
"github.com/google/go-sev-guest/verify/trust"
"k8s.io/utils/clock"
)

type cachedKDSHTTPClient struct {
trust.HTTPSGetter
logger *slog.Logger

cache *memstore.Store[string, cacheEntry]
gcTicker clock.Ticker
cache *memstore.Store[string, []byte]
}

func newCachedKDSHTTPClient(log *slog.Logger) *cachedKDSHTTPClient {
func newCachedKDSHTTPClient(ticker clock.Ticker, log *slog.Logger) *cachedKDSHTTPClient {
trust.DefaultHTTPSGetter()
return &cachedKDSHTTPClient{

c := &cachedKDSHTTPClient{
HTTPSGetter: trust.DefaultHTTPSGetter(),
logger: log.WithGroup("cached-kds-http-client"),
cache: memstore.New[string, cacheEntry](),
cache: memstore.New[string, []byte](),
gcTicker: ticker,
}

return c
}

func (c *cachedKDSHTTPClient) Get(url string) ([]byte, error) {
select {
case <-c.gcTicker.C():
c.logger.Debug("Garbage collecting")
c.cache.Clear()
default:
}

if cached, ok := c.cache.Get(url); ok {
c.logger.Debug("Get cached", "url", url)
return cached.data, nil
return cached, nil
}

c.logger.Debug("Get not cached", "url", url)
res, err := c.HTTPSGetter.Get(url)
if err != nil {
return nil, err
}
c.cache.Set(url, cacheEntry{
data: res,
})
c.cache.Set(url, res)
return res, nil
}

type cacheEntry struct {
data []byte
}
132 changes: 132 additions & 0 deletions internal/attestation/snp/cached_client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
package snp

import (
"log/slog"
"sync"
"testing"
"time"

"github.com/edgelesssys/nunki/internal/memstore"
"github.com/stretchr/testify/assert"
"go.uber.org/goleak"
testingclock "k8s.io/utils/clock/testing"
)

func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}

func TestCachedKDSHTTPClient(t *testing.T) {
t.Run("Get", func(t *testing.T) {
assert := assert.New(t)

fakeGetter := &fakeHTTPSGetter{
content: map[string][]byte{
"foo": []byte("bar"),
},
hits: map[string]int{},
}
stepTime := 5 * time.Minute
testClock := testingclock.NewFakeClock(time.Now())
ticker := testClock.NewTicker(stepTime)
client := &cachedKDSHTTPClient{
HTTPSGetter: fakeGetter,
gcTicker: ticker,
cache: memstore.New[string, []byte](),
logger: slog.Default(),
}

res, err := client.Get("foo")
assert.NoError(err)
assert.Equal([]byte("bar"), res)
assert.Equal(1, fakeGetter.hits["foo"])

// Expect a second call to return the cached value and not increase the hit counter.
res, err = client.Get("foo")
assert.NoError(err)
assert.Equal([]byte("bar"), res)
assert.Equal(1, fakeGetter.hits["foo"])

// After the step time, the cache should be invalidated and hit the backend again.
testClock.Step(stepTime)
res, err = client.Get("foo")
assert.NoError(err)
assert.Equal([]byte("bar"), res)
assert.Equal(2, fakeGetter.hits["foo"])
})
t.Run("Get error", func(t *testing.T) {
fakeGetter := &fakeHTTPSGetter{
getErr: assert.AnError,
content: map[string][]byte{},
hits: map[string]int{},
}
testClock := testingclock.NewFakeClock(time.Now())
ticker := testClock.NewTicker(5 * time.Minute)
client := &cachedKDSHTTPClient{
HTTPSGetter: fakeGetter,
gcTicker: ticker,
cache: memstore.New[string, []byte](),
logger: slog.Default(),
}

assert := assert.New(t)

_, err := client.Get("foo")
assert.Error(err)
assert.Equal(1, fakeGetter.hits["foo"])
})
t.Run("Concurrent access", func(t *testing.T) {
assert := assert.New(t)

fakeGetter := &fakeHTTPSGetter{
content: map[string][]byte{
"foo": []byte("bar"),
},
hits: map[string]int{},
}
numGets := 5
stepTime := 5 * time.Minute
testClock := testingclock.NewFakeClock(time.Now())
ticker := testClock.NewTicker(stepTime)
client := &cachedKDSHTTPClient{
HTTPSGetter: fakeGetter,
gcTicker: ticker,
cache: memstore.New[string, []byte](),
logger: slog.Default(),
}

var wg sync.WaitGroup
getFunc := func() {
defer wg.Done()
res, err := client.Get("foo")
assert.NoError(err)
assert.Equal([]byte("bar"), res)
}

wg.Add(numGets)
go getFunc()
go getFunc()
go getFunc()
go getFunc()
go getFunc()
wg.Wait()

// It's possible that the cache is not yet populated when it is checked by the second Get.
assert.Less(fakeGetter.hits["foo"], numGets)
})
}

type fakeHTTPSGetter struct {
content map[string][]byte
getErr error

hitsMux sync.Mutex
hits map[string]int
}

func (f *fakeHTTPSGetter) Get(url string) ([]byte, error) {
f.hitsMux.Lock()
defer f.hitsMux.Unlock()
f.hits[url]++
return f.content[url], f.getErr
}
5 changes: 3 additions & 2 deletions internal/attestation/snp/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"github.com/google/go-sev-guest/validate"
"github.com/google/go-sev-guest/verify"
"github.com/google/go-sev-guest/verify/trust"
"k8s.io/utils/clock"
)

// Validator validates attestation statements.
Expand Down Expand Up @@ -58,11 +59,11 @@ func NewValidator(optsGen validateOptsGenerator, log *slog.Logger) *Validator {
}

// NewValidatorWithCallbacks returns a new Validator with callbacks.
func NewValidatorWithCallbacks(optsGen validateOptsGenerator, log *slog.Logger, callbacks ...validateCallbacker) *Validator {
func NewValidatorWithCallbacks(optsGen validateOptsGenerator, ticker clock.Ticker, log *slog.Logger, callbacks ...validateCallbacker) *Validator {
return &Validator{
validateOptsGen: optsGen,
callbackers: callbacks,
kdsGetter: newCachedKDSHTTPClient(log),
kdsGetter: newCachedKDSHTTPClient(ticker, log),
logger: log.WithGroup("snp-validator"),
}
}
Expand Down
7 changes: 7 additions & 0 deletions internal/memstore/memstore.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,10 @@ func (s *Store[keyT, valueT]) GetAll() []valueT {
}
return values
}

// Clear clears all values from store.
func (s *Store[keyT, valueT]) Clear() {
s.mux.Lock()
defer s.mux.Unlock()
clear(s.m)
}

0 comments on commit 2116f59

Please sign in to comment.