diff --git a/coordinator/coordapi_test.go b/coordinator/coordapi_test.go index a4194ede1..c31bb4bee 100644 --- a/coordinator/coordapi_test.go +++ b/coordinator/coordapi_test.go @@ -2,6 +2,12 @@ package main import ( "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/sha256" + "crypto/tls" + "crypto/x509" "encoding/json" "log/slog" "sync" @@ -13,6 +19,8 @@ import ( "github.com/edgelesssys/nunki/internal/memstore" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/peer" ) func TestManifestSet(t *testing.T) { @@ -28,12 +36,21 @@ func TestManifestSet(t *testing.T) { require.NoError(t, err) return b } + trustedKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader) + require.NoError(t, err) + untrustedKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader) + require.NoError(t, err) + manifestWithTrustedKey, err := manifestWithWorkloadOwnerKey(trustedKey) + require.NoError(t, err) + manifestWithoutTrustedKey, err := manifestWithWorkloadOwnerKey(nil) + require.NoError(t, err) testCases := map[string]struct { - req *coordapi.SetManifestRequest - mSGetter *stubManifestSetGetter - caGetter *stubCertChainGetter - wantErr bool + req *coordapi.SetManifestRequest + mSGetter *stubManifestSetGetter + caGetter *stubCertChainGetter + workloadOwnerKey *ecdsa.PrivateKey + wantErr bool }{ "empty request": { req: &coordapi.SetManifestRequest{}, @@ -110,6 +127,84 @@ func TestManifestSet(t *testing.T) { caGetter: &stubCertChainGetter{}, wantErr: true, }, + "workload owner key match": { + req: &coordapi.SetManifestRequest{ + Manifest: newManifestBytes(func(m *manifest.Manifest) { + m.Policies = map[manifest.HexString][]string{ + manifest.HexString("ca978112ca1bbdcafac231b39a23dc4da786eff8147c4e72b9807785afee48bb"): {"a1", "a2"}, + manifest.HexString("3e23e8160039594a33894f6564e1b1348bbd7a0088d42c4acb73eeaed59c009d"): {"b1", "b2"}, + } + }), + Policies: [][]byte{ + []byte("a"), + []byte("b"), + }, + }, + mSGetter: &stubManifestSetGetter{ + getManifestResp: []*manifest.Manifest{manifestWithTrustedKey}, + }, + caGetter: &stubCertChainGetter{}, + workloadOwnerKey: trustedKey, + }, + "workload owner key mismatch": { + req: &coordapi.SetManifestRequest{ + Manifest: newManifestBytes(func(m *manifest.Manifest) { + m.Policies = map[manifest.HexString][]string{ + manifest.HexString("ca978112ca1bbdcafac231b39a23dc4da786eff8147c4e72b9807785afee48bb"): {"a1", "a2"}, + manifest.HexString("3e23e8160039594a33894f6564e1b1348bbd7a0088d42c4acb73eeaed59c009d"): {"b1", "b2"}, + } + }), + Policies: [][]byte{ + []byte("a"), + []byte("b"), + }, + }, + mSGetter: &stubManifestSetGetter{ + getManifestResp: []*manifest.Manifest{manifestWithTrustedKey}, + }, + caGetter: &stubCertChainGetter{}, + workloadOwnerKey: untrustedKey, + wantErr: true, + }, + "workload owner key missing": { + req: &coordapi.SetManifestRequest{ + Manifest: newManifestBytes(func(m *manifest.Manifest) { + m.Policies = map[manifest.HexString][]string{ + manifest.HexString("ca978112ca1bbdcafac231b39a23dc4da786eff8147c4e72b9807785afee48bb"): {"a1", "a2"}, + manifest.HexString("3e23e8160039594a33894f6564e1b1348bbd7a0088d42c4acb73eeaed59c009d"): {"b1", "b2"}, + } + }), + Policies: [][]byte{ + []byte("a"), + []byte("b"), + }, + }, + mSGetter: &stubManifestSetGetter{ + getManifestResp: []*manifest.Manifest{manifestWithTrustedKey}, + }, + caGetter: &stubCertChainGetter{}, + wantErr: true, + }, + "manifest not updatable": { + req: &coordapi.SetManifestRequest{ + Manifest: newManifestBytes(func(m *manifest.Manifest) { + m.Policies = map[manifest.HexString][]string{ + manifest.HexString("ca978112ca1bbdcafac231b39a23dc4da786eff8147c4e72b9807785afee48bb"): {"a1", "a2"}, + manifest.HexString("3e23e8160039594a33894f6564e1b1348bbd7a0088d42c4acb73eeaed59c009d"): {"b1", "b2"}, + } + }), + Policies: [][]byte{ + []byte("a"), + []byte("b"), + }, + }, + mSGetter: &stubManifestSetGetter{ + getManifestResp: []*manifest.Manifest{manifestWithoutTrustedKey}, + }, + caGetter: &stubCertChainGetter{}, + workloadOwnerKey: trustedKey, + wantErr: true, + }, } for name, tc := range testCases { @@ -124,7 +219,7 @@ func TestManifestSet(t *testing.T) { logger: slog.Default(), } - ctx := context.Background() + ctx := rpcContext(tc.workloadOwnerKey) resp, err := coordinator.SetManifest(ctx, tc.req) if tc.wantErr { @@ -303,6 +398,36 @@ func (s *stubCertChainGetter) GetRootCACert() []byte { return []byte("root") } func (s *stubCertChainGetter) GetMeshCACert() []byte { return []byte("mesh") } func (s *stubCertChainGetter) GetIntermCert() []byte { return []byte("inter") } +func rpcContext(key *ecdsa.PrivateKey) context.Context { + var peerCertificates []*x509.Certificate + if key != nil { + peerCertificates = []*x509.Certificate{{ + PublicKey: key.Public(), + PublicKeyAlgorithm: x509.ECDSA, + }} + } + return peer.NewContext(context.Background(), &peer.Peer{ + AuthInfo: credentials.TLSInfo{State: tls.ConnectionState{ + PeerCertificates: peerCertificates, + }}, + }) +} + +func manifestWithWorkloadOwnerKey(key *ecdsa.PrivateKey) (*manifest.Manifest, error) { + m := manifest.Default() + if key == nil { + return &m, nil + } + pubKey, err := x509.MarshalPKIXPublicKey(&key.PublicKey) + if err != nil { + return nil, err + } + ownerKeyHash := sha256.Sum256(pubKey) + ownerKeyHex := manifest.NewHexString(ownerKeyHash[:]) + m.WorkloadOwnerKeyDigests = []manifest.HexString{ownerKeyHex} + return &m, nil +} + func toPtr[T any](t T) *T { return &t }