Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

attestation: add name to Validator as unique identifier #1095

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions coordinator/internal/authority/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
"fmt"
"log/slog"
"net"
"strconv"
"strings"
"time"

"github.com/edgelesssys/contrast/internal/atls"
Expand Down Expand Up @@ -86,10 +88,11 @@ func (c *Credentials) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.A
return nil, nil, fmt.Errorf("generating SNP validation options: %w", err)
}

for _, opt := range opts {
for i, opt := range opts {
name := "snp-" + strconv.Itoa(i) + "-" + strings.TrimPrefix(opt.VerifyOpts.Product.Name.String(), "SEV_PRODUCT_")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
name := "snp-" + strconv.Itoa(i) + "-" + strings.TrimPrefix(opt.VerifyOpts.Product.Name.String(), "SEV_PRODUCT_")
name := fmt.Sprintf("snp-%d-%s", i, strings.TrimPrefix(opt.VerifyOpts.Product.Name.String(), "SEV_PRODUCT_"))

validator := snp.NewValidatorWithReportSetter(opt.VerifyOpts, opt.ValidateOpts,
logger.NewWithAttrs(logger.NewNamed(c.logger, "validator"), map[string]string{"tee-type": "snp"}),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that we have a name in the constructor, we could remove the attrs here and add them there?

&authInfo)
&authInfo, name)
validators = append(validators, validator)
}

Expand All @@ -98,9 +101,10 @@ func (c *Credentials) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.A
log.Error("Could not generate TDX validation options", "error", err)
return nil, nil, fmt.Errorf("generating TDX validation options: %w", err)
}
for _, opt := range tdxOpts {
for i, opt := range tdxOpts {
name := "tdx" + strconv.Itoa(i)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
name := "tdx" + strconv.Itoa(i)
name := fmt.Sprintf("tdx-%d", i)

validators = append(validators, tdx.NewValidatorWithReportSetter(&tdx.StaticValidateOptsGenerator{Opts: opt},
logger.NewWithAttrs(logger.NewNamed(c.logger, "validator"), map[string]string{"tee-type": "tdx"}), &authInfo))
logger.NewWithAttrs(logger.NewNamed(c.logger, "validator"), map[string]string{"tee-type": "tdx"}), &authInfo, name))
}

