Skip to content

Commit

Permalink
treewide: allow multiple validators
Browse files Browse the repository at this point in the history
This changes the attestation (as of now, only SEV-SNP) to be passed
multiple validators. The aTLS code already handles multiple validators,
but the code previously passed only one. This way, attestation will now
work by being handed a list of validators, and returning success as soon
as one can successfully validate a report. Furthermore, the
`atls.NoValidator` is now obsolete, and semantically represented by
passing an empty list of validators.
  • Loading branch information
msanft committed Aug 2, 2024
1 parent 0697c37 commit a353ac8
Show file tree
Hide file tree
Showing 17 changed files with 286 additions and 218 deletions.
26 changes: 18 additions & 8 deletions cli/cmd/recover.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ import (
"github.com/edgelesssys/contrast/internal/logger"
"github.com/edgelesssys/contrast/internal/manifest"
"github.com/edgelesssys/contrast/internal/userapi"
"github.com/google/go-sev-guest/proto/sevsnp"
"github.com/google/go-sev-guest/validate"
"github.com/spf13/cobra"
)

Expand Down Expand Up @@ -81,17 +83,25 @@ func runRecover(cmd *cobra.Command, _ []string) error {
return fmt.Errorf("getting cache dir: %w", err)
}
log.Debug("Using KDS cache dir", "dir", kdsDir)
kdsCache := fsstore.New(kdsDir, log.WithGroup("kds-cache"))
kdsGetter := snp.NewCachedHTTPSGetter(kdsCache, snp.NeverGCTicker, log.WithGroup("kds-getter"))

validateOptsGen, err := newCoordinatorValidateOptsGen(m, flags.policy)
validateOpts, err := m.AKSValidateOpts()
if err != nil {
return fmt.Errorf("generating validate opts: %w", err)
return fmt.Errorf("getting AKS validate options: %w", err)
}
kdsCache := fsstore.New(kdsDir, log.WithGroup("kds-cache"))
kdsGetter := snp.NewCachedHTTPSGetter(kdsCache, snp.NeverGCTicker, log.WithGroup("kds-getter"))
validator := snp.NewValidator(validateOptsGen, kdsGetter,
logger.NewWithAttrs(logger.NewNamed(log, "validator"), map[string]string{"tee-type": "snp"}),
)
dialer := dialer.NewWithKey(atls.NoIssuer, validator, &net.Dialer{}, workloadOwnerKey)

var validators []atls.Validator
for _, vo := range validateOpts {
vo.HostData = flags.policy
optsGen := func(_ *sevsnp.Report) (*validate.Options, error) {
return vo, nil
}
validators = append(validators, snp.NewValidator(optsGen, kdsGetter,
logger.NewWithAttrs(logger.NewNamed(log, "validator"), map[string]string{"tee-type": "snp"}),
))
}
dialer := dialer.NewWithKey(atls.NoIssuer, validators, &net.Dialer{}, workloadOwnerKey)

