Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

snp: KDS cache garbage collection #37

Merged
merged 1 commit into from
Jan 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion coordinator/intercom.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@ import (
"google.golang.org/grpc/codes"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/status"
"k8s.io/utils/clock"
)

type intercomServer struct {
grpc *grpc.Server
certGet certGetter
caChainGetter certChainGetter
ticker clock.Ticker
logger *slog.Logger

intercom.UnimplementedIntercomServer
Expand All @@ -31,7 +33,8 @@ type certGetter interface {
}

func newIntercomServer(meshAuth *meshAuthority, caGetter certChainGetter, log *slog.Logger) *intercomServer {
validator := snp.NewValidatorWithCallbacks(meshAuth, log, meshAuth)
ticker := clock.RealClock{}.NewTicker(24 * time.Hour)
validator := snp.NewValidatorWithCallbacks(meshAuth, ticker, log, meshAuth)
credentials := atlscredentials.New(atls.NoIssuer, []atls.Validator{validator})
grpcServer := grpc.NewServer(
grpc.Creds(credentials),
Expand All @@ -41,6 +44,7 @@ func newIntercomServer(meshAuth *meshAuthority, caGetter certChainGetter, log *s
grpc: grpcServer,
certGet: meshAuth,
caChainGetter: caGetter,
ticker: ticker,
logger: log.WithGroup("intercom"),
}
intercom.RegisterIntercomServer(s.grpc, s)
Expand All @@ -52,6 +56,8 @@ func (i *intercomServer) Serve(endpoint string) error {
if err != nil {
return fmt.Errorf("failed to listen: %w", err)
}

defer i.ticker.Stop()
return i.grpc.Serve(lis)
}

Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ require (
gopkg.in/yaml.v3 v3.0.1
k8s.io/api v0.29.0
k8s.io/apimachinery v0.29.0
k8s.io/utils v0.0.0-20240102154912-e7106e64919e
)

require (
Expand All @@ -39,7 +40,6 @@ require (
gopkg.in/inf.v0 v0.9.1 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
k8s.io/klog/v2 v2.110.1 // indirect
k8s.io/utils v0.0.0-20240102154912-e7106e64919e // indirect
sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd // indirect
sigs.k8s.io/structured-merge-diff/v4 v4.4.1 // indirect
sigs.k8s.io/yaml v1.4.0 // indirect
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,41 +5,48 @@ import (

"github.com/edgelesssys/nunki/internal/memstore"
"github.com/google/go-sev-guest/verify/trust"
"k8s.io/utils/clock"
)

type cachedKDSHTTPClient struct {
trust.HTTPSGetter
logger *slog.Logger

cache *memstore.Store[string, cacheEntry]
gcTicker clock.Ticker
cache *memstore.Store[string, []byte]
}

func newCachedKDSHTTPClient(log *slog.Logger) *cachedKDSHTTPClient {
func newCachedKDSHTTPClient(ticker clock.Ticker, log *slog.Logger) *cachedKDSHTTPClient {
trust.DefaultHTTPSGetter()
return &cachedKDSHTTPClient{

c := &cachedKDSHTTPClient{
HTTPSGetter: trust.DefaultHTTPSGetter(),
logger: log.WithGroup("cached-kds-http-client"),
cache: memstore.New[string, cacheEntry](),
cache: memstore.New[string, []byte](),
gcTicker: ticker,
}

return c
}

func (c *cachedKDSHTTPClient) Get(url string) ([]byte, error) {
select {
case <-c.gcTicker.C():
c.logger.Debug("Garbage collecting")
c.cache.Clear()
default:
}

if cached, ok := c.cache.Get(url); ok {
c.logger.Debug("Get cached", "url", url)
return cached.data, nil
return cached, nil
}

c.logger.Debug("Get not cached", "url", url)
res, err := c.HTTPSGetter.Get(url)
if err != nil {
return nil, err
}
c.cache.Set(url, cacheEntry{
data: res,
})
c.cache.Set(url, res)
return res, nil
}

type cacheEntry struct {
data []byte
}
132 changes: 132 additions & 0 deletions internal/attestation/snp/cached_client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
package snp

import (
"log/slog"
"sync"
"testing"
"time"

"github.com/edgelesssys/nunki/internal/memstore"
"github.com/stretchr/testify/assert"
"go.uber.org/goleak"
testingclock "k8s.io/utils/clock/testing"
)

func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}

func TestCachedKDSHTTPClient(t *testing.T) {
t.Run("Get", func(t *testing.T) {
assert := assert.New(t)

fakeGetter := &fakeHTTPSGetter{
content: map[string][]byte{
"foo": []byte("bar"),
},
hits: map[string]int{},
}
stepTime := 5 * time.Minute
testClock := testingclock.NewFakeClock(time.Now())
ticker := testClock.NewTicker(stepTime)
client := &cachedKDSHTTPClient{
HTTPSGetter: fakeGetter,
gcTicker: ticker,
cache: memstore.New[string, []byte](),
logger: slog.Default(),
}

res, err := client.Get("foo")
assert.NoError(err)
assert.Equal([]byte("bar"), res)
assert.Equal(1, fakeGetter.hits["foo"])

// Expect a second call to return the cached value and not increase the hit counter.
res, err = client.Get("foo")
assert.NoError(err)
assert.Equal([]byte("bar"), res)
assert.Equal(1, fakeGetter.hits["foo"])

// After the step time, the cache should be invalidated and hit the backend again.
testClock.Step(stepTime)
res, err = client.Get("foo")
assert.NoError(err)
assert.Equal([]byte("bar"), res)
assert.Equal(2, fakeGetter.hits["foo"])
})
t.Run("Get error", func(t *testing.T) {
fakeGetter := &fakeHTTPSGetter{
getErr: assert.AnError,
content: map[string][]byte{},
hits: map[string]int{},
}
testClock := testingclock.NewFakeClock(time.Now())
ticker := testClock.NewTicker(5 * time.Minute)
client := &cachedKDSHTTPClient{
HTTPSGetter: fakeGetter,
gcTicker: ticker,
cache: memstore.New[string, []byte](),
logger: slog.Default(),
}

assert := assert.New(t)

_, err := client.Get("foo")
assert.Error(err)
assert.Equal(1, fakeGetter.hits["foo"])
})
t.Run("Concurrent access", func(t *testing.T) {
assert := assert.New(t)

fakeGetter := &fakeHTTPSGetter{
content: map[string][]byte{
"foo": []byte("bar"),
},
hits: map[string]int{},
}
numGets := 5
stepTime := 5 * time.Minute
testClock := testingclock.NewFakeClock(time.Now())
ticker := testClock.NewTicker(stepTime)
client := &cachedKDSHTTPClient{
HTTPSGetter: fakeGetter,
gcTicker: ticker,
cache: memstore.New[string, []byte](),
logger: slog.Default(),
}

var wg sync.WaitGroup
getFunc := func() {
defer wg.Done()
res, err := client.Get("foo")
assert.NoError(err)
assert.Equal([]byte("bar"), res)
}

wg.Add(numGets)
go getFunc()
go getFunc()
go getFunc()
go getFunc()
go getFunc()
wg.Wait()

// It's possible that the cache is not yet populated when it is checked by the second Get.
assert.Less(fakeGetter.hits["foo"], numGets)
})
}

type fakeHTTPSGetter struct {
content map[string][]byte
getErr error

hitsMux sync.Mutex
hits map[string]int
}

func (f *fakeHTTPSGetter) Get(url string) ([]byte, error) {
f.hitsMux.Lock()
defer f.hitsMux.Unlock()
f.hits[url]++
return f.content[url], f.getErr
}
5 changes: 3 additions & 2 deletions internal/attestation/snp/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"github.com/google/go-sev-guest/validate"
"github.com/google/go-sev-guest/verify"
"github.com/google/go-sev-guest/verify/trust"
"k8s.io/utils/clock"
)

// Validator validates attestation statements.
Expand Down Expand Up @@ -58,11 +59,11 @@ func NewValidator(optsGen validateOptsGenerator, log *slog.Logger) *Validator {
}

// NewValidatorWithCallbacks returns a new Validator with callbacks.
func NewValidatorWithCallbacks(optsGen validateOptsGenerator, log *slog.Logger, callbacks ...validateCallbacker) *Validator {
func NewValidatorWithCallbacks(optsGen validateOptsGenerator, ticker clock.Ticker, log *slog.Logger, callbacks ...validateCallbacker) *Validator {
return &Validator{
validateOptsGen: optsGen,
callbackers: callbacks,
kdsGetter: newCachedKDSHTTPClient(log),
kdsGetter: newCachedKDSHTTPClient(ticker, log),
logger: log.WithGroup("snp-validator"),
}
}
Expand Down
7 changes: 7 additions & 0 deletions internal/memstore/memstore.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,10 @@ func (s *Store[keyT, valueT]) GetAll() []valueT {
}
return values
}

// Clear clears all values from store.
func (s *Store[keyT, valueT]) Clear() {
s.mux.Lock()
defer s.mux.Unlock()
clear(s.m)
}