Skip to content

Commit

Permalink
cli: embed multiple reference values
Browse files Browse the repository at this point in the history
This adds support for embedding a more versatile format of reference values (i.e. a structured type) into the Contrast binaries. This will allow us to embed all reference values at build-time from a single source (the Nix build file) rather than having SVNs in Go code and inserting trusted measurements via the go build commandline. It will now embed a JSON file containing the reference values, which is unmarshaled at first default manifest generation.
  • Loading branch information
msanft committed Jul 23, 2024
1 parent badd3a2 commit 39dd4d8
Show file tree
Hide file tree
Showing 13 changed files with 410 additions and 288 deletions.
82 changes: 42 additions & 40 deletions cli/cmd/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ subcommands.`,
func runGenerate(cmd *cobra.Command, args []string) error {
flags, err := parseGenerateFlags(cmd)
if err != nil {
return fmt.Errorf("failed to parse flags: %w", err)
return fmt.Errorf("parse flags: %w", err)
}

log, err := newCLILogger(cmd)
Expand All @@ -101,23 +101,44 @@ func runGenerate(cmd *cobra.Command, args []string) error {
return err
}

if err := patchTargets(paths, flags.imageReplacementsFile, flags.skipInitializer, log); err != nil {
return fmt.Errorf("failed to patch targets: %w", err)
// generate manifest
defaultManifest := manifest.Default(flags.referenceValuesPlatform)

defaultManifestData, err := json.MarshalIndent(&defaultManifest, "", " ")
if err != nil {
return fmt.Errorf("marshaling default manifest: %w", err)
}
manifestData, err := readFileOrDefault(flags.manifestPath, defaultManifestData)
if err != nil {
return fmt.Errorf("read manifest file: %w", err)
}
var manifest *manifest.Manifest
if err := json.Unmarshal(manifestData, &manifest); err != nil {
return fmt.Errorf("unmarshal manifest: %w", err)
}

runtimeHandler, err := manifest.RuntimeHandler(flags.referenceValuesPlatform)
if err != nil {
return fmt.Errorf("get runtime handler: %w", err)
}

if err := patchTargets(paths, flags.imageReplacementsFile, runtimeHandler, flags.skipInitializer, log); err != nil {
return fmt.Errorf("patch targets: %w", err)
}
fmt.Fprintln(cmd.OutOrStdout(), "✔️ Patched targets")

if err := generatePolicies(cmd.Context(), flags.policyPath, flags.settingsPath, flags.genpolicyCachePath, paths, log); err != nil {
return fmt.Errorf("failed to generate policies: %w", err)
return fmt.Errorf("generate policies: %w", err)
}
fmt.Fprintln(cmd.OutOrStdout(), "✔️ Generated workload policy annotations")

policies, err := policiesFromKubeResources(paths)
if err != nil {
return fmt.Errorf("failed to find kube resources with policy: %w", err)
return fmt.Errorf("find kube resources with policy: %w", err)
}
policyMap, err := manifestPolicyMapFromPolicies(policies)
if err != nil {
return fmt.Errorf("failed to create policy map: %w", err)
return fmt.Errorf("create policy map: %w", err)
}

if err := generateWorkloadOwnerKey(flags); err != nil {
Expand All @@ -127,24 +148,6 @@ func runGenerate(cmd *cobra.Command, args []string) error {
return fmt.Errorf("generating seedshare owner key: %w", err)
}

defaultManifest := manifest.Default()
switch flags.referenceValuesPlatform {
case platforms.AKSCloudHypervisorSNP:
defaultManifest = manifest.DefaultAKS()
}

defaultManifestData, err := json.MarshalIndent(&defaultManifest, "", " ")
if err != nil {
return fmt.Errorf("marshaling default manifest: %w", err)
}
manifestData, err := readFileOrDefault(flags.manifestPath, defaultManifestData)
if err != nil {
return fmt.Errorf("failed to read manifest file: %w", err)
}
var manifest *manifest.Manifest
if err := json.Unmarshal(manifestData, &manifest); err != nil {
return fmt.Errorf("failed to unmarshal manifest: %w", err)
}
manifest.Policies = policyMap
if err := manifest.Validate(); err != nil {
return fmt.Errorf("validating manifest: %w", err)
Expand All @@ -170,18 +173,18 @@ func runGenerate(cmd *cobra.Command, args []string) error {

manifestData, err = json.MarshalIndent(manifest, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal manifest: %w", err)
return fmt.Errorf("marshal manifest: %w", err)
}
if err := os.WriteFile(flags.manifestPath, append(manifestData, '\n'), 0o644); err != nil {
return fmt.Errorf("failed to write manifest: %w", err)
return fmt.Errorf("write manifest: %w", err)
}

fmt.Fprintf(cmd.OutOrStdout(), "✔️ Updated manifest %s\n", flags.manifestPath)

if hash := getCoordinatorPolicyHash(policies, log); hash != "" {
coordHashPath := filepath.Join(flags.workspaceDir, coordHashFilename)
if err := os.WriteFile(coordHashPath, []byte(hash), 0o644); err != nil {
return fmt.Errorf("failed to write coordinator policy hash: %w", err)
return fmt.Errorf("write coordinator policy hash: %w", err)
}
}

Expand All @@ -207,7 +210,7 @@ func findGenerateTargets(args []string, logger *slog.Logger) ([]string, error) {
return nil
})
if err != nil {
return nil, fmt.Errorf("failed to walk %s: %w", path, err)
return nil, fmt.Errorf("walk %s: %w", path, err)
}
}
if len(paths) == 0 {
Expand All @@ -227,7 +230,7 @@ func filterNonCoCoRuntime(runtimeClassNamePrefix string, paths []string, logger
for _, path := range paths {
data, err := os.ReadFile(path)
if err != nil {
logger.Warn("Failed to read file", "path", path, "err", err)
logger.Warn("read file", "path", path, "err", err)
continue
}
if !bytes.Contains(data, []byte(runtimeClassNamePrefix)) {
Expand All @@ -248,21 +251,21 @@ func generatePolicies(ctx context.Context, regoRulesPath, policySettingsPath, ge
}
binaryInstallDir, err := installDir()
if err != nil {
return fmt.Errorf("failed to get install dir: %w", err)
return fmt.Errorf("get install dir: %w", err)
}
genpolicyInstall, err := embedbin.New().Install(binaryInstallDir, genpolicyBin)
if err != nil {
return fmt.Errorf("failed to install genpolicy: %w", err)
return fmt.Errorf("install genpolicy: %w", err)
}
defer func() {
if err := genpolicyInstall.Uninstall(); err != nil {
logger.Warn("Failed to uninstall genpolicy tool", "err", err)
logger.Warn("uninstall genpolicy tool", "err", err)
}
}()
for _, yamlPath := range yamlPaths {
policyHash, err := generatePolicyForFile(ctx, genpolicyInstall.Path(), regoRulesPath, policySettingsPath, yamlPath, genpolicyCachePath, logger)
if err != nil {
return fmt.Errorf("failed to generate policy for %s: %w", yamlPath, err)
return fmt.Errorf("generate policy for %s: %w", yamlPath, err)
}
if policyHash == [32]byte{} {
continue
Expand All @@ -273,7 +276,7 @@ func generatePolicies(ctx context.Context, regoRulesPath, policySettingsPath, ge
return nil
}

func patchTargets(paths []string, imageReplacementsFile string, skipInitializer bool, logger *slog.Logger) error {
func patchTargets(paths []string, imageReplacementsFile, runtimeHandler string, skipInitializer bool, logger *slog.Logger) error {
var replacements map[string]string
var err error
if imageReplacementsFile != "" {
Expand All @@ -296,11 +299,11 @@ func patchTargets(paths []string, imageReplacementsFile string, skipInitializer
for _, path := range paths {
data, err := os.ReadFile(path)
if err != nil {
return fmt.Errorf("failed to read %s: %w", path, err)
return fmt.Errorf("read %s: %w", path, err)
}
kubeObjs, err := kuberesource.UnmarshalApplyConfigurations(data)
if err != nil {
return fmt.Errorf("failed to unmarshal %s: %w", path, err)
return fmt.Errorf("unmarshal %s: %w", path, err)
}

if !skipInitializer {
Expand All @@ -314,7 +317,7 @@ func patchTargets(paths []string, imageReplacementsFile string, skipInitializer

kubeObjs = kuberesource.PatchImages(kubeObjs, replacements)

replaceRuntimeClassName := runtimeClassNamePatcher()
replaceRuntimeClassName := runtimeClassNamePatcher(runtimeHandler)
for i := range kubeObjs {
kubeObjs[i] = kuberesource.MapPodSpec(kubeObjs[i], replaceRuntimeClassName)
}
Expand All @@ -325,7 +328,7 @@ func patchTargets(paths []string, imageReplacementsFile string, skipInitializer
return err
}
if err := os.WriteFile(path, resource, os.ModePerm); err != nil {
return fmt.Errorf("failed to write %s: %w", path, err)
return fmt.Errorf("write %s: %w", path, err)
}
}
return nil
Expand Down Expand Up @@ -362,8 +365,7 @@ func injectServiceMesh(resources []any) error {
return nil
}

func runtimeClassNamePatcher() func(*applycorev1.PodSpecApplyConfiguration) *applycorev1.PodSpecApplyConfiguration {
handler := runtimeHandler(manifest.TrustedMeasurement)
func runtimeClassNamePatcher(handler string) func(*applycorev1.PodSpecApplyConfiguration) *applycorev1.PodSpecApplyConfiguration {
return func(spec *applycorev1.PodSpecApplyConfiguration) *applycorev1.PodSpecApplyConfiguration {
if spec.RuntimeClassName == nil || *spec.RuntimeClassName == handler {
return spec
Expand Down
2 changes: 1 addition & 1 deletion cli/cmd/verify.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ func parseVerifyFlags(cmd *cobra.Command) (*verifyFlags, error) {
}

func newCoordinatorValidateOptsGen(mnfst manifest.Manifest, hostData []byte) (*snp.StaticValidateOptsGenerator, error) {
validateOpts, err := mnfst.SNPValidateOpts()
validateOpts, err := mnfst.AKSValidateOpts()
if err != nil {
return nil, err
}
Expand Down
3 changes: 1 addition & 2 deletions cli/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ func buildVersionString() string {
}
fmt.Fprint(versionsWriter, "\n")
fmt.Fprintf(versionsWriter, "reference values for %s platform:\n", platforms.AKSCloudHypervisorSNP.String())
fmt.Fprintf(versionsWriter, "\truntime handler:\tcontrast-cc-%s\n", manifest.TrustedMeasurement[:32])
fmt.Fprintf(versionsWriter, "\tlaunch digest:\t%s\n", manifest.TrustedMeasurement)
fmt.Fprintf(versionsWriter, "\tembedded reference values:\t%s\n", manifest.EmbeddedReferenceValuesJSON)
fmt.Fprintf(versionsWriter, "\tgenpolicy version:\t%s\n", genpolicyVersion)
versionsWriter.Flush()
return versionsBuilder.String()
Expand Down
2 changes: 1 addition & 1 deletion coordinator/internal/authority/authority.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func (m *Authority) SNPValidateOpts(report *sevsnp.Report) (*validate.Options, e
return nil, fmt.Errorf("hostdata %s not found in manifest", hostData)
}

return mnfst.SNPValidateOpts()
return mnfst.AKSValidateOpts()
}

// ValidateCallback creates a certificate bundle for the verified client.
Expand Down
3 changes: 2 additions & 1 deletion coordinator/internal/authority/authority_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/edgelesssys/contrast/coordinator/history"
"github.com/edgelesssys/contrast/internal/manifest"
"github.com/edgelesssys/contrast/internal/userapi"
"github.com/edgelesssys/contrast/node-installer/platforms"
"github.com/google/go-sev-guest/proto/sevsnp"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/testutil"
Expand Down Expand Up @@ -79,7 +80,7 @@ func newManifest(t *testing.T) (*manifest.Manifest, []byte, [][]byte) {
policyHash := sha256.Sum256(policy)
policyHashHex := manifest.NewHexString(policyHash[:])

mnfst := manifest.DefaultAKS()
mnfst := manifest.Default(platforms.AKSCloudHypervisorSNP)
mnfst.Policies = map[manifest.HexString][]string{policyHashHex: {"test"}}
mnfst.WorkloadOwnerKeyDigests = []manifest.HexString{keyDigest}
mnfstBytes, err := json.Marshal(mnfst)
Expand Down
9 changes: 5 additions & 4 deletions coordinator/internal/authority/userapi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/edgelesssys/contrast/coordinator/history"
"github.com/edgelesssys/contrast/internal/manifest"
"github.com/edgelesssys/contrast/internal/userapi"
"github.com/edgelesssys/contrast/node-installer/platforms"
"github.com/prometheus/client_golang/prometheus"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
Expand All @@ -33,7 +34,7 @@ import (

func TestManifestSet(t *testing.T) {
newBaseManifest := func() manifest.Manifest {
return manifest.Default()
return manifest.Default(platforms.AKSCloudHypervisorSNP)
}
newManifestBytes := func(f func(*manifest.Manifest)) []byte {
m := newBaseManifest()
Expand Down Expand Up @@ -221,7 +222,7 @@ func TestGetManifests(t *testing.T) {
require.Equal(codes.FailedPrecondition, status.Code(err))
assert.Nil(resp)

m := manifest.Default()
m := manifest.Default(platforms.AKSCloudHypervisorSNP)
m.Policies = map[manifest.HexString][]string{
manifest.HexString("ca978112ca1bbdcafac231b39a23dc4da786eff8147c4e72b9807785afee48bb"): {"a1", "a2"},
manifest.HexString("3e23e8160039594a33894f6564e1b1348bbd7a0088d42c4acb73eeaed59c009d"): {"b1", "b2"},
Expand Down Expand Up @@ -374,7 +375,7 @@ func TestRecoveryFlow(t *testing.T) {
// gRPCs of the server.
func TestUserAPIConcurrent(t *testing.T) {
newBaseManifest := func() manifest.Manifest {
return manifest.Default()
return manifest.Default(platforms.AKSCloudHypervisorSNP)
}
newManifestBytes := func(f func(*manifest.Manifest)) []byte {
m := newBaseManifest()
Expand Down Expand Up @@ -459,7 +460,7 @@ func rpcContext(key *ecdsa.PrivateKey) context.Context {
}

func manifestWithWorkloadOwnerKey(key *ecdsa.PrivateKey) (*manifest.Manifest, error) {
m := manifest.Default()
m := manifest.Default(platforms.AKSCloudHypervisorSNP)
if key == nil {
return &m, nil
}
Expand Down
1 change: 1 addition & 0 deletions internal/manifest/assets/reference-values.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"THIS FILE IS REPLACED DURING BUILD AND ONLY HERE TO SATISFY GO TOOLING"
45 changes: 22 additions & 23 deletions internal/manifest/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,31 @@

package manifest

// TrustedMeasurement contains the expected launch digest and is injected at build time.
var TrustedMeasurement = "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"
import (
"encoding/json"
"fmt"
"os"

// Default returns a default manifest.
func Default() Manifest {
return Manifest{
ReferenceValues: ReferenceValues{
TrustedMeasurement: HexString(TrustedMeasurement),
},
"github.com/edgelesssys/contrast/node-installer/platforms"
)

// Default returns a default manifest with reference values for the given platform.
func Default(platform platforms.Platform) Manifest {
if embeddedReferenceValues == nil {
// If we're here, this is the first time this function is called, and the global state is not
// yet initialized. So let's unmarshal the embedded reference values.
if err := json.Unmarshal(EmbeddedReferenceValuesJSON, &embeddedReferenceValues); err != nil {
fmt.Printf("Failed to unmarshal embedded reference values: %s\n", err)
os.Exit(1)
}
}
}

// DefaultAKS returns a default manifest with AKS reference values.
func DefaultAKS() Manifest {
mnfst := Default()
mnfst.ReferenceValues.SNP = SNPReferenceValues{
MinimumTCB: SNPTCB{
BootloaderVersion: toPtr(SVN(3)),
TEEVersion: toPtr(SVN(0)),
SNPVersion: toPtr(SVN(8)),
MicrocodeVersion: toPtr(SVN(115)),
},
mnfst := Manifest{}
switch platform {
case platforms.AKSCloudHypervisorSNP:
mnfst.ReferenceValues.AKS = embeddedReferenceValues.AKS
case platforms.RKE2QEMUTDX, platforms.K3sQEMUTDX:
mnfst.ReferenceValues.BareMetalTDX = embeddedReferenceValues.BareMetalTDX
}
return mnfst
}

func toPtr[T any](t T) *T {
return &t
}
Loading

0 comments on commit 39dd4d8

Please sign in to comment.