From ae5a52d4b184ee2ce3526c73df7e1f2e1ddbc604 Mon Sep 17 00:00:00 2001 From: Leonard Cohnen Date: Fri, 29 Dec 2023 17:01:03 +0100 Subject: [PATCH] snp: cache amd kds requests --- internal/attestation/snp/cachedClient.go | 45 ++++++++++++++++++++++++ internal/attestation/snp/validator.go | 7 +++- 2 files changed, 51 insertions(+), 1 deletion(-) create mode 100644 internal/attestation/snp/cachedClient.go diff --git a/internal/attestation/snp/cachedClient.go b/internal/attestation/snp/cachedClient.go new file mode 100644 index 0000000000..eab4ce5e5a --- /dev/null +++ b/internal/attestation/snp/cachedClient.go @@ -0,0 +1,45 @@ +package snp + +import ( + "log/slog" + + "github.com/edgelesssys/nunki/internal/memstore" + "github.com/google/go-sev-guest/verify/trust" +) + +type cachedKDSHTTPClient struct { + trust.HTTPSGetter + logger *slog.Logger + + cache *memstore.Store[string, cacheEntry] +} + +func NewCachedKDSHTTPClient(log *slog.Logger) *cachedKDSHTTPClient { + trust.DefaultHTTPSGetter() + return &cachedKDSHTTPClient{ + HTTPSGetter: trust.DefaultHTTPSGetter(), + logger: log.WithGroup("cached-kds-http-client"), + cache: memstore.New[string, cacheEntry](), + } +} + +func (c *cachedKDSHTTPClient) Get(url string) ([]byte, error) { + if cached, ok := c.cache.Get(url); ok { + c.logger.Debug("Get cached", "url", url) + return cached.data, nil + } + + c.logger.Debug("Get not cached", "url", url) + res, err := c.HTTPSGetter.Get(url) + if err != nil { + return nil, err + } + c.cache.Set(url, cacheEntry{ + data: res, + }) + return res, nil +} + +type cacheEntry struct { + data []byte +} diff --git a/internal/attestation/snp/validator.go b/internal/attestation/snp/validator.go index 8fa88e9b79..de43acb0b4 100644 --- a/internal/attestation/snp/validator.go +++ b/internal/attestation/snp/validator.go @@ -18,11 +18,13 @@ import ( "github.com/google/go-sev-guest/proto/sevsnp" "github.com/google/go-sev-guest/validate" "github.com/google/go-sev-guest/verify" + "github.com/google/go-sev-guest/verify/trust" ) type Validator struct { validateOptsGen validateOptsGenerator callbackers []validateCallbacker + kdsGetter trust.HTTPSGetter logger *slog.Logger } @@ -53,6 +55,7 @@ func NewValidatorWithCallbacks(optsGen validateOptsGenerator, log *slog.Logger, return &Validator{ validateOptsGen: optsGen, callbackers: callbacks, + kdsGetter: NewCachedKDSHTTPClient(log), logger: log.WithGroup("snp-validator"), } } @@ -85,7 +88,9 @@ func (v *Validator) Validate(ctx context.Context, attDocRaw []byte, nonce []byte // Report signature verification. - verifyOpts := &verify.Options{} + verifyOpts := &verify.Options{ + Getter: v.kdsGetter, + } attestation, err := verify.GetAttestationFromReport(report, verifyOpts) if err != nil { return fmt.Errorf("getting attestation from report: %w", err)