diff --git a/.gitignore b/.gitignore index 81ce3726ae..219c4e13d8 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ layers_cache layers-cache.json mesh-root.pem coordinator-root.pem +workload-owner.pem justfile.env workspace workspace.cache diff --git a/cli/constants.go b/cli/constants.go index 1c82909b9c..62ec8ddf78 100644 --- a/cli/constants.go +++ b/cli/constants.go @@ -9,6 +9,7 @@ import ( const ( coordRootPEMFilename = "coordinator-root.pem" coordIntermPEMFilename = "mesh-root.pem" + workloadOwnerPEM = "workload-owner.pem" manifestFilename = "manifest.json" settingsFilename = "settings.json" rulesFilename = "rules.rego" diff --git a/cli/generate.go b/cli/generate.go index 0f9627fccd..4bd3e622c2 100644 --- a/cli/generate.go +++ b/cli/generate.go @@ -3,15 +3,21 @@ package main import ( "bytes" "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" "crypto/sha256" + "crypto/x509" "encoding/hex" "encoding/json" + "encoding/pem" "errors" "fmt" "log/slog" "os" "os/exec" "path/filepath" + "slices" "strings" "github.com/edgelesssys/nunki/internal/embedbin" @@ -49,7 +55,12 @@ func newGenerateCmd() *cobra.Command { cmd.Flags().StringP("policy", "p", policyDir, "path to policy (.rego) file") cmd.Flags().StringP("settings", "s", settingsFilename, "path to settings (.json) file") cmd.Flags().StringP("manifest", "m", manifestFilename, "path to manifest (.json) file") - + cmd.Flags().StringArrayP("workload-owner-key", "w", []string{workloadOwnerPEM}, "path to workload owner key (.pem) file") + cmd.Flags().BoolP("disable-updates", "d", false, "prevent further updates of the manifest") + must(cmd.MarkFlagFilename("policy", "rego")) + must(cmd.MarkFlagFilename("settings", "json")) + must(cmd.MarkFlagFilename("manifest", "json")) + cmd.MarkFlagsMutuallyExclusive("workload-owner-key", "disable-updates") return cmd } @@ -82,6 +93,10 @@ func runGenerate(cmd *cobra.Command, args []string) error { return fmt.Errorf("failed to create policy map: %w", err) } + if err := generateWorkloadOwnerKey(flags); err != nil { + return fmt.Errorf("generating workload owner key: %w", err) + } + defaultManifest := manifest.Default() defaultManifestData, err := json.MarshalIndent(&defaultManifest, "", " ") if err != nil { @@ -96,6 +111,18 @@ func runGenerate(cmd *cobra.Command, args []string) error { return fmt.Errorf("failed to unmarshal manifest: %w", err) } manifest.Policies = policyMap + + if flags.disableUpdates { + manifest.WorkloadOwnerKeyDigests = nil + } else { + for _, keyPath := range flags.workloadOwnerKeys { + if err := addWorkloadOwnerKeyToManifest(manifest, keyPath); err != nil { + return fmt.Errorf("adding workload owner key to manifest: %w", err) + } + } + } + slices.Sort(manifest.WorkloadOwnerKeyDigests) + manifestData, err = json.MarshalIndent(manifest, "", " ") if err != nil { return fmt.Errorf("failed to marshal manifest: %w", err) @@ -162,10 +189,10 @@ func filterNonCoCoRuntime(runtimeClassName string, paths []string, logger *slog. } func generatePolicies(ctx context.Context, regoPath, policyPath string, yamlPaths []string, logger *slog.Logger) error { - if err := createFileWithDefault(filepath.Join(regoPath, policyPath), defaultGenpolicySettings); err != nil { + if err := createFileWithDefault(filepath.Join(regoPath, policyPath), func() ([]byte, error) { return defaultGenpolicySettings, nil }); err != nil { return fmt.Errorf("creating default policy file: %w", err) } - if err := createFileWithDefault(filepath.Join(regoPath, rulesFilename), defaultRules); err != nil { + if err := createFileWithDefault(filepath.Join(regoPath, rulesFilename), func() ([]byte, error) { return defaultRules, nil }); err != nil { return fmt.Errorf("creating default policy.rego file: %w", err) } binaryInstallDir, err := installDir() @@ -195,6 +222,43 @@ func generatePolicies(ctx context.Context, regoPath, policyPath string, yamlPath return nil } +func addWorkloadOwnerKeyToManifest(manifst *manifest.Manifest, keyPath string) error { + keyData, err := os.ReadFile(keyPath) + if err != nil { + return fmt.Errorf("reading workload owner key: %w", err) + } + block, _ := pem.Decode(keyData) + if block == nil { + return errors.New("failed to decode PEM block") + } + var publicKey []byte + switch block.Type { + case "PUBLIC KEY": + publicKey = block.Bytes + case "EC PRIVATE KEY": + privateKey, err := x509.ParseECPrivateKey(block.Bytes) + if err != nil { + return fmt.Errorf("parsing EC private key: %w", err) + } + publicKey, err = x509.MarshalPKIXPublicKey(&privateKey.PublicKey) + if err != nil { + return fmt.Errorf("marshaling public key: %w", err) + } + default: + return fmt.Errorf("unsupported PEM block type: %s", block.Type) + } + + hash := sha256.Sum256(publicKey) + hashString := manifest.NewHexString(hash[:]) + for _, existingHash := range manifst.WorkloadOwnerKeyDigests { + if existingHash == hashString { + return nil + } + } + manifst.WorkloadOwnerKeyDigests = append(manifst.WorkloadOwnerKeyDigests, hashString) + return nil +} + func generatePolicyForFile(ctx context.Context, genpolicyPath, regoPath, policyPath, yamlPath string, logger *slog.Logger) ([32]byte, error) { args := []string{ "--raw-out", @@ -223,10 +287,39 @@ func generatePolicyForFile(ctx context.Context, genpolicyPath, regoPath, policyP return policyHash, nil } +func generateWorkloadOwnerKey(flags *generateFlags) error { + if flags.disableUpdates || len(flags.workloadOwnerKeys) != 1 { + // No need to generate keys + // either updates are disabled or + // the user has provided a set of (presumably already generated) public keys + return nil + } + keyPath := flags.workloadOwnerKeys[0] + + if err := createFileWithDefault(keyPath, newKeyPair); err != nil { + return fmt.Errorf("creating default workload owner key file: %w", err) + } + return nil +} + +func newKeyPair() ([]byte, error) { + privateKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader) + if err != nil { + return nil, fmt.Errorf("generating private key: %w", err) + } + privateKeyBytes, err := x509.MarshalECPrivateKey(privateKey) + if err != nil { + return nil, fmt.Errorf("marshaling private key: %w", err) + } + return pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: privateKeyBytes}), nil +} + type generateFlags struct { - policyPath string - settingsPath string - manifestPath string + policyPath string + settingsPath string + manifestPath string + workloadOwnerKeys []string + disableUpdates bool } func parseGenerateFlags(cmd *cobra.Command) (*generateFlags, error) { @@ -242,10 +335,21 @@ func parseGenerateFlags(cmd *cobra.Command) (*generateFlags, error) { if err != nil { return nil, err } + workloadOwnerKeys, err := cmd.Flags().GetStringArray("workload-owner-key") + if err != nil { + return nil, err + } + disableUpdates, err := cmd.Flags().GetBool("disable-updates") + if err != nil { + return nil, err + } + return &generateFlags{ - policyPath: policyPath, - settingsPath: settingsPath, - manifestPath: manifestPath, + policyPath: policyPath, + settingsPath: settingsPath, + manifestPath: manifestPath, + workloadOwnerKeys: workloadOwnerKeys, + disableUpdates: disableUpdates, }, nil } @@ -264,7 +368,7 @@ func readFileOrDefault(path string, deflt []byte) ([]byte, error) { // createFileWithDefault creates the file at path with the default value, // if it doesn't exist. -func createFileWithDefault(path string, deflt []byte) error { +func createFileWithDefault(path string, dflt func() ([]byte, error)) error { file, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0o644) if os.IsExist(err) { return nil @@ -273,7 +377,11 @@ func createFileWithDefault(path string, deflt []byte) error { return err } defer file.Close() - _, err = file.Write(deflt) + content, err := dflt() + if err != nil { + return err + } + _, err = file.Write(content) return err } diff --git a/cli/set.go b/cli/set.go index ce221f9807..4837654d2e 100644 --- a/cli/set.go +++ b/cli/set.go @@ -2,12 +2,18 @@ package main import ( "context" + "crypto/ecdsa" + "crypto/sha256" + "crypto/x509" "encoding/hex" "encoding/json" + "encoding/pem" "fmt" "io" + "log/slog" "net" "os" + "slices" "time" "github.com/edgelesssys/nunki/internal/atls" @@ -18,6 +24,8 @@ import ( "github.com/edgelesssys/nunki/internal/manifest" "github.com/edgelesssys/nunki/internal/spinner" "github.com/spf13/cobra" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) func newSetCmd() *cobra.Command { @@ -43,6 +51,7 @@ func newSetCmd() *cobra.Command { cmd.Flags().StringP("coordinator", "c", "", "endpoint the coordinator can be reached at") must(cobra.MarkFlagRequired(cmd.Flags(), "coordinator")) cmd.Flags().String("coordinator-policy-hash", DefaultCoordinatorPolicyHash, "expected policy hash of the coordinator, will not be checked if empty") + cmd.Flags().String("workload-owner-key", workloadOwnerPEM, "path to workload owner key (.pem) file") return cmd } @@ -67,6 +76,11 @@ func runSet(cmd *cobra.Command, args []string) error { return fmt.Errorf("failed to unmarshal manifest: %w", err) } + workloadOwnerKey, err := loadWorkloadOwnerKey(flags.workloadOwnerKeyPath, m, log) + if err != nil { + return fmt.Errorf("loading workload owner key: %w", err) + } + paths, err := findGenerateTargets(args, log) if err != nil { return fmt.Errorf("finding yaml files: %w", err) @@ -90,7 +104,7 @@ func runSet(cmd *cobra.Command, args []string) error { kdsCache := fsstore.New(kdsDir, log.WithGroup("kds-cache")) kdsGetter := snp.NewCachedHTTPSGetter(kdsCache, snp.NeverGCTicker, log.WithGroup("kds-getter")) validator := snp.NewValidator(validateOptsGen, kdsGetter, log.WithGroup("snp-validator")) - dialer := dialer.New(atls.NoIssuer, validator, &net.Dialer{}) + dialer := dialer.NewWithKey(atls.NoIssuer, validator, &net.Dialer{}, workloadOwnerKey) conn, err := dialer.Dial(cmd.Context(), flags.coordinator) if err != nil { @@ -105,6 +119,18 @@ func runSet(cmd *cobra.Command, args []string) error { } resp, err := setLoop(cmd.Context(), client, cmd.OutOrStdout(), req) if err != nil { + grpcSt, ok := status.FromError(err) + if ok { + if grpcSt.Code() == codes.PermissionDenied { + msg := "Permission denied." + if workloadOwnerKey == nil { + msg += " Specify a workload owner key with --workload-owner-key." + } else { + msg += " Ensure you are using a trusted workload owner key." + } + fmt.Fprintln(cmd.OutOrStdout(), msg) + } + } return fmt.Errorf("failed to set manifest: %w", err) } @@ -122,9 +148,10 @@ func runSet(cmd *cobra.Command, args []string) error { } type setFlags struct { - manifestPath string - coordinator string - policy []byte + manifestPath string + coordinator string + policy []byte + workloadOwnerKeyPath string } func parseSetFlags(cmd *cobra.Command) (*setFlags, error) { @@ -147,6 +174,10 @@ func parseSetFlags(cmd *cobra.Command) (*setFlags, error) { if err != nil { return nil, fmt.Errorf("hex-decoding coordinator-policy-hash flag: %w", err) } + flags.workloadOwnerKeyPath, err = cmd.Flags().GetString("workload-owner-key") + if err != nil { + return nil, fmt.Errorf("getting workload-owner-key flag: %w", err) + } return flags, nil } @@ -159,6 +190,42 @@ func policyMapToBytesList(m map[string]deployment) [][]byte { return policies } +func loadWorkloadOwnerKey(path string, manifst manifest.Manifest, log *slog.Logger) (*ecdsa.PrivateKey, error) { + key, err := os.ReadFile(path) + if os.IsNotExist(err) { + return nil, nil + } + if err != nil { + return nil, fmt.Errorf("reading workload owner key: %w", err) + } + pemBlock, _ := pem.Decode(key) + if pemBlock == nil { + return nil, fmt.Errorf("decoding workload owner key: %w", err) + } + if pemBlock.Type != "EC PRIVATE KEY" { + return nil, fmt.Errorf("workload owner key is not an EC private key") + } + workloadOwnerKey, err := x509.ParseECPrivateKey(pemBlock.Bytes) + if err != nil { + return nil, fmt.Errorf("parsing workload owner key: %w", err) + } + pubKey, err := x509.MarshalPKIXPublicKey(&workloadOwnerKey.PublicKey) + if err != nil { + return nil, fmt.Errorf("marshaling public key: %w", err) + } + ownerKeyHash := sha256.Sum256(pubKey) + ownerKeyHex := manifest.NewHexString(ownerKeyHash[:]) + if len(manifst.WorkloadOwnerKeyDigests) == 0 { + log.Warn("No workload owner keys in manifest. Further manifest updates will be rejected by the coordinator") + return workloadOwnerKey, nil + } + log.Debug("Workload owner keys in manifest", "keys", manifst.WorkloadOwnerKeyDigests) + if !slices.Contains(manifst.WorkloadOwnerKeyDigests, ownerKeyHex) { + log.Warn("Workload owner key not found in manifest. This may lock you out from further updates") + } + return workloadOwnerKey, nil +} + func setLoop( ctx context.Context, client coordapi.CoordAPIClient, out io.Writer, req *coordapi.SetManifestRequest, ) (resp *coordapi.SetManifestResponse, retErr error) { @@ -178,6 +245,14 @@ func setLoop( if rpcErr == nil { return resp, nil } + grpcSt, ok := status.FromError(rpcErr) + if ok { + switch grpcSt.Code() { + case codes.PermissionDenied, codes.InvalidArgument: + // These errors are not retryable + return nil, rpcErr + } + } timer := time.NewTimer(1 * time.Second) select { case <-ctx.Done(): diff --git a/coordinator/coordapi.go b/coordinator/coordapi.go index d35392e878..fcf573eb09 100644 --- a/coordinator/coordapi.go +++ b/coordinator/coordapi.go @@ -1,13 +1,19 @@ package main import ( + "bytes" "context" + "crypto/sha256" + "crypto/x509" "encoding/json" + "errors" "fmt" "log/slog" "net" + "sync" "time" + "github.com/edgelesssys/nunki/internal/appendable" "github.com/edgelesssys/nunki/internal/attestation/snp" "github.com/edgelesssys/nunki/internal/coordapi" "github.com/edgelesssys/nunki/internal/grpc/atlscredentials" @@ -16,7 +22,9 @@ import ( "github.com/edgelesssys/nunki/internal/memstore" "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" "google.golang.org/grpc/keepalive" + "google.golang.org/grpc/peer" "google.golang.org/grpc/status" ) @@ -26,6 +34,7 @@ type coordAPIServer struct { manifSetGetter manifestSetGetter caChainGetter certChainGetter logger *slog.Logger + mux sync.RWMutex coordapi.UnimplementedCoordAPIServer } @@ -56,9 +65,16 @@ func (s *coordAPIServer) Serve(endpoint string) error { return s.grpc.Serve(lis) } -func (s *coordAPIServer) SetManifest(_ context.Context, req *coordapi.SetManifestRequest, +func (s *coordAPIServer) SetManifest(ctx context.Context, req *coordapi.SetManifestRequest, ) (*coordapi.SetManifestResponse, error) { s.logger.Info("SetManifest called") + s.mux.Lock() + defer s.mux.Unlock() + + if err := s.validatePeer(ctx); err != nil { + s.logger.Warn("SetManifest peer validation failed", "err", err) + return nil, status.Errorf(codes.PermissionDenied, "validating peer: %v", err) + } var m *manifest.Manifest if err := json.Unmarshal(req.Manifest, &m); err != nil { @@ -96,6 +112,8 @@ func (s *coordAPIServer) SetManifest(_ context.Context, req *coordapi.SetManifes func (s *coordAPIServer) GetManifests(_ context.Context, _ *coordapi.GetManifestsRequest, ) (*coordapi.GetManifestsResponse, error) { s.logger.Info("GetManifest called") + s.mux.RLock() + defer s.mux.RUnlock() manifests := s.manifSetGetter.GetManifests() if len(manifests) == 0 { @@ -123,6 +141,54 @@ func (s *coordAPIServer) GetManifests(_ context.Context, _ *coordapi.GetManifest return resp, nil } +func (s *coordAPIServer) validatePeer(ctx context.Context) error { + latest, err := s.manifSetGetter.LatestManifest() + if err != nil && errors.Is(err, appendable.ErrIsEmpty) { + // in the initial state, no peer validation is required + return nil + } + if err != nil { + return fmt.Errorf("getting latest manifest: %w", err) + } + if len(latest.WorkloadOwnerKeyDigests) == 0 { + return errors.New("setting manifest is disabled") + } + + peerPubKey, err := getPeerPublicKey(ctx) + if err != nil { + return err + } + peerPub256Sum := sha256.Sum256(peerPubKey) + for _, key := range latest.WorkloadOwnerKeyDigests { + trustedWorkloadOwnerSHA256, err := key.Bytes() + if err != nil { + return fmt.Errorf("parsing key: %w", err) + } + if bytes.Equal(peerPub256Sum[:], trustedWorkloadOwnerSHA256) { + return nil + } + } + return errors.New("peer not authorized workload owner") +} + +func getPeerPublicKey(ctx context.Context) ([]byte, error) { + peer, ok := peer.FromContext(ctx) + if !ok { + return nil, errors.New("no peer found in context") + } + tlsInfo, ok := peer.AuthInfo.(credentials.TLSInfo) + if !ok { + return nil, errors.New("peer auth info is not of type TLSInfo") + } + if len(tlsInfo.State.PeerCertificates) == 0 || tlsInfo.State.PeerCertificates[0] == nil { + return nil, errors.New("no peer certificates found") + } + if tlsInfo.State.PeerCertificates[0].PublicKeyAlgorithm != x509.ECDSA { + return nil, errors.New("peer public key is not of type ECDSA") + } + return x509.MarshalPKIXPublicKey(tlsInfo.State.PeerCertificates[0].PublicKey) +} + func policySliceToBytesSlice(s []manifest.Policy) [][]byte { var policies [][]byte for _, policy := range s { @@ -152,6 +218,7 @@ type certChainGetter interface { type manifestSetGetter interface { SetManifest(*manifest.Manifest) error GetManifests() []*manifest.Manifest + LatestManifest() (*manifest.Manifest, error) } type store[keyT comparable, valueT any] interface { diff --git a/coordinator/coordapi_test.go b/coordinator/coordapi_test.go index f275a6c577..c31bb4bee3 100644 --- a/coordinator/coordapi_test.go +++ b/coordinator/coordapi_test.go @@ -2,16 +2,25 @@ package main import ( "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/sha256" + "crypto/tls" + "crypto/x509" "encoding/json" "log/slog" "sync" "testing" + "github.com/edgelesssys/nunki/internal/appendable" "github.com/edgelesssys/nunki/internal/coordapi" "github.com/edgelesssys/nunki/internal/manifest" "github.com/edgelesssys/nunki/internal/memstore" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/peer" ) func TestManifestSet(t *testing.T) { @@ -27,16 +36,26 @@ func TestManifestSet(t *testing.T) { require.NoError(t, err) return b } + trustedKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader) + require.NoError(t, err) + untrustedKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader) + require.NoError(t, err) + manifestWithTrustedKey, err := manifestWithWorkloadOwnerKey(trustedKey) + require.NoError(t, err) + manifestWithoutTrustedKey, err := manifestWithWorkloadOwnerKey(nil) + require.NoError(t, err) testCases := map[string]struct { - req *coordapi.SetManifestRequest - mSGetter *stubManifestSetGetter - caGetter *stubCertChainGetter - wantErr bool + req *coordapi.SetManifestRequest + mSGetter *stubManifestSetGetter + caGetter *stubCertChainGetter + workloadOwnerKey *ecdsa.PrivateKey + wantErr bool }{ "empty request": { - req: &coordapi.SetManifestRequest{}, - wantErr: true, + req: &coordapi.SetManifestRequest{}, + mSGetter: &stubManifestSetGetter{}, + wantErr: true, }, "manifest without policies": { req: &coordapi.SetManifestRequest{ @@ -44,7 +63,8 @@ func TestManifestSet(t *testing.T) { m.Policies = nil }), }, - wantErr: true, + mSGetter: &stubManifestSetGetter{}, + wantErr: true, }, "request without policies": { req: &coordapi.SetManifestRequest{ @@ -55,7 +75,8 @@ func TestManifestSet(t *testing.T) { } }), }, - wantErr: true, + mSGetter: &stubManifestSetGetter{}, + wantErr: true, }, "policy not in manifest": { req: &coordapi.SetManifestRequest{ @@ -70,7 +91,8 @@ func TestManifestSet(t *testing.T) { []byte("c"), }, }, - wantErr: true, + mSGetter: &stubManifestSetGetter{}, + wantErr: true, }, "valid manifest": { req: &coordapi.SetManifestRequest{ @@ -105,6 +127,84 @@ func TestManifestSet(t *testing.T) { caGetter: &stubCertChainGetter{}, wantErr: true, }, + "workload owner key match": { + req: &coordapi.SetManifestRequest{ + Manifest: newManifestBytes(func(m *manifest.Manifest) { + m.Policies = map[manifest.HexString][]string{ + manifest.HexString("ca978112ca1bbdcafac231b39a23dc4da786eff8147c4e72b9807785afee48bb"): {"a1", "a2"}, + manifest.HexString("3e23e8160039594a33894f6564e1b1348bbd7a0088d42c4acb73eeaed59c009d"): {"b1", "b2"}, + } + }), + Policies: [][]byte{ + []byte("a"), + []byte("b"), + }, + }, + mSGetter: &stubManifestSetGetter{ + getManifestResp: []*manifest.Manifest{manifestWithTrustedKey}, + }, + caGetter: &stubCertChainGetter{}, + workloadOwnerKey: trustedKey, + }, + "workload owner key mismatch": { + req: &coordapi.SetManifestRequest{ + Manifest: newManifestBytes(func(m *manifest.Manifest) { + m.Policies = map[manifest.HexString][]string{ + manifest.HexString("ca978112ca1bbdcafac231b39a23dc4da786eff8147c4e72b9807785afee48bb"): {"a1", "a2"}, + manifest.HexString("3e23e8160039594a33894f6564e1b1348bbd7a0088d42c4acb73eeaed59c009d"): {"b1", "b2"}, + } + }), + Policies: [][]byte{ + []byte("a"), + []byte("b"), + }, + }, + mSGetter: &stubManifestSetGetter{ + getManifestResp: []*manifest.Manifest{manifestWithTrustedKey}, + }, + caGetter: &stubCertChainGetter{}, + workloadOwnerKey: untrustedKey, + wantErr: true, + }, + "workload owner key missing": { + req: &coordapi.SetManifestRequest{ + Manifest: newManifestBytes(func(m *manifest.Manifest) { + m.Policies = map[manifest.HexString][]string{ + manifest.HexString("ca978112ca1bbdcafac231b39a23dc4da786eff8147c4e72b9807785afee48bb"): {"a1", "a2"}, + manifest.HexString("3e23e8160039594a33894f6564e1b1348bbd7a0088d42c4acb73eeaed59c009d"): {"b1", "b2"}, + } + }), + Policies: [][]byte{ + []byte("a"), + []byte("b"), + }, + }, + mSGetter: &stubManifestSetGetter{ + getManifestResp: []*manifest.Manifest{manifestWithTrustedKey}, + }, + caGetter: &stubCertChainGetter{}, + wantErr: true, + }, + "manifest not updatable": { + req: &coordapi.SetManifestRequest{ + Manifest: newManifestBytes(func(m *manifest.Manifest) { + m.Policies = map[manifest.HexString][]string{ + manifest.HexString("ca978112ca1bbdcafac231b39a23dc4da786eff8147c4e72b9807785afee48bb"): {"a1", "a2"}, + manifest.HexString("3e23e8160039594a33894f6564e1b1348bbd7a0088d42c4acb73eeaed59c009d"): {"b1", "b2"}, + } + }), + Policies: [][]byte{ + []byte("a"), + []byte("b"), + }, + }, + mSGetter: &stubManifestSetGetter{ + getManifestResp: []*manifest.Manifest{manifestWithoutTrustedKey}, + }, + caGetter: &stubCertChainGetter{}, + workloadOwnerKey: trustedKey, + wantErr: true, + }, } for name, tc := range testCases { @@ -119,7 +219,7 @@ func TestManifestSet(t *testing.T) { logger: slog.Default(), } - ctx := context.Background() + ctx := rpcContext(tc.workloadOwnerKey) resp, err := coordinator.SetManifest(ctx, tc.req) if tc.wantErr { @@ -283,12 +383,51 @@ func (s *stubManifestSetGetter) GetManifests() []*manifest.Manifest { return s.getManifestResp } +func (s *stubManifestSetGetter) LatestManifest() (*manifest.Manifest, error) { + s.mux.RLock() + defer s.mux.RUnlock() + if len(s.getManifestResp) == 0 { + return nil, appendable.ErrIsEmpty + } + return s.getManifestResp[len(s.getManifestResp)-1], nil +} + type stubCertChainGetter struct{} func (s *stubCertChainGetter) GetRootCACert() []byte { return []byte("root") } func (s *stubCertChainGetter) GetMeshCACert() []byte { return []byte("mesh") } func (s *stubCertChainGetter) GetIntermCert() []byte { return []byte("inter") } +func rpcContext(key *ecdsa.PrivateKey) context.Context { + var peerCertificates []*x509.Certificate + if key != nil { + peerCertificates = []*x509.Certificate{{ + PublicKey: key.Public(), + PublicKeyAlgorithm: x509.ECDSA, + }} + } + return peer.NewContext(context.Background(), &peer.Peer{ + AuthInfo: credentials.TLSInfo{State: tls.ConnectionState{ + PeerCertificates: peerCertificates, + }}, + }) +} + +func manifestWithWorkloadOwnerKey(key *ecdsa.PrivateKey) (*manifest.Manifest, error) { + m := manifest.Default() + if key == nil { + return &m, nil + } + pubKey, err := x509.MarshalPKIXPublicKey(&key.PublicKey) + if err != nil { + return nil, err + } + ownerKeyHash := sha256.Sum256(pubKey) + ownerKeyHex := manifest.NewHexString(ownerKeyHash[:]) + m.WorkloadOwnerKeyDigests = []manifest.HexString{ownerKeyHex} + return &m, nil +} + func toPtr[T any](t T) *T { return &t } diff --git a/coordinator/mesh.go b/coordinator/mesh.go index b2569cc57a..5d850fa568 100644 --- a/coordinator/mesh.go +++ b/coordinator/mesh.go @@ -140,6 +140,10 @@ func (m *meshAuthority) SetManifest(mnfst *manifest.Manifest) error { return nil } +func (m *meshAuthority) LatestManifest() (*manifest.Manifest, error) { + return m.manifests.Latest() +} + type appendableList[T any] interface { Append(T) All() []T diff --git a/internal/appendable/appendable.go b/internal/appendable/appendable.go index dd97727883..04e3e90772 100644 --- a/internal/appendable/appendable.go +++ b/internal/appendable/appendable.go @@ -31,8 +31,11 @@ func (a *Appendable[T]) Latest() (T, error) { defer a.mux.RUnlock() if len(a.list) == 0 { - return *new(T), errors.New("appendable is empty") + return *new(T), ErrIsEmpty } return a.list[len(a.list)-1], nil } + +// ErrIsEmpty is returned when trying to get the latest value from an empty list. +var ErrIsEmpty = errors.New("appendable is empty") diff --git a/internal/atls/atls.go b/internal/atls/atls.go index 4d97199f3a..6b2e066ae7 100644 --- a/internal/atls/atls.go +++ b/internal/atls/atls.go @@ -117,19 +117,20 @@ func getATLSConfigForClientFunc(issuer Issuer, validators []Validator) (func(*tl } cfg := &tls.Config{ - VerifyPeerCertificate: serverConn.verify, - GetCertificate: serverConn.getCertificate, - MinVersion: tls.VersionTLS12, + GetCertificate: serverConn.getCertificate, + MinVersion: tls.VersionTLS12, + ClientAuth: tls.RequestClientCert, // request client certificate but don't require it + } + + // ugly hack: abuse acceptable client CAs as a channel to transmit the nonce + if cfg.ClientCAs, err = encodeNonceToCertPool(serverNonce, priv); err != nil { + return nil, fmt.Errorf("encode nonce: %w", err) } // enable mutual aTLS if any validators are set if len(validators) > 0 { cfg.ClientAuth = tls.RequireAnyClientCert // validity of certificate will be checked by our custom verify function - - // ugly hack: abuse acceptable client CAs as a channel to transmit the nonce - if cfg.ClientCAs, err = encodeNonceToCertPool(serverNonce, priv); err != nil { - return nil, fmt.Errorf("encode nonce: %w", err) - } + cfg.VerifyPeerCertificate = serverConn.verify } return cfg, nil diff --git a/internal/manifest/manifest.go b/internal/manifest/manifest.go index 50118d3f9c..e2a1030213 100644 --- a/internal/manifest/manifest.go +++ b/internal/manifest/manifest.go @@ -12,8 +12,9 @@ import ( // Manifest is the Coordinator manifest and contains the reference values of the deployment. type Manifest struct { // policyHash/HOSTDATA -> commonName - Policies map[HexString][]string - ReferenceValues ReferenceValues + Policies map[HexString][]string + ReferenceValues ReferenceValues + WorkloadOwnerKeyDigests []HexString } // ReferenceValues contains the workload independent reference values.