diff --git a/cli/cmd/recover.go b/cli/cmd/recover.go index 46ab1a7a05..3f487f160a 100644 --- a/cli/cmd/recover.go +++ b/cli/cmd/recover.go @@ -81,17 +81,21 @@ func runRecover(cmd *cobra.Command, _ []string) error { return fmt.Errorf("getting cache dir: %w", err) } log.Debug("Using KDS cache dir", "dir", kdsDir) + kdsCache := fsstore.New(kdsDir, log.WithGroup("kds-cache")) + kdsGetter := snp.NewCachedHTTPSGetter(kdsCache, snp.NeverGCTicker, log.WithGroup("kds-getter")) - validateOptsGen, err := newCoordinatorValidateOptsGen(m, flags.policy) + optsGens, err := m.SNPValidateOpts() if err != nil { - return fmt.Errorf("generating validate opts: %w", err) + return fmt.Errorf("getting AKS validate options: %w", err) } - kdsCache := fsstore.New(kdsDir, log.WithGroup("kds-cache")) - kdsGetter := snp.NewCachedHTTPSGetter(kdsCache, snp.NeverGCTicker, log.WithGroup("kds-getter")) - validator := snp.NewValidator(validateOptsGen, kdsGetter, - logger.NewWithAttrs(logger.NewNamed(log, "validator"), map[string]string{"tee-type": "snp"}), - ) - dialer := dialer.NewWithKey(atls.NoIssuer, validator, &net.Dialer{}, workloadOwnerKey) + + var validators []atls.Validator + for _, gen := range optsGens { + validators = append(validators, snp.NewValidator(gen.WithStaticHostData(flags.policy), kdsGetter, + logger.NewWithAttrs(logger.NewNamed(log, "validator"), map[string]string{"tee-type": "snp"}), + )) + } + dialer := dialer.NewWithKey(atls.NoIssuer, validators, &net.Dialer{}, workloadOwnerKey) log.Debug("Dialing coordinator", "endpoint", flags.coordinator) conn, err := dialer.Dial(cmd.Context(), flags.coordinator) diff --git a/cli/cmd/set.go b/cli/cmd/set.go index b74400f7d5..3806711a07 100644 --- a/cli/cmd/set.go +++ b/cli/cmd/set.go @@ -106,17 +106,21 @@ func runSet(cmd *cobra.Command, args []string) error { return fmt.Errorf("getting cache dir: %w", err) } log.Debug("Using KDS cache dir", "dir", kdsDir) + kdsCache := fsstore.New(kdsDir, log.WithGroup("kds-cache")) + kdsGetter := snp.NewCachedHTTPSGetter(kdsCache, snp.NeverGCTicker, log.WithGroup("kds-getter")) - validateOptsGen, err := newCoordinatorValidateOptsGen(m, flags.policy) + optsGens, err := m.SNPValidateOpts() if err != nil { - return fmt.Errorf("generating validate opts: %w", err) + return fmt.Errorf("getting AKS validate options: %w", err) } - kdsCache := fsstore.New(kdsDir, log.WithGroup("kds-cache")) - kdsGetter := snp.NewCachedHTTPSGetter(kdsCache, snp.NeverGCTicker, log.WithGroup("kds-getter")) - validator := snp.NewValidator(validateOptsGen, kdsGetter, - logger.NewWithAttrs(logger.NewNamed(log, "validator"), map[string]string{"tee-type": "snp"}), - ) - dialer := dialer.NewWithKey(atls.NoIssuer, validator, &net.Dialer{}, workloadOwnerKey) + + var validators []atls.Validator + for _, gen := range optsGens { + validators = append(validators, snp.NewValidator(gen.WithStaticHostData(flags.policy), kdsGetter, + logger.NewWithAttrs(logger.NewNamed(log, "validator"), map[string]string{"tee-type": "snp"}), + )) + } + dialer := dialer.NewWithKey(atls.NoIssuer, validators, &net.Dialer{}, workloadOwnerKey) conn, err := dialer.Dial(cmd.Context(), flags.coordinator) if err != nil { diff --git a/cli/cmd/verify.go b/cli/cmd/verify.go index adb06568c5..34bb1146b0 100644 --- a/cli/cmd/verify.go +++ b/cli/cmd/verify.go @@ -76,17 +76,21 @@ func runVerify(cmd *cobra.Command, _ []string) error { return fmt.Errorf("getting cache dir: %w", err) } log.Debug("Using KDS cache dir", "dir", kdsDir) + kdsCache := fsstore.New(kdsDir, log.WithGroup("kds-cache")) + kdsGetter := snp.NewCachedHTTPSGetter(kdsCache, snp.NeverGCTicker, log.WithGroup("kds-getter")) - validateOptsGen, err := newCoordinatorValidateOptsGen(m, flags.policy) + optsGens, err := m.SNPValidateOpts() if err != nil { - return fmt.Errorf("generating validate opts: %w", err) + return fmt.Errorf("getting AKS validate options: %w", err) } - kdsCache := fsstore.New(kdsDir, log.WithGroup("kds-cache")) - kdsGetter := snp.NewCachedHTTPSGetter(kdsCache, snp.NeverGCTicker, log.WithGroup("kds-getter")) - validator := snp.NewValidator(validateOptsGen, kdsGetter, - logger.NewWithAttrs(logger.NewNamed(log, "validator"), map[string]string{"tee-type": "snp"}), - ) - dialer := dialer.New(atls.NoIssuer, validator, &net.Dialer{}) + + var validators []atls.Validator + for _, gen := range optsGens { + validators = append(validators, snp.NewValidator(gen.WithStaticHostData(flags.policy), kdsGetter, + logger.NewWithAttrs(logger.NewNamed(log, "validator"), map[string]string{"tee-type": "snp"}), + )) + } + dialer := dialer.New(atls.NoIssuer, validators, &net.Dialer{}) log.Debug("Dialing coordinator", "endpoint", flags.coordinator) conn, err := dialer.Dial(cmd.Context(), flags.coordinator) @@ -174,17 +178,6 @@ func parseVerifyFlags(cmd *cobra.Command) (*verifyFlags, error) { }, nil } -func newCoordinatorValidateOptsGen(mnfst manifest.Manifest, hostData []byte) (*snp.StaticValidateOptsGenerator, error) { - validateOpts, err := mnfst.AKSValidateOpts() - if err != nil { - return nil, err - } - validateOpts.HostData = hostData - return &snp.StaticValidateOptsGenerator{ - Opts: validateOpts, - }, nil -} - func writeFilelist(dir string, filelist map[string][]byte) error { if dir != "" { if err := os.MkdirAll(dir, 0o755); err != nil { diff --git a/coordinator/internal/authority/authority.go b/coordinator/internal/authority/authority.go index fb52ed217e..4c0a275a3e 100644 --- a/coordinator/internal/authority/authority.go +++ b/coordinator/internal/authority/authority.go @@ -15,8 +15,6 @@ import ( "github.com/edgelesssys/contrast/internal/ca" "github.com/edgelesssys/contrast/internal/manifest" "github.com/edgelesssys/contrast/internal/userapi" - "github.com/google/go-sev-guest/proto/sevsnp" - "github.com/google/go-sev-guest/validate" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" ) @@ -179,19 +177,3 @@ type State struct { latest *history.LatestTransition generation int } - -// SNPValidateOpts returns SNP validation options from reference values. -// -// It also ensures that the policy hash in the report's HOSTDATA is allowed by the current -// manifest. -// TODO(msanft): make the manifest authoritative and allow other types of reference values. -func (s *State) SNPValidateOpts(report *sevsnp.Report) (*validate.Options, error) { - mnfst := s.Manifest - - hostData := manifest.NewHexString(report.HostData) - if _, ok := mnfst.Policies[hostData]; !ok { - return nil, fmt.Errorf("hostdata %s not found in manifest", hostData) - } - - return mnfst.AKSValidateOpts() -} diff --git a/coordinator/internal/authority/authority_test.go b/coordinator/internal/authority/authority_test.go index e7cd8beb81..8d391d6bc3 100644 --- a/coordinator/internal/authority/authority_test.go +++ b/coordinator/internal/authority/authority_test.go @@ -47,16 +47,20 @@ func TestSNPValidateOpts(t *testing.T) { _, err := a.SetManifest(context.Background(), req) require.NoError(err) - opts, err := a.state.Load().SNPValidateOpts(report) + gens, err := a.state.Load().Manifest.SNPValidateOpts() require.NoError(err) - require.NotNil(opts) + require.NotNil(gens) // Change to unknown policy hash in HostData. report.HostData[0]++ - opts, err = a.state.Load().SNPValidateOpts(report) + gens, err = a.state.Load().Manifest.SNPValidateOpts() + require.NoError(err) + require.NotNil(gens) + + gen := gens[0].WithReportHostData() + _, err = gen.SNPValidateOpts(report) require.Error(err) - require.Nil(opts) } // TODO(burgerdev): test ValidateCallback and GetCertBundle diff --git a/coordinator/internal/authority/credentials.go b/coordinator/internal/authority/credentials.go index 9d1add75c3..effb3b9cac 100644 --- a/coordinator/internal/authority/credentials.go +++ b/coordinator/internal/authority/credentials.go @@ -72,11 +72,19 @@ func (c *Credentials) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.A authInfo := AuthInfo{State: state} - validator := snp.NewValidatorWithCallbacks(state, c.kdsGetter, - logger.NewWithAttrs(logger.NewNamed(c.logger, "validator"), map[string]string{"tee-type": "snp"}), - c.attestationFailuresCounter, &authInfo) + optsGens, err := state.Manifest.SNPValidateOpts() + if err != nil { + return nil, nil, fmt.Errorf("generating SNP validation options: %w", err) + } - serverCfg, err := atls.CreateAttestationServerTLSConfig(c.issuer, []atls.Validator{validator}) + var validators []atls.Validator + for _, gen := range optsGens { + validator := snp.NewValidatorWithCallbacks(gen.WithReportHostData(), c.kdsGetter, + logger.NewWithAttrs(logger.NewNamed(c.logger, "validator"), map[string]string{"tee-type": "snp"}), + c.attestationFailuresCounter, &authInfo) + validators = append(validators, validator) + } + serverCfg, err := atls.CreateAttestationServerTLSConfig(c.issuer, validators) if err != nil { return nil, nil, err } diff --git a/initializer/main.go b/initializer/main.go index 5d423f7e65..441c662c50 100644 --- a/initializer/main.go +++ b/initializer/main.go @@ -60,7 +60,9 @@ func run() (retErr error) { } requestCert := func() (*meshapi.NewMeshCertResponse, error) { - dial := dialer.NewWithKey(issuer, atls.NoValidator, &net.Dialer{}, privKey) + // Supply an empty list of validators, as the coordinator does not need to be + // validated by the initializer. + dial := dialer.NewWithKey(issuer, []atls.Validator{}, &net.Dialer{}, privKey) conn, err := dial.Dial(ctx, net.JoinHostPort(coordinatorHostname, meshapi.Port)) if err != nil { return nil, fmt.Errorf("dialing: %w", err) diff --git a/internal/atls/atls.go b/internal/atls/atls.go index 4b9490343c..5f17c42ea9 100644 --- a/internal/atls/atls.go +++ b/internal/atls/atls.go @@ -28,8 +28,6 @@ import ( const attestationTimeout = 30 * time.Second var ( - // NoValidator skips validation of the server's attestation document. - NoValidator Validator // NoIssuer skips embedding the client's attestation document. NoIssuer Issuer diff --git a/internal/attestation/snp/validator.go b/internal/attestation/snp/validator.go index 431d537f76..17fcb238b0 100644 --- a/internal/attestation/snp/validator.go +++ b/internal/attestation/snp/validator.go @@ -44,13 +44,13 @@ type validateOptsGenerator interface { SNPValidateOpts(report *sevsnp.Report) (*validate.Options, error) } -// StaticValidateOptsGenerator returns validate.Options generator that returns +// StaticValidateOptsGenerator is a [validate.Options] generator that returns // static validation options. type StaticValidateOptsGenerator struct { Opts *validate.Options } -// SNPValidateOpts return the SNP validation options. +// SNPValidateOpts returns the SNP validation options. func (v *StaticValidateOptsGenerator) SNPValidateOpts(_ *sevsnp.Report) (*validate.Options, error) { return v.Opts, nil } @@ -65,13 +65,13 @@ func NewValidator(optsGen validateOptsGenerator, kdsGetter trust.HTTPSGetter, lo } // NewValidatorWithCallbacks returns a new Validator with callbacks. -func NewValidatorWithCallbacks(optsGen validateOptsGenerator, kdsGetter trust.HTTPSGetter, log *slog.Logger, attestataionFailures prometheus.Counter, callbacks ...validateCallbacker) *Validator { +func NewValidatorWithCallbacks(optsGen validateOptsGenerator, kdsGetter trust.HTTPSGetter, log *slog.Logger, attestationFailures prometheus.Counter, callbacks ...validateCallbacker) *Validator { return &Validator{ validateOptsGen: optsGen, callbackers: callbacks, kdsGetter: kdsGetter, logger: log, - metrics: metrics{attestationFailures: attestataionFailures}, + metrics: metrics{attestationFailures: attestationFailures}, } } diff --git a/internal/grpc/dialer/dialer.go b/internal/grpc/dialer/dialer.go index 1ca6e0f622..e33c8521e0 100644 --- a/internal/grpc/dialer/dialer.go +++ b/internal/grpc/dialer/dialer.go @@ -18,38 +18,34 @@ import ( // Dialer can open grpc client connections with different levels of ATLS encryption / verification. type Dialer struct { - issuer atls.Issuer - validator atls.Validator - netDialer NetDialer - privKey *ecdsa.PrivateKey + issuer atls.Issuer + validators []atls.Validator + netDialer NetDialer + privKey *ecdsa.PrivateKey } // New creates a new Dialer. -func New(issuer atls.Issuer, validator atls.Validator, netDialer NetDialer) *Dialer { +func New(issuer atls.Issuer, validators []atls.Validator, netDialer NetDialer) *Dialer { return &Dialer{ - issuer: issuer, - validator: validator, - netDialer: netDialer, + issuer: issuer, + validators: validators, + netDialer: netDialer, } } // NewWithKey creates a new Dialer with the given private key. -func NewWithKey(issuer atls.Issuer, validator atls.Validator, netDialer NetDialer, privKey *ecdsa.PrivateKey) *Dialer { +func NewWithKey(issuer atls.Issuer, validators []atls.Validator, netDialer NetDialer, privKey *ecdsa.PrivateKey) *Dialer { return &Dialer{ - issuer: issuer, - validator: validator, - netDialer: netDialer, - privKey: privKey, + issuer: issuer, + validators: validators, + netDialer: netDialer, + privKey: privKey, } } // Dial creates a new grpc client connection to the given target using the atls validator. func (d *Dialer) Dial(_ context.Context, target string) (*grpc.ClientConn, error) { - var validators []atls.Validator - if d.validator != nil { - validators = append(validators, d.validator) - } - credentials := atlscredentials.NewWithKey(d.issuer, validators, d.privKey) + credentials := atlscredentials.NewWithKey(d.issuer, d.validators, d.privKey) return grpc.NewClient(target, d.grpcWithDialer(), diff --git a/internal/manifest/constants.go b/internal/manifest/constants.go index a9478b236b..d28f807190 100644 --- a/internal/manifest/constants.go +++ b/internal/manifest/constants.go @@ -13,27 +13,13 @@ import ( // Default returns a default manifest with reference values for the given platform. func Default(platform platforms.Platform) (*Manifest, error) { embeddedRefValues := GetEmbeddedReferenceValues() + refValues, err := embeddedRefValues.ForPlatform(platform) if err != nil { return nil, fmt.Errorf("get reference values for platform %s: %w", platform, err) } - mnfst := Manifest{} - switch platform { - case platforms.AKSCloudHypervisorSNP: - return &Manifest{ - ReferenceValues: ReferenceValues{ - AKS: refValues.AKS, - }, - }, nil - case platforms.RKE2QEMUTDX, platforms.K3sQEMUTDX: - return &Manifest{ - ReferenceValues: ReferenceValues{ - BareMetalTDX: refValues.BareMetalTDX, - }, - }, nil - } - return &mnfst, nil + return &Manifest{ReferenceValues: *refValues}, nil } // GetEmbeddedReferenceValues returns the reference values embedded in the binary. diff --git a/internal/manifest/manifest.go b/internal/manifest/manifest.go index 8bd5b7214f..431333b25a 100644 --- a/internal/manifest/manifest.go +++ b/internal/manifest/manifest.go @@ -6,10 +6,12 @@ package manifest import ( "crypto/sha256" "encoding/base64" + "errors" "fmt" "github.com/google/go-sev-guest/abi" "github.com/google/go-sev-guest/kds" + "github.com/google/go-sev-guest/proto/sevsnp" "github.com/google/go-sev-guest/validate" ) @@ -17,7 +19,7 @@ import ( type Manifest struct { // policyHash/HOSTDATA -> commonName Policies map[HexString]PolicyEntry - ReferenceValues ReferenceValues + ReferenceValues []ReferenceValues WorkloadOwnerKeyDigests []HexString SeedshareOwnerPubKeys []HexString } @@ -65,18 +67,18 @@ func (p Policy) Hash() HexString { // Validate checks the validity of all fields in the reference values. func (r ReferenceValues) Validate() error { - if r.AKS != nil { - if err := r.AKS.Validate(); err != nil { + if r.SNP != nil { + if err := r.SNP.Validate(); err != nil { return fmt.Errorf("validating AKS reference values: %w", err) } } - if r.BareMetalTDX != nil { - if err := r.BareMetalTDX.Validate(); err != nil { + if r.TDX != nil { + if err := r.TDX.Validate(); err != nil { return fmt.Errorf("validating bare metal TDX reference values: %w", err) } } - if r.BareMetalTDX == nil && r.AKS == nil { + if r.TDX == nil && r.SNP == nil { return fmt.Errorf("reference values in manifest cannot be empty. Is the chosen platform supported?") } @@ -84,14 +86,14 @@ func (r ReferenceValues) Validate() error { } // Validate checks the validity of all fields in the AKS reference values. -func (r AKSReferenceValues) Validate() error { - if r.SNP.MinimumTCB.BootloaderVersion == nil { +func (r SNPReferenceValues) Validate() error { + if r.MinimumTCB.BootloaderVersion == nil { return fmt.Errorf("field BootloaderVersion in manifest cannot be empty") - } else if r.SNP.MinimumTCB.TEEVersion == nil { + } else if r.MinimumTCB.TEEVersion == nil { return fmt.Errorf("field TEEVersion in manifest cannot be empty") - } else if r.SNP.MinimumTCB.SNPVersion == nil { + } else if r.MinimumTCB.SNPVersion == nil { return fmt.Errorf("field SNPVersion in manifest cannot be empty") - } else if r.SNP.MinimumTCB.MicrocodeVersion == nil { + } else if r.MinimumTCB.MicrocodeVersion == nil { return fmt.Errorf("field MicrocodeVersion in manifest cannot be empty") } @@ -103,7 +105,7 @@ func (r AKSReferenceValues) Validate() error { } // Validate checks the validity of all fields in the bare metal TDX reference values. -func (r BareMetalTDXReferenceValues) Validate() error { +func (r TDXReferenceValues) Validate() error { if r.TrustedMeasurement == "" { return fmt.Errorf("field TrustedMeasurement in manifest cannot be empty") } @@ -120,8 +122,10 @@ func (m *Manifest) Validate() error { } } - if err := m.ReferenceValues.Validate(); err != nil { - return fmt.Errorf("validating reference values: %w", err) + for i, rv := range m.ReferenceValues { + if err := rv.Validate(); err != nil { + return fmt.Errorf("validating reference values [%d]: %w", i, err) + } } for _, keyDigest := range m.WorkloadOwnerKeyDigests { @@ -140,40 +144,94 @@ func (m *Manifest) Validate() error { return nil } -// AKSValidateOpts returns validate options populated with the manifest's -// AKS reference values and trusted measurement. -func (m *Manifest) AKSValidateOpts() (*validate.Options, error) { - if m.ReferenceValues.AKS == nil { - return nil, fmt.Errorf("no AKS reference values present in manifest") +// SNPValidateOptsGenerator generates SNP validation options and +// can be instantiated from a manifest only. +type SNPValidateOptsGenerator struct { + opts *validate.Options + manifest *Manifest + extraChecks []func(report *sevsnp.Report) error // additional checks that need to pass for the validation to succeed. +} + +// SNPValidateOpts returns the SNP validation options. +func (g *SNPValidateOptsGenerator) SNPValidateOpts(report *sevsnp.Report) (*validate.Options, error) { + for _, check := range g.extraChecks { + if err := check(report); err != nil { + return nil, fmt.Errorf("additional check failed: %w", err) + } + } + return g.opts, nil +} + +// TODO(msanft): add generic validation interface for other attestation types. + +// SNPValidateOpts returns validate options generators populated with the manifest's +// SNP reference values and trusted measurement for the given runtime. +func (m *Manifest) SNPValidateOpts() ([]*SNPValidateOptsGenerator, error) { + if m.ReferenceValues == nil { + return nil, errors.New("reference values cannot be empty") } if err := m.Validate(); err != nil { return nil, fmt.Errorf("validating manifest: %w", err) } - trustedMeasurement, err := m.ReferenceValues.AKS.TrustedMeasurement.Bytes() - if err != nil { - return nil, fmt.Errorf("failed to convert TrustedMeasurement from manifest to byte slices: %w", err) - } - - return &validate.Options{ - Measurement: trustedMeasurement, - GuestPolicy: abi.SnpPolicy{ - Debug: false, - SMT: true, - }, - VMPL: new(int), // VMPL0 - MinimumTCB: kds.TCBParts{ - BlSpl: m.ReferenceValues.AKS.SNP.MinimumTCB.BootloaderVersion.UInt8(), - TeeSpl: m.ReferenceValues.AKS.SNP.MinimumTCB.TEEVersion.UInt8(), - SnpSpl: m.ReferenceValues.AKS.SNP.MinimumTCB.SNPVersion.UInt8(), - UcodeSpl: m.ReferenceValues.AKS.SNP.MinimumTCB.MicrocodeVersion.UInt8(), - }, - MinimumLaunchTCB: kds.TCBParts{ - BlSpl: m.ReferenceValues.AKS.SNP.MinimumTCB.BootloaderVersion.UInt8(), - TeeSpl: m.ReferenceValues.AKS.SNP.MinimumTCB.TEEVersion.UInt8(), - SnpSpl: m.ReferenceValues.AKS.SNP.MinimumTCB.SNPVersion.UInt8(), - UcodeSpl: m.ReferenceValues.AKS.SNP.MinimumTCB.MicrocodeVersion.UInt8(), - }, - PermitProvisionalFirmware: true, - }, nil + + var out []*SNPValidateOptsGenerator + for _, refVal := range m.ReferenceValues { + trustedMeasurement, err := refVal.SNP.TrustedMeasurement.Bytes() + if err != nil { + return nil, fmt.Errorf("failed to convert TrustedMeasurement from manifest to byte slices: %w", err) + } + + out = append(out, &SNPValidateOptsGenerator{ + manifest: m, + opts: &validate.Options{ + Measurement: trustedMeasurement, + GuestPolicy: abi.SnpPolicy{ + Debug: false, + SMT: true, + }, + VMPL: new(int), // VMPL0 + MinimumTCB: kds.TCBParts{ + BlSpl: refVal.SNP.MinimumTCB.BootloaderVersion.UInt8(), + TeeSpl: refVal.SNP.MinimumTCB.TEEVersion.UInt8(), + SnpSpl: refVal.SNP.MinimumTCB.SNPVersion.UInt8(), + UcodeSpl: refVal.SNP.MinimumTCB.MicrocodeVersion.UInt8(), + }, + MinimumLaunchTCB: kds.TCBParts{ + BlSpl: refVal.SNP.MinimumTCB.BootloaderVersion.UInt8(), + TeeSpl: refVal.SNP.MinimumTCB.TEEVersion.UInt8(), + SnpSpl: refVal.SNP.MinimumTCB.SNPVersion.UInt8(), + UcodeSpl: refVal.SNP.MinimumTCB.MicrocodeVersion.UInt8(), + }, + PermitProvisionalFirmware: true, + }, + }) + } + + if len(out) == 0 { + return nil, errors.New("no AKS reference values found in manifest") + } + + return out, nil +} + +// WithReportHostData augments the validate options generator with +// a check that verifies whether the policy hash in the report's +// HOSTDATA is allowed by the manifest. +func (g *SNPValidateOptsGenerator) WithReportHostData() *SNPValidateOptsGenerator { + g.extraChecks = append(g.extraChecks, func(report *sevsnp.Report) error { + hostData := NewHexString(report.HostData) + if _, ok := g.manifest.Policies[hostData]; !ok { + return fmt.Errorf("hostdata %s not found in manifest", hostData) + } + return nil + }) + return g +} + +// WithStaticHostData augments the validate options with +// the given host data. +func (g *SNPValidateOptsGenerator) WithStaticHostData(hostData []byte) *SNPValidateOptsGenerator { + g.opts.HostData = hostData + return g } diff --git a/internal/manifest/manifest_test.go b/internal/manifest/manifest_test.go index a6f98589e4..36e266352d 100644 --- a/internal/manifest/manifest_test.go +++ b/internal/manifest/manifest_test.go @@ -92,10 +92,12 @@ func TestValidate(t *testing.T) { { m: &Manifest{ Policies: map[HexString]PolicyEntry{HexString(""): {}}, - ReferenceValues: ReferenceValues{ - AKS: &AKSReferenceValues{ - SNP: mnf.ReferenceValues.AKS.SNP, - TrustedMeasurement: "", + ReferenceValues: []ReferenceValues{ + { + SNP: &SNPReferenceValues{ + MinimumTCB: mnf.ReferenceValues[0].SNP.MinimumTCB, + TrustedMeasurement: "", + }, }, }, }, @@ -109,6 +111,7 @@ func TestValidate(t *testing.T) { wantErr: true, }, } + for i, tc := range testCases { t.Run(strconv.Itoa(i), func(t *testing.T) { assert := assert.New(t) @@ -128,18 +131,20 @@ func TestAKSValidateOpts(t *testing.T) { m, err := Default(platforms.AKSCloudHypervisorSNP) require.NoError(t, err) - opts, err := m.AKSValidateOpts() + optsGen, err := m.SNPValidateOpts() assert.NoError(err) + assert.Len(optsGen, 1) - tcb := m.ReferenceValues.AKS.SNP.MinimumTCB + tcb := m.ReferenceValues[0].SNP.MinimumTCB assert.NotNil(tcb.BootloaderVersion) assert.NotNil(tcb.TEEVersion) assert.NotNil(tcb.SNPVersion) assert.NotNil(tcb.MicrocodeVersion) - trustedMeasurement, err := m.ReferenceValues.AKS.TrustedMeasurement.Bytes() + trustedMeasurement, err := m.ReferenceValues[0].SNP.TrustedMeasurement.Bytes() assert.NoError(err) - assert.Equal(trustedMeasurement, opts.Measurement) + + assert.Equal(trustedMeasurement, optsGen[0].opts.Measurement) tcbParts := kds.TCBParts{ BlSpl: tcb.BootloaderVersion.UInt8(), @@ -147,6 +152,6 @@ func TestAKSValidateOpts(t *testing.T) { SnpSpl: tcb.SNPVersion.UInt8(), UcodeSpl: tcb.MicrocodeVersion.UInt8(), } - assert.Equal(tcbParts, opts.MinimumTCB) - assert.Equal(tcbParts, opts.MinimumLaunchTCB) + assert.Equal(tcbParts, optsGen[0].opts.MinimumTCB) + assert.Equal(tcbParts, optsGen[0].opts.MinimumLaunchTCB) } diff --git a/internal/manifest/referencevalues.go b/internal/manifest/referencevalues.go index e63503ea6a..5a0de0b10a 100644 --- a/internal/manifest/referencevalues.go +++ b/internal/manifest/referencevalues.go @@ -19,35 +19,30 @@ import ( //go:embed assets/reference-values.json var EmbeddedReferenceValuesJSON []byte -// ReferenceValues contains the workload-independent reference values for each platform. +// ReferenceValues contains the workload-independent reference values for each TEE type. type ReferenceValues struct { - // AKS holds the reference values for AKS. - AKS *AKSReferenceValues `json:"aks,omitempty"` - // BareMetalTDX holds the reference values for TDX on bare metal. - BareMetalTDX *BareMetalTDXReferenceValues `json:"bareMetalTDX,omitempty"` + // SNP holds the reference values for SNP. + SNP *SNPReferenceValues `json:"snp,omitempty"` + // TDX holds the reference values for TDX. + TDX *TDXReferenceValues `json:"tdx,omitempty"` } -// EmbeddedReferenceValues is a map of runtime handler names to reference values, as -// embedded in the binary. -type EmbeddedReferenceValues map[string]ReferenceValues +// EmbeddedReferenceValues is a map of runtime handler names to a list of reference values +// for the runtime handler, as embedded in the binary. +type EmbeddedReferenceValues map[string][]ReferenceValues -// AKSReferenceValues contains reference values for AKS. -type AKSReferenceValues struct { - SNP SNPReferenceValues +// SNPReferenceValues contains reference values for SEV-SNP. +type SNPReferenceValues struct { + MinimumTCB SNPTCB TrustedMeasurement HexString } -// BareMetalTDXReferenceValues contains reference values for BareMetalTDX. -type BareMetalTDXReferenceValues struct { +// TDXReferenceValues contains reference values for TDX. +type TDXReferenceValues struct { TrustedMeasurement HexString } -// SNPReferenceValues contains reference values for the SNP report. -type SNPReferenceValues struct { - MinimumTCB SNPTCB -} - -// SNPTCB represents a set of SNP TCB values. +// SNPTCB represents a set of SEV-SNP TCB values. type SNPTCB struct { BootloaderVersion *SVN TEEVersion *SVN @@ -102,7 +97,7 @@ func (h HexString) Bytes() ([]byte, error) { } // ForPlatform returns the reference values for the given platform. -func (e *EmbeddedReferenceValues) ForPlatform(platform platforms.Platform) (*ReferenceValues, error) { +func (e *EmbeddedReferenceValues) ForPlatform(platform platforms.Platform) (*[]ReferenceValues, error) { var mapping EmbeddedReferenceValues if err := json.Unmarshal(EmbeddedReferenceValuesJSON, &mapping); err != nil { return nil, fmt.Errorf("unmarshal embedded reference values mapping: %w", err) diff --git a/packages/by-name/contrast/package.nix b/packages/by-name/contrast/package.nix index 518bdc06d0..a645faf2d8 100644 --- a/packages/by-name/contrast/package.nix +++ b/packages/by-name/contrast/package.nix @@ -54,8 +54,8 @@ let rke2-qemu-tdx-handler = runtimeHandler "rke2-qemu-tdx" "${kata.runtime-class-files}/runtime-hash-tdx.hex"; k3s-qemu-snp-handler = runtimeHandler "k3s-qemu-snp" "${kata.runtime-class-files}/runtime-hash-snp.hex"; - aksRefVals = { - aks = { + aksRefVals = [ + { snp = { minimumTCB = { bootloaderVersion = 3; @@ -63,27 +63,33 @@ let snpVersion = 8; microcodeVersion = 115; }; + trustedMeasurement = lib.removeSuffix "\n" ( + builtins.readFile "${microsoft.runtime-class-files}/launch-digest.hex" + ); }; - trustedMeasurement = lib.removeSuffix "\n" ( - builtins.readFile "${microsoft.runtime-class-files}/launch-digest.hex" - ); - }; - }; - - snpRefVals = { - inherit (aksRefVals.aks) snp; - trustedMeasurement = lib.removeSuffix "\n" ( - builtins.readFile "${kata.runtime-class-files}/launch-digest-snp.hex" - ); - }; - - tdxRefVals = { - bareMetalTDX = { - trustedMeasurement = lib.removeSuffix "\n" ( - builtins.readFile "${kata.runtime-class-files}/launch-digest-tdx.hex" - ); - }; - }; + } + ]; + + snpRefVals = [ + { + snp = { + inherit ((builtins.head aksRefVals).snp) minimumTCB; + trustedMeasurement = lib.removeSuffix "\n" ( + builtins.readFile "${kata.runtime-class-files}/launch-digest-snp.hex" + ); + }; + } + ]; + + tdxRefVals = [ + { + tdx = { + trustedMeasurement = lib.removeSuffix "\n" ( + builtins.readFile "${kata.runtime-class-files}/launch-digest-tdx.hex" + ); + }; + } + ]; in builtins.toFile "reference-values.json" ( builtins.toJSON {