Skip to content

Commit

Permalink
sdk: move verify code into sdk
Browse files Browse the repository at this point in the history
  • Loading branch information
derpsteb committed Nov 25, 2024
1 parent 6f2caf3 commit 1603bb5
Show file tree
Hide file tree
Showing 8 changed files with 182 additions and 94 deletions.
60 changes: 0 additions & 60 deletions cli/cmd/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,10 @@ package cmd
import (
"context"
_ "embed"
"fmt"
"log/slog"
"os"
"path/filepath"
"time"

"github.com/edgelesssys/contrast/cli/telemetry"
"github.com/edgelesssys/contrast/internal/atls"
"github.com/edgelesssys/contrast/internal/attestation/certcache"
"github.com/edgelesssys/contrast/internal/attestation/snp"
"github.com/edgelesssys/contrast/internal/attestation/tdx"
"github.com/edgelesssys/contrast/internal/fsstore"
"github.com/edgelesssys/contrast/internal/logger"
"github.com/edgelesssys/contrast/internal/manifest"
"github.com/spf13/cobra"
)

Expand All @@ -35,7 +25,6 @@ const (
rulesFilename = "rules.rego"
layersCacheFilename = "layers-cache.json"
verifyDir = "verify"
cacheDirEnv = "CONTRAST_CACHE_DIR"
)

