Skip to content

Commit

Permalink
snp: clear kds cache daily
Browse files Browse the repository at this point in the history
  • Loading branch information
3u13r committed Jan 5, 2024
1 parent 6cbd7e0 commit c236551
Show file tree
Hide file tree
Showing 7 changed files with 128 additions and 16 deletions.
8 changes: 7 additions & 1 deletion coordinator/intercom.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@ import (
"github.com/edgelesssys/nunki/internal/intercom"
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
"k8s.io/utils/clock"
)

type intercomServer struct {
grpc *grpc.Server
certGet certGetter
caChainGetter certChainGetter
ticker clock.Ticker
logger *slog.Logger

intercom.UnimplementedIntercomServer
Expand All @@ -29,7 +31,8 @@ type certGetter interface {
}

func newIntercomServer(meshAuth *meshAuthority, caGetter certChainGetter, log *slog.Logger) (*intercomServer, error) {
validator := snp.NewValidatorWithCallbacks(meshAuth, log, meshAuth)
ticker := clock.RealClock{}.NewTicker(24 * time.Hour)
validator := snp.NewValidatorWithCallbacks(meshAuth, ticker, log, meshAuth)
credentials := atlscredentials.New(atls.NoIssuer, []atls.Validator{validator})
grpcServer := grpc.NewServer(
grpc.Creds(credentials),
Expand All @@ -39,6 +42,7 @@ func newIntercomServer(meshAuth *meshAuthority, caGetter certChainGetter, log *s
grpc: grpcServer,
certGet: meshAuth,
caChainGetter: caGetter,
ticker: ticker,
logger: log.WithGroup("intercom"),
}
intercom.RegisterIntercomServer(s.grpc, s)
Expand All @@ -50,6 +54,8 @@ func (i *intercomServer) Serve(endpoint string) error {
if err != nil {
return fmt.Errorf("failed to listen: %v", err)
}

defer i.ticker.Stop()
return i.grpc.Serve(lis)
}

Expand Down
4 changes: 4 additions & 0 deletions coordinator/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ import (
"log/slog"
"net"
"os"
"time"

"github.com/edgelesssys/nunki/internal/ca"
"github.com/edgelesssys/nunki/internal/coordapi"
"github.com/edgelesssys/nunki/internal/intercom"
"k8s.io/utils/clock"
)

func main() {
Expand Down Expand Up @@ -56,6 +58,8 @@ func run() (retErr error) {
}
}()

ticker := clock.RealClock{}.NewTicker(24 * time.Hour)
defer ticker.Stop()
intercomS, err := newIntercomServer(meshAuth, caInstance, logger)
if err != nil {
return fmt.Errorf("creating intercom server: %v", err)
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ require (
gopkg.in/yaml.v2 v2.4.0
k8s.io/api v0.28.4
k8s.io/apimachinery v0.28.4
k8s.io/utils v0.0.0-20230406110748-d93618cff8a2
)

require (
Expand All @@ -39,7 +40,6 @@ require (
gopkg.in/inf.v0 v0.9.1 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
k8s.io/klog/v2 v2.100.1 // indirect
k8s.io/utils v0.0.0-20230406110748-d93618cff8a2 // indirect
sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd // indirect
sigs.k8s.io/structured-merge-diff/v4 v4.2.3 // indirect
sigs.k8s.io/yaml v1.3.0 // indirect
Expand Down
31 changes: 19 additions & 12 deletions internal/attestation/snp/cachedClient.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,41 +5,48 @@ import (

"github.com/edgelesssys/nunki/internal/memstore"
"github.com/google/go-sev-guest/verify/trust"
"k8s.io/utils/clock"
)

type cachedKDSHTTPClient struct {
trust.HTTPSGetter
logger *slog.Logger

cache *memstore.Store[string, cacheEntry]
gcTicker clock.Ticker
cache *memstore.Store[string, []byte]
}

func NewCachedKDSHTTPClient(log *slog.Logger) *cachedKDSHTTPClient {
func NewCachedKDSHTTPClient(ticker clock.Ticker, log *slog.Logger) *cachedKDSHTTPClient {
trust.DefaultHTTPSGetter()
return &cachedKDSHTTPClient{

c := &cachedKDSHTTPClient{
HTTPSGetter: trust.DefaultHTTPSGetter(),
logger: log.WithGroup("cached-kds-http-client"),
cache: memstore.New[string, cacheEntry](),
cache: memstore.New[string, []byte](),
gcTicker: ticker,
}

return c
}

func (c *cachedKDSHTTPClient) Get(url string) ([]byte, error) {
select {
case <-c.gcTicker.C():
c.logger.Debug("Garbage collecting")
c.cache.Clear()
default:
}

if cached, ok := c.cache.Get(url); ok {
c.logger.Debug("Get cached", "url", url)
return cached.data, nil
return cached, nil
}

c.logger.Debug("Get not cached", "url", url)
res, err := c.HTTPSGetter.Get(url)
if err != nil {
return nil, err
}
c.cache.Set(url, cacheEntry{
data: res,
})
c.cache.Set(url, res)
return res, nil
}

type cacheEntry struct {
data []byte
}
88 changes: 88 additions & 0 deletions internal/attestation/snp/cachedClient_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package snp

import (
"log/slog"
"testing"
"time"

"github.com/edgelesssys/nunki/internal/memstore"
"github.com/stretchr/testify/assert"
"go.uber.org/goleak"
testingclock "k8s.io/utils/clock/testing"
)

func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}

func TestCachedKDSHTTPClient(t *testing.T) {
t.Run("Get", func(t *testing.T) {
assert := assert.New(t)

fakeGetter := &fakeHTTPSGetter{
content: map[string][]byte{
"foo": []byte("bar"),
},
hits: map[string]int{},
}
stepTime := 5 * time.Minute
testClock := testingclock.NewFakeClock(time.Now())
ticker := testClock.NewTicker(stepTime)
client := &cachedKDSHTTPClient{
HTTPSGetter: fakeGetter,
gcTicker: ticker,
cache: memstore.New[string, []byte](),
logger: slog.Default(),
}

res, err := client.Get("foo")
assert.NoError(err)
assert.Equal([]byte("bar"), res)
assert.Equal(1, fakeGetter.hits["foo"])

// Expect a second call to return the cached value and not increase the hit counter.
res, err = client.Get("foo")
assert.NoError(err)
assert.Equal([]byte("bar"), res)
assert.Equal(1, fakeGetter.hits["foo"])

// After the step time, the cache should be invalidated and hit the backend again.
testClock.Step(stepTime)
res, err = client.Get("foo")
assert.NoError(err)
assert.Equal([]byte("bar"), res)
assert.Equal(2, fakeGetter.hits["foo"])
})
t.Run("Get error", func(t *testing.T) {
fakeGetter := &fakeHTTPSGetter{
getErr: assert.AnError,
content: map[string][]byte{},
hits: map[string]int{},
}
testClock := testingclock.NewFakeClock(time.Now())
ticker := testClock.NewTicker(5 * time.Minute)
client := &cachedKDSHTTPClient{
HTTPSGetter: fakeGetter,
gcTicker: ticker,
cache: memstore.New[string, []byte](),
logger: slog.Default(),
}

assert := assert.New(t)

_, err := client.Get("foo")
assert.Error(err)
assert.Equal(1, fakeGetter.hits["foo"])
})
}

type fakeHTTPSGetter struct {
content map[string][]byte
hits map[string]int
getErr error
}

func (f *fakeHTTPSGetter) Get(url string) ([]byte, error) {
f.hits[url]++
return f.content[url], f.getErr
}
5 changes: 3 additions & 2 deletions internal/attestation/snp/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"github.com/google/go-sev-guest/validate"
"github.com/google/go-sev-guest/verify"
"github.com/google/go-sev-guest/verify/trust"
"k8s.io/utils/clock"
)

type Validator struct {
Expand Down Expand Up @@ -51,11 +52,11 @@ func NewValidator(optsGen validateOptsGenerator, log *slog.Logger) *Validator {
}
}

func NewValidatorWithCallbacks(optsGen validateOptsGenerator, log *slog.Logger, callbacks ...validateCallbacker) *Validator {
func NewValidatorWithCallbacks(optsGen validateOptsGenerator, ticker clock.Ticker, log *slog.Logger, callbacks ...validateCallbacker) *Validator {
return &Validator{
validateOptsGen: optsGen,
callbackers: callbacks,
kdsGetter: NewCachedKDSHTTPClient(log),
kdsGetter: NewCachedKDSHTTPClient(ticker, log),
logger: log.WithGroup("snp-validator"),
}
}
Expand Down
6 changes: 6 additions & 0 deletions internal/memstore/memstore.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,9 @@ func (s *Store[keyT, valueT]) GetAll() []valueT {
}
return values
}

func (s *Store[keyT, valueT]) Clear() {
s.mux.Lock()
defer s.mux.Unlock()
clear(s.m)
}

0 comments on commit c236551

Please sign in to comment.