serverCfg, err := atls.CreateAttestationServerTLSConfig(c.issuer, validators, c.attestationFailuresCounter)
Expand Down
17 changes: 15 additions & 2 deletions internal/atls/atls.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"crypto/x509/pkix"
"encoding/asn1"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
Expand Down Expand Up @@ -94,6 +95,7 @@ type Issuer interface {
type Validator interface {
Getter
Validate(attDoc []byte, nonce []byte, peerPublicKey []byte) error
Name() string
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More idiomatic might be to implement fmt.Stringer.

}

// getATLSConfigForClientFunc returns a config setup function that is called once for every client connecting to the server.
Expand Down Expand Up @@ -233,7 +235,7 @@ func verifyEmbeddedReport(validators []Validator, cert *x509.Certificate, peerPu

// We have a valid attestation document. Let's check it against all applicable validators.
foundExtension = true
for _, validator := range validators {
for i, validator := range validators {
// Optimization: Skip the validator if it doesn't match the attestation type of the document.
if !ex.Id.Equal(validator.OID()) {
continue
Expand All @@ -248,7 +250,13 @@ func verifyEmbeddedReport(validators []Validator, cert *x509.Certificate, peerPu
return nil
}
// Otherwise, we'll keep track of the error and continue with the next validator.
retErr = errors.Join(retErr, fmt.Errorf("validator %s failed: %w", validator.OID(), validationErr))
// The joined error should reveal the atls nonce once to maintain readability. Because the error is only revealed if all validators fail,
// we can implicitly include the nonce in the first validator error and the concatenate the other errors.
Comment on lines +253 to +254
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would make more sense to first collect the errors with Join, and then wrap them with a message that contains the nonce before returning from this function.

if i == 0 {
retErr = errors.Join(retErr, fmt.Errorf("AtlsConnectionNonce: %s || validator %s failed: %w ||", hex.EncodeToString(nonce), validator.Name(), validationErr))
} else {
retErr = errors.Join(retErr, fmt.Errorf(" validator %s failed: %w ||", validator.Name(), validationErr))
}
}
}

Expand Down Expand Up @@ -439,6 +447,11 @@ func (v FakeValidator) Validate(attDoc []byte, nonce []byte, _ []byte) error {
return v.err
}

// Name returns the name of the validator.
func (v *FakeValidator) Name() string {
return ""
}

// FakeAttestationDoc is a fake attestation document used for testing.
type FakeAttestationDoc struct {
UserData []byte
Expand Down
20 changes: 14 additions & 6 deletions internal/attestation/snp/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,42 +26,45 @@ type Validator struct {
validateOpts *validate.Options
reportSetter attestation.ReportSetter
logger *slog.Logger
name string
}

// NewValidator returns a new Validator.
func NewValidator(VerifyOpts *verify.Options, ValidateOpts *validate.Options, log *slog.Logger) *Validator {
func NewValidator(VerifyOpts *verify.Options, ValidateOpts *validate.Options, log *slog.Logger, name string) *Validator {
return &Validator{
verifyOpts: VerifyOpts,
validateOpts: ValidateOpts,
logger: log,
name: name,
}
}

// NewValidatorWithReportSetter returns a new Validator with a report setter.
func NewValidatorWithReportSetter(VerifyOpts *verify.Options, ValidateOpts *validate.Options,
log *slog.Logger, reportSetter attestation.ReportSetter,
log *slog.Logger, reportSetter attestation.ReportSetter, name string,
) *Validator {
return &Validator{
verifyOpts: VerifyOpts,
validateOpts: ValidateOpts,
reportSetter: reportSetter,
logger: log,
name: name,
}
}

// OID returns the OID of the validator.
// OID returns the root OID for the raw SNP report extension used by the validator.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why root? And, if you are changing the docstring here, please also change it for TDX.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm. I'd say it's technically a node, but not the root node.

func (v *Validator) OID() asn1.ObjectIdentifier {
return oid.RawSNPReport
}

// Validate a TPM based attestation.
func (v *Validator) Validate(attDocRaw []byte, nonce []byte, peerPublicKey []byte) (err error) {
v.logger.Info("Validate called", "nonce", hex.EncodeToString(nonce))
v.logger.Info("Validate called", "name", v.name, "nonce", hex.EncodeToString(nonce))
defer func() {
if err != nil {
v.logger.Error("Validation failed", "error", err)
v.logger.Debug("Validate failed", "name", v.name, "nonce", hex.EncodeToString(nonce), "error", err)
} else {
v.logger.Info("Validation successful")
v.logger.Info("Validate succeeded", "name", v.name, "nonce", hex.EncodeToString(nonce))
}
}()

Expand Down Expand Up @@ -99,6 +102,11 @@ func (v *Validator) Validate(attDocRaw []byte, nonce []byte, peerPublicKey []byt
return nil
}

// Name returns the name of the validator.
func (v *Validator) Name() string {
return v.name
}

type snpReport struct {
report *sevsnp.Report
}
Expand Down
19 changes: 13 additions & 6 deletions internal/attestation/tdx/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ type Validator struct {
validateOptsGen validateOptsGenerator
reportSetter attestation.ReportSetter
logger *slog.Logger
name string
}

type validateOptsGenerator interface {
Expand All @@ -55,16 +56,17 @@ func (v *StaticValidateOptsGenerator) TDXValidateOpts(_ *tdx.QuoteV4) (*validate
}

// NewValidator returns a new Validator.
func NewValidator(optsGen validateOptsGenerator, log *slog.Logger) *Validator {
func NewValidator(optsGen validateOptsGenerator, log *slog.Logger, name string) *Validator {
return &Validator{
validateOptsGen: optsGen,
logger: log,
name: name,
}
}

// NewValidatorWithReportSetter returns a new Validator with a report setter.
func NewValidatorWithReportSetter(optsGen validateOptsGenerator, log *slog.Logger, reportSetter attestation.ReportSetter) *Validator {
v := NewValidator(optsGen, log)
func NewValidatorWithReportSetter(optsGen validateOptsGenerator, log *slog.Logger, reportSetter attestation.ReportSetter, name string) *Validator {
v := NewValidator(optsGen, log, name)
v.reportSetter = reportSetter
return v
}
Expand All @@ -78,12 +80,12 @@ func (v *Validator) OID() asn1.ObjectIdentifier {
func (v *Validator) Validate(attDocRaw []byte, nonce []byte, peerPublicKey []byte) (err error) {
// TODO(freax13): Validate the memory integrity mode (logical vs cryptographic) in the provisioning certificate.

v.logger.Info("Validate called", "nonce", hex.EncodeToString(nonce))
v.logger.Info("Validate called", "name", v.name, "nonce", hex.EncodeToString(nonce))
defer func() {
if err != nil {
v.logger.Error("Validation failed", "error", err)
v.logger.Debug("Validate failed", "name", v.name, "nonce", hex.EncodeToString(nonce), "error", err)
} else {
v.logger.Info("Validation successful")
v.logger.Info("Validate succeeded", "name", v.name, "nonce", hex.EncodeToString(nonce))
}
}()

Expand Down Expand Up @@ -137,6 +139,11 @@ func (v *Validator) Validate(attDocRaw []byte, nonce []byte, peerPublicKey []byt
return nil
}

// Name returns the name of the validator.
func (v *Validator) Name() string {
return v.name
}

func trustedRoots() (*x509.CertPool, error) {
rootCerts := x509.NewCertPool()
if ok := rootCerts.AppendCertsFromPEM(tdxRootCert); !ok {
Expand Down
12 changes: 8 additions & 4 deletions sdk/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ package sdk
import (
"fmt"
"log/slog"
"strconv"
"strings"

"github.com/edgelesssys/contrast/internal/atls"
"github.com/edgelesssys/contrast/internal/attestation/certcache"
Expand All @@ -31,10 +33,11 @@ func ValidatorsFromManifest(kdsDir string, m *manifest.Manifest, log *slog.Logge
if err != nil {
return nil, fmt.Errorf("getting SNP validate options: %w", err)
}
for _, opt := range opts {
for i, opt := range opts {
opt.ValidateOpts.HostData = coordinatorPolicyChecksum
name := "snp-" + strconv.Itoa(i) + "-" + strings.TrimPrefix(opt.VerifyOpts.Product.Name.String(), "SEV_PRODUCT_")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fmt.Sprintf as above

validators = append(validators, snp.NewValidator(opt.VerifyOpts, opt.ValidateOpts,
logger.NewWithAttrs(logger.NewNamed(log, "validator"), map[string]string{"tee-type": "snp"}),
logger.NewWithAttrs(logger.NewNamed(log, "validator"), map[string]string{"tee-type": "snp"}), name,
))
}

Expand All @@ -44,9 +47,10 @@ func ValidatorsFromManifest(kdsDir string, m *manifest.Manifest, log *slog.Logge
}
var mrConfigID [48]byte
copy(mrConfigID[:], coordinatorPolicyChecksum)
for _, opt := range tdxOpts {
for i, opt := range tdxOpts {
name := "tdx-" + strconv.Itoa(i)
opt.TdQuoteBodyOptions.MrConfigID = mrConfigID[:]
validators = append(validators, tdx.NewValidator(&tdx.StaticValidateOptsGenerator{Opts: opt}, logger.NewWithAttrs(logger.NewNamed(log, "validator"), map[string]string{"tee-type": "tdx"})))
validators = append(validators, tdx.NewValidator(&tdx.StaticValidateOptsGenerator{Opts: opt}, logger.NewWithAttrs(logger.NewNamed(log, "validator"), map[string]string{"tee-type": "tdx"}), name))
}

return validators, nil
Expand Down