diff --git a/cli/cmd/common.go b/cli/cmd/common.go index 86881a4d40..77bc3d347a 100644 --- a/cli/cmd/common.go +++ b/cli/cmd/common.go @@ -6,20 +6,10 @@ package cmd import ( "context" _ "embed" - "fmt" - "log/slog" "os" - "path/filepath" "time" "github.com/edgelesssys/contrast/cli/telemetry" - "github.com/edgelesssys/contrast/internal/atls" - "github.com/edgelesssys/contrast/internal/attestation/certcache" - "github.com/edgelesssys/contrast/internal/attestation/snp" - "github.com/edgelesssys/contrast/internal/attestation/tdx" - "github.com/edgelesssys/contrast/internal/fsstore" - "github.com/edgelesssys/contrast/internal/logger" - "github.com/edgelesssys/contrast/internal/manifest" "github.com/spf13/cobra" ) @@ -35,7 +25,6 @@ const ( rulesFilename = "rules.rego" layersCacheFilename = "layers-cache.json" verifyDir = "verify" - cacheDirEnv = "CONTRAST_CACHE_DIR" ) var ( @@ -48,18 +37,6 @@ var ( DefaultCoordinatorPolicyHash = "" ) -func cachedir(subdir string) (string, error) { - dir := os.Getenv(cacheDirEnv) - if dir == "" { - cachedir, err := os.UserCacheDir() - if err != nil { - return "", err - } - dir = filepath.Join(cachedir, "contrast") - } - return filepath.Join(dir, subdir), nil -} - func must(err error) { if err != nil { panic(err) @@ -81,40 +58,3 @@ func withTelemetry(runFunc func(*cobra.Command, []string) error) func(*cobra.Com return cmdErr } } - -// validatorsFromManifest returns a list of validators corresponding to the reference values in the given manifest. -func validatorsFromManifest(m *manifest.Manifest, log *slog.Logger, hostData []byte) ([]atls.Validator, error) { - kdsDir, err := cachedir("kds") - if err != nil { - return nil, fmt.Errorf("getting cache dir: %w", err) - } - log.Debug("Using KDS cache dir", "dir", kdsDir) - kdsCache := fsstore.New(kdsDir, log.WithGroup("kds-cache")) - kdsGetter := certcache.NewCachedHTTPSGetter(kdsCache, certcache.NeverGCTicker, log.WithGroup("kds-getter")) - - var validators []atls.Validator - - opts, err := m.SNPValidateOpts(kdsGetter) - if err != nil { - return nil, fmt.Errorf("getting SNP validate options: %w", err) - } - for _, opt := range opts { - opt.ValidateOpts.HostData = hostData - validators = append(validators, snp.NewValidator(opt.VerifyOpts, opt.ValidateOpts, - logger.NewWithAttrs(logger.NewNamed(log, "validator"), map[string]string{"tee-type": "snp"}), - )) - } - - tdxOpts, err := m.TDXValidateOpts() - if err != nil { - return nil, fmt.Errorf("generating TDX validation options: %w", err) - } - var mrConfigID [48]byte - copy(mrConfigID[:], hostData) - for _, opt := range tdxOpts { - opt.TdQuoteBodyOptions.MrConfigID = mrConfigID[:] - validators = append(validators, tdx.NewValidator(&tdx.StaticValidateOptsGenerator{Opts: opt}, logger.NewWithAttrs(logger.NewNamed(log, "validator"), map[string]string{"tee-type": "tdx"}))) - } - - return validators, nil -} diff --git a/cli/cmd/recover.go b/cli/cmd/recover.go index d77e5e4d34..89741f86d9 100644 --- a/cli/cmd/recover.go +++ b/cli/cmd/recover.go @@ -14,6 +14,7 @@ import ( "github.com/edgelesssys/contrast/internal/grpc/dialer" "github.com/edgelesssys/contrast/internal/manifest" "github.com/edgelesssys/contrast/internal/userapi" + "github.com/edgelesssys/contrast/sdk" "github.com/spf13/cobra" ) @@ -73,7 +74,7 @@ func runRecover(cmd *cobra.Command, _ []string) error { return fmt.Errorf("decrypting seed: %w", err) } - validators, err := validatorsFromManifest(&m, log, flags.policy) + validators, err := sdk.ValidatorsFromManifest(&m, log, flags.policy) if err != nil { return fmt.Errorf("getting validators: %w", err) } diff --git a/cli/cmd/set.go b/cli/cmd/set.go index e07445c8d5..3e9eb0af77 100644 --- a/cli/cmd/set.go +++ b/cli/cmd/set.go @@ -24,6 +24,7 @@ import ( "github.com/edgelesssys/contrast/internal/retry" "github.com/edgelesssys/contrast/internal/spinner" "github.com/edgelesssys/contrast/internal/userapi" + "github.com/edgelesssys/contrast/sdk" "github.com/spf13/cobra" "github.com/spf13/pflag" "google.golang.org/grpc/codes" @@ -98,7 +99,7 @@ func runSet(cmd *cobra.Command, args []string) error { return fmt.Errorf("checking policies match manifest: %w", err) } - validators, err := validatorsFromManifest(&m, log, flags.policy) + validators, err := sdk.ValidatorsFromManifest(&m, log, flags.policy) if err != nil { return fmt.Errorf("getting validators: %w", err) } diff --git a/cli/cmd/verify.go b/cli/cmd/verify.go index 86f5753f93..680d4749d7 100644 --- a/cli/cmd/verify.go +++ b/cli/cmd/verify.go @@ -4,18 +4,13 @@ package cmd import ( - "bytes" "crypto/sha256" - "encoding/json" "fmt" - "net" "os" "path/filepath" - "github.com/edgelesssys/contrast/internal/atls" - "github.com/edgelesssys/contrast/internal/grpc/dialer" "github.com/edgelesssys/contrast/internal/manifest" - "github.com/edgelesssys/contrast/internal/userapi" + "github.com/edgelesssys/contrast/sdk" "github.com/spf13/cobra" ) @@ -60,33 +55,13 @@ func runVerify(cmd *cobra.Command, _ []string) error { if err != nil { return fmt.Errorf("failed to read manifest file: %w", err) } - var m manifest.Manifest - if err := json.Unmarshal(manifestBytes, &m); err != nil { - return fmt.Errorf("failed to unmarshal manifest: %w", err) - } - if err := m.Validate(); err != nil { - return fmt.Errorf("validating manifest: %w", err) - } - validators, err := validatorsFromManifest(&m, log, flags.policy) + sdkClient := sdk.New(log) + resp, err := sdkClient.GetManifests(cmd.Context(), manifestBytes, flags.coordinator, flags.policy) if err != nil { - return fmt.Errorf("getting validators: %w", err) + return fmt.Errorf("getting manifests: %w", err) } - dialer := dialer.New(atls.NoIssuer, validators, atls.NoMetrics, &net.Dialer{}) - log.Debug("Dialing coordinator", "endpoint", flags.coordinator) - conn, err := dialer.Dial(cmd.Context(), flags.coordinator) - if err != nil { - return fmt.Errorf("Error: failed to dial coordinator: %w", err) - } - defer conn.Close() - - log.Debug("Getting manifest") - client := userapi.NewUserAPIClient(conn) - resp, err := client.GetManifests(cmd.Context(), &userapi.GetManifestsRequest{}) - if err != nil { - return fmt.Errorf("failed to get manifest: %w", err) - } log.Debug("Got response") fmt.Fprintln(cmd.OutOrStdout(), "✔️ Successfully verified Coordinator CVM based on reference values from manifest") @@ -109,9 +84,8 @@ func runVerify(cmd *cobra.Command, _ []string) error { fmt.Fprintf(cmd.OutOrStdout(), "✔️ Wrote Coordinator configuration and keys to %s\n", filepath.Join(flags.workspaceDir, verifyDir)) - currentManifest := resp.Manifests[len(resp.Manifests)-1] - if !bytes.Equal(currentManifest, manifestBytes) { - return fmt.Errorf("manifest active at Coordinator does not match expected manifest") + if err := sdk.Verify(manifestBytes, resp.Manifests); err != nil { + return fmt.Errorf("failed to verify Coordinator manifest: %w", err) } fmt.Fprintln(cmd.OutOrStdout(), "✔️ Manifest active at Coordinator matches expected manifest") diff --git a/sdk/README.md b/sdk/README.md new file mode 100644 index 0000000000..e338cc8bc1 --- /dev/null +++ b/sdk/README.md @@ -0,0 +1,6 @@ +# Contrast SDK + +**Caution:** This SDK is still under active development and not fit for external use yet. +Please expect breaking changes with new minor versions. + +The SDK allows writing programs that interact with a Contrast deployment like the CLI does, without relying on the CLI. diff --git a/sdk/common.go b/sdk/common.go new file mode 100644 index 0000000000..1a6f8b1077 --- /dev/null +++ b/sdk/common.go @@ -0,0 +1,72 @@ +// Copyright 2024 Edgeless Systems GmbH +// SPDX-License-Identifier: AGPL-3.0-only + +package sdk + +import ( + "fmt" + "log/slog" + "os" + "path/filepath" + + "github.com/edgelesssys/contrast/internal/atls" + "github.com/edgelesssys/contrast/internal/attestation/certcache" + "github.com/edgelesssys/contrast/internal/attestation/snp" + "github.com/edgelesssys/contrast/internal/attestation/tdx" + "github.com/edgelesssys/contrast/internal/fsstore" + "github.com/edgelesssys/contrast/internal/logger" + "github.com/edgelesssys/contrast/internal/manifest" +) + +const cacheDirEnv = "CONTRAST_CACHE_DIR" + +// ValidatorsFromManifest returns a list of validators corresponding to the reference values in the given manifest. +// Originally an unexported function in the contrast CLI. +// Can be made unexported again, if we decide to move all userapi calls from the CLI to the SDK. +func ValidatorsFromManifest(m *manifest.Manifest, log *slog.Logger, hostData []byte) ([]atls.Validator, error) { + kdsDir, err := cachedir("kds") + if err != nil { + return nil, fmt.Errorf("getting cache dir: %w", err) + } + log.Debug("Using KDS cache dir", "dir", kdsDir) + kdsCache := fsstore.New(kdsDir, log.WithGroup("kds-cache")) + kdsGetter := certcache.NewCachedHTTPSGetter(kdsCache, certcache.NeverGCTicker, log.WithGroup("kds-getter")) + + var validators []atls.Validator + + opts, err := m.SNPValidateOpts(kdsGetter) + if err != nil { + return nil, fmt.Errorf("getting SNP validate options: %w", err) + } + for _, opt := range opts { + opt.ValidateOpts.HostData = hostData + validators = append(validators, snp.NewValidator(opt.VerifyOpts, opt.ValidateOpts, + logger.NewWithAttrs(logger.NewNamed(log, "validator"), map[string]string{"tee-type": "snp"}), + )) + } + + tdxOpts, err := m.TDXValidateOpts() + if err != nil { + return nil, fmt.Errorf("generating TDX validation options: %w", err) + } + var mrConfigID [48]byte + copy(mrConfigID[:], hostData) + for _, opt := range tdxOpts { + opt.TdQuoteBodyOptions.MrConfigID = mrConfigID[:] + validators = append(validators, tdx.NewValidator(&tdx.StaticValidateOptsGenerator{Opts: opt}, logger.NewWithAttrs(logger.NewNamed(log, "validator"), map[string]string{"tee-type": "tdx"}))) + } + + return validators, nil +} + +func cachedir(subdir string) (string, error) { + dir := os.Getenv(cacheDirEnv) + if dir == "" { + cachedir, err := os.UserCacheDir() + if err != nil { + return "", err + } + dir = filepath.Join(cachedir, "contrast") + } + return filepath.Join(dir, subdir), nil +} diff --git a/sdk/sdk.go b/sdk/sdk.go new file mode 100644 index 0000000000..7965abc7a4 --- /dev/null +++ b/sdk/sdk.go @@ -0,0 +1,4 @@ +// Copyright 2024 Edgeless Systems GmbH +// SPDX-License-Identifier: AGPL-3.0-only + +package sdk diff --git a/sdk/verify.go b/sdk/verify.go new file mode 100644 index 0000000000..a004109240 --- /dev/null +++ b/sdk/verify.go @@ -0,0 +1,90 @@ +// Copyright 2024 Edgeless Systems GmbH +// SPDX-License-Identifier: AGPL-3.0-only + +package sdk + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "log/slog" + "net" + + "github.com/edgelesssys/contrast/internal/atls" + "github.com/edgelesssys/contrast/internal/grpc/dialer" + "github.com/edgelesssys/contrast/internal/manifest" + "github.com/edgelesssys/contrast/internal/userapi" +) + +// Client is used to interact with a Contrast deployment. +type Client struct { + log *slog.Logger +} + +// New returns a Client. +func New(log *slog.Logger) Client { + return Client{ + log: log, + } +} + +// Verify checks if a given manifest is the latest manifest in the given history. +func Verify(expected []byte, history [][]byte) error { + currentManifest := history[len(history)-1] + if !bytes.Equal(currentManifest, expected) { + return fmt.Errorf("active manifest does not match expected manifest") + } + + return nil +} + +// GetManifests calls GetManifests on the coordinator's userapi. +func (c Client) GetManifests(ctx context.Context, manifestBytes []byte, endpoint string, policyHash []byte) (GetManifestsResponse, error) { + var m manifest.Manifest + if err := json.Unmarshal(manifestBytes, &m); err != nil { + return GetManifestsResponse{}, fmt.Errorf("unmarshalling manifest: %w", err) + } + if err := m.Validate(); err != nil { + return GetManifestsResponse{}, fmt.Errorf("validating manifest: %w", err) + } + + validators, err := ValidatorsFromManifest(&m, c.log, policyHash) + if err != nil { + return GetManifestsResponse{}, fmt.Errorf("getting validators: %w", err) + } + dialer := dialer.New(atls.NoIssuer, validators, atls.NoMetrics, &net.Dialer{}) + + c.log.Debug("Dialing coordinator", "endpoint", endpoint) + + conn, err := dialer.Dial(ctx, endpoint) + if err != nil { + return GetManifestsResponse{}, fmt.Errorf("dialing coordinator: %w", err) + } + defer conn.Close() + + c.log.Debug("Getting manifest") + + client := userapi.NewUserAPIClient(conn) + resp, err := client.GetManifests(ctx, &userapi.GetManifestsRequest{}) + if err != nil { + return GetManifestsResponse{}, fmt.Errorf("getting manifests: %w", err) + } + + return GetManifestsResponse{ + Manifests: resp.Manifests, + Policies: resp.Policies, + RootCA: resp.RootCA, + MeshCA: resp.MeshCA, + }, nil +} + +// GetManifestsResponse contains the Coordinator's response to a GetManifests call. +type GetManifestsResponse struct { + Manifests [][]byte + Policies [][]byte + // PEM-encoded certificate + RootCA []byte + // PEM-encoded certificate + MeshCA []byte +}