var (
Expand All @@ -48,18 +37,6 @@ var (
DefaultCoordinatorPolicyHash = ""
)

func cachedir(subdir string) (string, error) {
dir := os.Getenv(cacheDirEnv)
if dir == "" {
cachedir, err := os.UserCacheDir()
if err != nil {
return "", err
}
dir = filepath.Join(cachedir, "contrast")
}
return filepath.Join(dir, subdir), nil
}

func must(err error) {
if err != nil {
panic(err)
Expand All @@ -81,40 +58,3 @@ func withTelemetry(runFunc func(*cobra.Command, []string) error) func(*cobra.Com
return cmdErr
}
}

// validatorsFromManifest returns a list of validators corresponding to the reference values in the given manifest.
func validatorsFromManifest(m *manifest.Manifest, log *slog.Logger, hostData []byte) ([]atls.Validator, error) {
kdsDir, err := cachedir("kds")
if err != nil {
return nil, fmt.Errorf("getting cache dir: %w", err)
}
log.Debug("Using KDS cache dir", "dir", kdsDir)
kdsCache := fsstore.New(kdsDir, log.WithGroup("kds-cache"))
kdsGetter := certcache.NewCachedHTTPSGetter(kdsCache, certcache.NeverGCTicker, log.WithGroup("kds-getter"))

var validators []atls.Validator

opts, err := m.SNPValidateOpts(kdsGetter)
if err != nil {
return nil, fmt.Errorf("getting SNP validate options: %w", err)
}
for _, opt := range opts {
opt.ValidateOpts.HostData = hostData
validators = append(validators, snp.NewValidator(opt.VerifyOpts, opt.ValidateOpts,
logger.NewWithAttrs(logger.NewNamed(log, "validator"), map[string]string{"tee-type": "snp"}),
))
}

tdxOpts, err := m.TDXValidateOpts()
if err != nil {
return nil, fmt.Errorf("generating TDX validation options: %w", err)
}
var mrConfigID [48]byte
copy(mrConfigID[:], hostData)
for _, opt := range tdxOpts {
opt.TdQuoteBodyOptions.MrConfigID = mrConfigID[:]
validators = append(validators, tdx.NewValidator(&tdx.StaticValidateOptsGenerator{Opts: opt}, logger.NewWithAttrs(logger.NewNamed(log, "validator"), map[string]string{"tee-type": "tdx"})))
}

return validators, nil
}
3 changes: 2 additions & 1 deletion cli/cmd/recover.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/edgelesssys/contrast/internal/grpc/dialer"
"github.com/edgelesssys/contrast/internal/manifest"
"github.com/edgelesssys/contrast/internal/userapi"
"github.com/edgelesssys/contrast/sdk"
"github.com/spf13/cobra"
)

Expand Down Expand Up @@ -73,7 +74,7 @@ func runRecover(cmd *cobra.Command, _ []string) error {
return fmt.Errorf("decrypting seed: %w", err)
}

validators, err := validatorsFromManifest(&m, log, flags.policy)
validators, err := sdk.ValidatorsFromManifest(&m, log, flags.policy)
if err != nil {
return fmt.Errorf("getting validators: %w", err)
}
Expand Down
3 changes: 2 additions & 1 deletion cli/cmd/set.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/edgelesssys/contrast/internal/retry"
"github.com/edgelesssys/contrast/internal/spinner"
"github.com/edgelesssys/contrast/internal/userapi"
"github.com/edgelesssys/contrast/sdk"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
"google.golang.org/grpc/codes"
Expand Down Expand Up @@ -98,7 +99,7 @@ func runSet(cmd *cobra.Command, args []string) error {
return fmt.Errorf("checking policies match manifest: %w", err)
}

validators, err := validatorsFromManifest(&m, log, flags.policy)
validators, err := sdk.ValidatorsFromManifest(&m, log, flags.policy)
if err != nil {
return fmt.Errorf("getting validators: %w", err)
}
Expand Down
38 changes: 6 additions & 32 deletions cli/cmd/verify.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,13 @@
package cmd

import (
"bytes"
"crypto/sha256"
"encoding/json"
"fmt"
"net"
"os"
"path/filepath"

"github.com/edgelesssys/contrast/internal/atls"
"github.com/edgelesssys/contrast/internal/grpc/dialer"
"github.com/edgelesssys/contrast/internal/manifest"
"github.com/edgelesssys/contrast/internal/userapi"
"github.com/edgelesssys/contrast/sdk"
"github.com/spf13/cobra"
)

Expand Down Expand Up @@ -60,33 +55,13 @@ func runVerify(cmd *cobra.Command, _ []string) error {
if err != nil {
return fmt.Errorf("failed to read manifest file: %w", err)
}
var m manifest.Manifest
if err := json.Unmarshal(manifestBytes, &m); err != nil {
return fmt.Errorf("failed to unmarshal manifest: %w", err)
}
if err := m.Validate(); err != nil {
return fmt.Errorf("validating manifest: %w", err)
}

validators, err := validatorsFromManifest(&m, log, flags.policy)
sdkClient := sdk.New(log)
resp, err := sdkClient.GetManifests(cmd.Context(), manifestBytes, flags.coordinator, flags.policy)
if err != nil {
return fmt.Errorf("getting validators: %w", err)
return fmt.Errorf("getting manifests: %w", err)
}
dialer := dialer.New(atls.NoIssuer, validators, atls.NoMetrics, &net.Dialer{})

log.Debug("Dialing coordinator", "endpoint", flags.coordinator)
conn, err := dialer.Dial(cmd.Context(), flags.coordinator)
if err != nil {
return fmt.Errorf("Error: failed to dial coordinator: %w", err)
}
defer conn.Close()

log.Debug("Getting manifest")
client := userapi.NewUserAPIClient(conn)
resp, err := client.GetManifests(cmd.Context(), &userapi.GetManifestsRequest{})
if err != nil {
return fmt.Errorf("failed to get manifest: %w", err)
}
log.Debug("Got response")

fmt.Fprintln(cmd.OutOrStdout(), "✔️ Successfully verified Coordinator CVM based on reference values from manifest")
Expand All @@ -109,9 +84,8 @@ func runVerify(cmd *cobra.Command, _ []string) error {

fmt.Fprintf(cmd.OutOrStdout(), "✔️ Wrote Coordinator configuration and keys to %s\n", filepath.Join(flags.workspaceDir, verifyDir))

currentManifest := resp.Manifests[len(resp.Manifests)-1]
if !bytes.Equal(currentManifest, manifestBytes) {
return fmt.Errorf("manifest active at Coordinator does not match expected manifest")
if err := sdk.Verify(manifestBytes, resp.Manifests); err != nil {
return fmt.Errorf("failed to verify Coordinator manifest: %w", err)
}

fmt.Fprintln(cmd.OutOrStdout(), "✔️ Manifest active at Coordinator matches expected manifest")
Expand Down
6 changes: 6 additions & 0 deletions sdk/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Contrast SDK

**Caution:** This SDK is still under active development and not fit for external use yet.
Please expect breaking changes with new minor versions.

The SDK allows writing programs that interact with a Contrast deployment like the CLI does, without relying on the CLI.
72 changes: 72 additions & 0 deletions sdk/common.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// Copyright 2024 Edgeless Systems GmbH
// SPDX-License-Identifier: AGPL-3.0-only

package sdk

import (
"fmt"
"log/slog"
"os"
"path/filepath"

"github.com/edgelesssys/contrast/internal/atls"
"github.com/edgelesssys/contrast/internal/attestation/certcache"
"github.com/edgelesssys/contrast/internal/attestation/snp"
"github.com/edgelesssys/contrast/internal/attestation/tdx"
"github.com/edgelesssys/contrast/internal/fsstore"
"github.com/edgelesssys/contrast/internal/logger"
"github.com/edgelesssys/contrast/internal/manifest"
)

const cacheDirEnv = "CONTRAST_CACHE_DIR"

// ValidatorsFromManifest returns a list of validators corresponding to the reference values in the given manifest.
// Originally an unexported function in the contrast CLI.
// Can be made unexported again, if we decide to move all userapi calls from the CLI to the SDK.
func ValidatorsFromManifest(m *manifest.Manifest, log *slog.Logger, hostData []byte) ([]atls.Validator, error) {
kdsDir, err := cachedir("kds")
if err != nil {
return nil, fmt.Errorf("getting cache dir: %w", err)
}
log.Debug("Using KDS cache dir", "dir", kdsDir)
kdsCache := fsstore.New(kdsDir, log.WithGroup("kds-cache"))
kdsGetter := certcache.NewCachedHTTPSGetter(kdsCache, certcache.NeverGCTicker, log.WithGroup("kds-getter"))

var validators []atls.Validator

opts, err := m.SNPValidateOpts(kdsGetter)
if err != nil {
return nil, fmt.Errorf("getting SNP validate options: %w", err)
}
for _, opt := range opts {
opt.ValidateOpts.HostData = hostData
validators = append(validators, snp.NewValidator(opt.VerifyOpts, opt.ValidateOpts,
logger.NewWithAttrs(logger.NewNamed(log, "validator"), map[string]string{"tee-type": "snp"}),
))
}

tdxOpts, err := m.TDXValidateOpts()
if err != nil {
return nil, fmt.Errorf("generating TDX validation options: %w", err)
}
var mrConfigID [48]byte
copy(mrConfigID[:], hostData)
for _, opt := range tdxOpts {
opt.TdQuoteBodyOptions.MrConfigID = mrConfigID[:]
validators = append(validators, tdx.NewValidator(&tdx.StaticValidateOptsGenerator{Opts: opt}, logger.NewWithAttrs(logger.NewNamed(log, "validator"), map[string]string{"tee-type": "tdx"})))
}

return validators, nil
}

func cachedir(subdir string) (string, error) {
dir := os.Getenv(cacheDirEnv)
if dir == "" {
cachedir, err := os.UserCacheDir()
if err != nil {
return "", err
}
dir = filepath.Join(cachedir, "contrast")
}
return filepath.Join(dir, subdir), nil
}
4 changes: 4 additions & 0 deletions sdk/sdk.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
// Copyright 2024 Edgeless Systems GmbH
// SPDX-License-Identifier: AGPL-3.0-only

package sdk
90 changes: 90 additions & 0 deletions sdk/verify.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
// Copyright 2024 Edgeless Systems GmbH
// SPDX-License-Identifier: AGPL-3.0-only

package sdk

import (
"bytes"
"context"
"encoding/json"
"fmt"
"log/slog"
"net"

"github.com/edgelesssys/contrast/internal/atls"
"github.com/edgelesssys/contrast/internal/grpc/dialer"
"github.com/edgelesssys/contrast/internal/manifest"
"github.com/edgelesssys/contrast/internal/userapi"
)

// Client is used to interact with a Contrast deployment.
type Client struct {
log *slog.Logger
}

// New returns a Client.
func New(log *slog.Logger) Client {
return Client{
log: log,
}
}

// Verify checks if a given manifest is the latest manifest in the given history.
func Verify(expected []byte, history [][]byte) error {
currentManifest := history[len(history)-1]
if !bytes.Equal(currentManifest, expected) {
return fmt.Errorf("active manifest does not match expected manifest")
}

return nil
}

// GetManifests calls GetManifests on the coordinator's userapi.
func (c Client) GetManifests(ctx context.Context, manifestBytes []byte, endpoint string, policyHash []byte) (GetManifestsResponse, error) {
var m manifest.Manifest
if err := json.Unmarshal(manifestBytes, &m); err != nil {
return GetManifestsResponse{}, fmt.Errorf("unmarshalling manifest: %w", err)
}
if err := m.Validate(); err != nil {
return GetManifestsResponse{}, fmt.Errorf("validating manifest: %w", err)
}

validators, err := ValidatorsFromManifest(&m, c.log, policyHash)
if err != nil {
return GetManifestsResponse{}, fmt.Errorf("getting validators: %w", err)
}
dialer := dialer.New(atls.NoIssuer, validators, atls.NoMetrics, &net.Dialer{})

c.log.Debug("Dialing coordinator", "endpoint", endpoint)

conn, err := dialer.Dial(ctx, endpoint)
if err != nil {
return GetManifestsResponse{}, fmt.Errorf("dialing coordinator: %w", err)
}
defer conn.Close()

c.log.Debug("Getting manifest")

client := userapi.NewUserAPIClient(conn)
resp, err := client.GetManifests(ctx, &userapi.GetManifestsRequest{})
if err != nil {
return GetManifestsResponse{}, fmt.Errorf("getting manifests: %w", err)
}

return GetManifestsResponse{
Manifests: resp.Manifests,
Policies: resp.Policies,
RootCA: resp.RootCA,
MeshCA: resp.MeshCA,
}, nil
}

// GetManifestsResponse contains the Coordinator's response to a GetManifests call.
type GetManifestsResponse struct {
Manifests [][]byte
Policies [][]byte
// PEM-encoded certificate
RootCA []byte
// PEM-encoded certificate
MeshCA []byte
}

0 comments on commit 1603bb5

Please sign in to comment.