log.Debug("Dialing coordinator", "endpoint", flags.coordinator)
conn, err := dialer.Dial(cmd.Context(), flags.coordinator)
Expand Down
26 changes: 18 additions & 8 deletions cli/cmd/set.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ import (
"github.com/edgelesssys/contrast/internal/retry"
"github.com/edgelesssys/contrast/internal/spinner"
"github.com/edgelesssys/contrast/internal/userapi"
"github.com/google/go-sev-guest/proto/sevsnp"
"github.com/google/go-sev-guest/validate"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
"google.golang.org/grpc/codes"
Expand Down Expand Up @@ -106,17 +108,25 @@ func runSet(cmd *cobra.Command, args []string) error {
return fmt.Errorf("getting cache dir: %w", err)
}
log.Debug("Using KDS cache dir", "dir", kdsDir)
kdsCache := fsstore.New(kdsDir, log.WithGroup("kds-cache"))
kdsGetter := snp.NewCachedHTTPSGetter(kdsCache, snp.NeverGCTicker, log.WithGroup("kds-getter"))

validateOptsGen, err := newCoordinatorValidateOptsGen(m, flags.policy)
validateOpts, err := m.AKSValidateOpts()
if err != nil {
return fmt.Errorf("generating validate opts: %w", err)
return fmt.Errorf("getting AKS validate options: %w", err)
}
kdsCache := fsstore.New(kdsDir, log.WithGroup("kds-cache"))
kdsGetter := snp.NewCachedHTTPSGetter(kdsCache, snp.NeverGCTicker, log.WithGroup("kds-getter"))
validator := snp.NewValidator(validateOptsGen, kdsGetter,
logger.NewWithAttrs(logger.NewNamed(log, "validator"), map[string]string{"tee-type": "snp"}),
)
dialer := dialer.NewWithKey(atls.NoIssuer, validator, &net.Dialer{}, workloadOwnerKey)

var validators []atls.Validator
for _, vo := range validateOpts {
vo.HostData = flags.policy
optsGen := func(_ *sevsnp.Report) (*validate.Options, error) {
return vo, nil
}
validators = append(validators, snp.NewValidator(optsGen, kdsGetter,
logger.NewWithAttrs(logger.NewNamed(log, "validator"), map[string]string{"tee-type": "snp"}),
))
}
dialer := dialer.NewWithKey(atls.NoIssuer, validators, &net.Dialer{}, workloadOwnerKey)

conn, err := dialer.Dial(cmd.Context(), flags.coordinator)
if err != nil {
Expand Down
37 changes: 18 additions & 19 deletions cli/cmd/verify.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ import (
"github.com/edgelesssys/contrast/internal/logger"
"github.com/edgelesssys/contrast/internal/manifest"
"github.com/edgelesssys/contrast/internal/userapi"
"github.com/google/go-sev-guest/proto/sevsnp"
"github.com/google/go-sev-guest/validate"
"github.com/spf13/cobra"
)

Expand Down Expand Up @@ -76,17 +78,25 @@ func runVerify(cmd *cobra.Command, _ []string) error {
return fmt.Errorf("getting cache dir: %w", err)
}
log.Debug("Using KDS cache dir", "dir", kdsDir)
kdsCache := fsstore.New(kdsDir, log.WithGroup("kds-cache"))
kdsGetter := snp.NewCachedHTTPSGetter(kdsCache, snp.NeverGCTicker, log.WithGroup("kds-getter"))

validateOptsGen, err := newCoordinatorValidateOptsGen(m, flags.policy)
validateOpts, err := m.AKSValidateOpts()
if err != nil {
return fmt.Errorf("generating validate opts: %w", err)
return fmt.Errorf("getting AKS validate options: %w", err)
}
kdsCache := fsstore.New(kdsDir, log.WithGroup("kds-cache"))
kdsGetter := snp.NewCachedHTTPSGetter(kdsCache, snp.NeverGCTicker, log.WithGroup("kds-getter"))
validator := snp.NewValidator(validateOptsGen, kdsGetter,
logger.NewWithAttrs(logger.NewNamed(log, "validator"), map[string]string{"tee-type": "snp"}),
)
dialer := dialer.New(atls.NoIssuer, validator, &net.Dialer{})

var validators []atls.Validator
for _, vo := range validateOpts {
vo.HostData = flags.policy
optsGen := func(_ *sevsnp.Report) (*validate.Options, error) {
return vo, nil
}
validators = append(validators, snp.NewValidator(optsGen, kdsGetter,
logger.NewWithAttrs(logger.NewNamed(log, "validator"), map[string]string{"tee-type": "snp"}),
))
}
dialer := dialer.New(atls.NoIssuer, validators, &net.Dialer{})

log.Debug("Dialing coordinator", "endpoint", flags.coordinator)
conn, err := dialer.Dial(cmd.Context(), flags.coordinator)
Expand Down Expand Up @@ -174,17 +184,6 @@ func parseVerifyFlags(cmd *cobra.Command) (*verifyFlags, error) {
}, nil
}

func newCoordinatorValidateOptsGen(mnfst manifest.Manifest, hostData []byte) (*snp.StaticValidateOptsGenerator, error) {
validateOpts, err := mnfst.AKSValidateOpts()
if err != nil {
return nil, err
}
validateOpts.HostData = hostData
return &snp.StaticValidateOptsGenerator{
Opts: validateOpts,
}, nil
}

func writeFilelist(dir string, filelist map[string][]byte) error {
if dir != "" {
if err := os.MkdirAll(dir, 0o755); err != nil {
Expand Down
80 changes: 71 additions & 9 deletions coordinator/internal/authority/authority.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,27 @@ import (
"errors"
"fmt"
"log/slog"
"net"
"sync"
"sync/atomic"
"time"

"github.com/edgelesssys/contrast/coordinator/history"
"github.com/edgelesssys/contrast/coordinator/internal/seedengine"
"github.com/edgelesssys/contrast/internal/atls"
"github.com/edgelesssys/contrast/internal/attestation/snp"
"github.com/edgelesssys/contrast/internal/ca"
"github.com/edgelesssys/contrast/internal/grpc/atlscredentials"
"github.com/edgelesssys/contrast/internal/logger"
"github.com/edgelesssys/contrast/internal/manifest"
"github.com/edgelesssys/contrast/internal/memstore"
"github.com/edgelesssys/contrast/internal/userapi"
"github.com/google/go-sev-guest/proto/sevsnp"
"github.com/google/go-sev-guest/validate"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"google.golang.org/grpc/credentials"
"k8s.io/utils/clock"
)

// ErrNoManifest is returned when a manifest is needed but not present.
Expand All @@ -52,6 +60,11 @@ type Authority struct {
logger *slog.Logger
metrics metrics

kdsGetter *snp.CachedHTTPSGetter
attestationFailuresCounter prometheus.Counter

atlsCredentials atomic.Pointer[atlscredentials.Credentials]

userapi.UnimplementedUserAPIServer
}

Expand All @@ -68,33 +81,57 @@ func New(hist *history.History, reg *prometheus.Registry, log *slog.Logger) *Aut
})
manifestGeneration.Set(0)

return &Authority{
bundles: make(map[string]Bundle),
hist: hist,
logger: log.WithGroup("mesh-authority"),
ticker := clock.RealClock{}.NewTicker(24 * time.Hour)

a := &Authority{
bundles: make(map[string]Bundle),
hist: hist,
logger: log.WithGroup("mesh-authority"),
kdsGetter: snp.NewCachedHTTPSGetter(memstore.New[string, []byte](), ticker, logger.NewNamed(log, "kds-getter")),
attestationFailuresCounter: promauto.With(reg).NewCounter(prometheus.CounterOpts{
Subsystem: "contrast_meshapi",
Name: "attestation_failures_total",
Help: "Number of attestation failures from workloads to the Coordinator.",
}),
metrics: metrics{
manifestGeneration: manifestGeneration,
},
}

a.atlsCredentials.Store(atlscredentials.New(atls.NoIssuer, []atls.Validator{}))

return a
}

// SNPValidateOpts returns SNP validation options from reference values.
//
// It also ensures that the policy hash in the report's HOSTDATA is allowed by the current
// manifest.
func (m *Authority) SNPValidateOpts(report *sevsnp.Report) (*validate.Options, error) {
func (m *Authority) SNPValidateOpts() ([]func(report *sevsnp.Report) (*validate.Options, error), error) {
state := m.state.Load()
if state == nil {
return nil, ErrNoManifest
}
mnfst := state.manifest

hostData := manifest.NewHexString(report.HostData)
if _, ok := mnfst.Policies[hostData]; !ok {
return nil, fmt.Errorf("hostdata %s not found in manifest", hostData)
validateOpts, err := mnfst.AKSValidateOpts()
if err != nil {
return nil, fmt.Errorf("get validate options: %w", err)
}

return mnfst.AKSValidateOpts()
var out []func(report *sevsnp.Report) (*validate.Options, error)
for _, vo := range validateOpts {
out = append(out, func(report *sevsnp.Report) (*validate.Options, error) {
hostData := manifest.NewHexString(report.HostData)
if _, ok := mnfst.Policies[hostData]; !ok {
return nil, fmt.Errorf("hostdata %s not found in manifest", hostData)
}

return vo, nil
})
}

return out, nil
}

// ValidateCallback creates a certificate bundle for the verified client.
Expand Down Expand Up @@ -264,3 +301,28 @@ type state struct {
ca *ca.CA
generation int
}

// ClientHandshake is a stub for accessing the currently stored aTLS credentials.
func (m *Authority) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
return m.atlsCredentials.Load().ClientHandshake(ctx, authority, rawConn)
}

// ServerHandshake is a stub for accessing the currently stored aTLS credentials.
func (m *Authority) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
return m.atlsCredentials.Load().ServerHandshake(rawConn)
}

// Info is a stub for accessing the currently stored aTLS credentials.
func (m *Authority) Info() credentials.ProtocolInfo {
return m.atlsCredentials.Load().Info()
}

// Clone is a stub for accessing the currently stored aTLS credentials.
func (m *Authority) Clone() credentials.TransportCredentials {
return m.atlsCredentials.Load().Clone()
}

// OverrideServerName is a stub for accessing the currently stored aTLS credentials.
func (m *Authority) OverrideServerName(serverName string) error {
return m.atlsCredentials.Load().OverrideServerName(serverName)
}
7 changes: 3 additions & 4 deletions coordinator/internal/authority/authority_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func TestSNPValidateOpts(t *testing.T) {
policyHash := sha256.Sum256(policies[0])
report := &sevsnp.Report{HostData: policyHash[:]}

opts, err := a.SNPValidateOpts(report)
opts, err := a.SNPValidateOpts()
require.Error(err)
require.Nil(opts)

Expand All @@ -51,16 +51,15 @@ func TestSNPValidateOpts(t *testing.T) {
_, err = a.SetManifest(context.Background(), req)
require.NoError(err)

opts, err = a.SNPValidateOpts(report)
opts, err = a.SNPValidateOpts()
require.NoError(err)
require.NotNil(opts)

// Change to unknown policy hash in HostData.
report.HostData[0]++

opts, err = a.SNPValidateOpts(report)
_, err = opts[0](report)
require.Error(err)
require.Nil(opts)
}

// TODO(burgerdev): test ValidateCallback and GetCertBundle
Expand Down
20 changes: 20 additions & 0 deletions coordinator/internal/authority/userapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,12 @@ import (
"slices"

"github.com/edgelesssys/contrast/coordinator/history"
"github.com/edgelesssys/contrast/internal/atls"
"github.com/edgelesssys/contrast/internal/attestation/snp"
"github.com/edgelesssys/contrast/internal/ca"
"github.com/edgelesssys/contrast/internal/crypto"
"github.com/edgelesssys/contrast/internal/grpc/atlscredentials"
"github.com/edgelesssys/contrast/internal/logger"
"github.com/edgelesssys/contrast/internal/manifest"
"github.com/edgelesssys/contrast/internal/userapi"
"google.golang.org/grpc/codes"
Expand Down Expand Up @@ -150,6 +154,22 @@ func (a *Authority) SetManifest(ctx context.Context, req *userapi.SetManifestReq
resp.RootCA = ca.GetRootCACert()
resp.MeshCA = ca.GetMeshCACert()

validateOpts, err := a.SNPValidateOpts()
if err != nil {
return nil, fmt.Errorf("getting SNP validate options: %w", err)
}

// create a validator for each [manifest.ReferenceValues] in the manifest.
var validators []atls.Validator
for _, opt := range validateOpts {
validator := snp.NewValidatorWithCallbacks(opt, a.kdsGetter,
logger.NewWithAttrs(logger.NewNamed(a.logger, "validator"), map[string]string{"tee-type": "snp"}),
a.attestationFailuresCounter, a)
validators = append(validators, validator)
}
a.atlsCredentials.Swap(atlscredentials.New(atls.NoIssuer, validators))
a.logger.Info("Swapped aTLS credentials", "validators", validators)

a.logger.Info("SetManifest succeeded")
return &resp, nil
}
Expand Down
2 changes: 1 addition & 1 deletion coordinator/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func run() (retErr error) {

userapi.RegisterUserAPIServer(grpcServer, meshAuth)
serverMetrics.InitializeMetrics(grpcServer)
meshAPI := newMeshAPIServer(meshAuth, meshAuth, promRegistry, serverMetrics, logger)
meshAPI := newMeshAPIServer(meshAuth, meshAuth, serverMetrics, logger)
metricsServer := &http.Server{}

eg, ctx := errgroup.WithContext(ctx)
Expand Down
Loading

0 comments on commit a353ac8

Please sign in to comment.