Skip to content

Commit

Permalink
coordinator: add coordAPI unit tests
Browse files Browse the repository at this point in the history
Signed-off-by: Paul Meyer <[email protected]>
  • Loading branch information
katexochen committed Jan 19, 2024
1 parent 6fcbbc6 commit 99958ac
Show file tree
Hide file tree
Showing 2 changed files with 289 additions and 2 deletions.
16 changes: 14 additions & 2 deletions coordinator/coordapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ func (s *coordAPIServer) SetManifest(_ context.Context, req *coordapi.SetManifes
return nil, status.Errorf(codes.InvalidArgument, "unmarshaling manifest: %v", err)
}

if len(m.Policies) == 0 {
return nil, status.Error(codes.InvalidArgument, "manifest must contain at least one policy")
}
if len(m.Policies) != len(req.Policies) {
return nil, status.Error(codes.InvalidArgument, "request must contain exactly the policies referenced in the manifest")
}

for _, policyBytes := range req.Policies {
policy := manifest.Policy(policyBytes)
if _, ok := m.Policies[policy.Hash()]; !ok {
Expand All @@ -89,17 +96,22 @@ func (s *coordAPIServer) GetManifests(_ context.Context, _ *coordapi.GetManifest

manifests := s.manifSetGetter.GetManifests()
if len(manifests) == 0 {
return nil, status.Errorf(codes.FailedPrecondition, "no manifests found")
return nil, status.Errorf(codes.FailedPrecondition, "no manifests set")
}

manifestBytes, err := manifestSliceToBytesSlice(manifests)
if err != nil {
return nil, status.Errorf(codes.Internal, "marshaling manifests: %v", err)
}

policies := s.policyTextStore.GetAll()
if len(policies) == 0 {
return nil, status.Error(codes.Internal, "no policies found in store")
}

resp := &coordapi.GetManifestsResponse{
Manifests: manifestBytes,
Policies: policySliceToBytesSlice(s.policyTextStore.GetAll()),
Policies: policySliceToBytesSlice(policies),
CACert: s.caChainGetter.GetRootCACert(),
IntermCert: s.caChainGetter.GetIntermCert(),
}
Expand Down
275 changes: 275 additions & 0 deletions coordinator/coordapi_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
package main

import (
"context"
"encoding/json"
"log/slog"
"sync"
"testing"

"github.com/edgelesssys/nunki/internal/coordapi"
"github.com/edgelesssys/nunki/internal/manifest"
"github.com/edgelesssys/nunki/internal/memstore"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestManifestSet(t *testing.T) {
newBaseManifest := func() manifest.Manifest {
return manifest.Default()
}
newManifestBytes := func(f func(*manifest.Manifest)) []byte {
m := newBaseManifest()
if f != nil {
f(&m)
}
b, err := json.Marshal(m)
require.NoError(t, err)
return b
}

testCases := map[string]struct {
req *coordapi.SetManifestRequest
mSGetter *stubManifestSetGetter
caGetter *stubCertChainGetter
wantErr bool
}{
"empty request": {
req: &coordapi.SetManifestRequest{},
wantErr: true,
},
"manifest without policies": {
req: &coordapi.SetManifestRequest{
Manifest: newManifestBytes(func(m *manifest.Manifest) {
m.Policies = nil
}),
},
wantErr: true,
},
"request without policies": {
req: &coordapi.SetManifestRequest{
Manifest: newManifestBytes(func(m *manifest.Manifest) {
m.Policies = map[manifest.HexString][]string{
manifest.HexString("a"): {"a1", "a2"},
manifest.HexString("b"): {"b1", "b2"},
}
}),
},
wantErr: true,
},
"policy not in manifest": {
req: &coordapi.SetManifestRequest{
Manifest: newManifestBytes(func(m *manifest.Manifest) {
m.Policies = map[manifest.HexString][]string{
manifest.HexString("ca978112ca1bbdcafac231b39a23dc4da786eff8147c4e72b9807785afee48bb"): {"a1", "a2"},
manifest.HexString("3e23e8160039594a33894f6564e1b1348bbd7a0088d42c4acb73eeaed59c009d"): {"b1", "b2"},
}
}),
Policies: [][]byte{
[]byte("a"),
[]byte("c"),
},
},
wantErr: true,
},
"valid manifest": {
req: &coordapi.SetManifestRequest{
Manifest: newManifestBytes(func(m *manifest.Manifest) {
m.Policies = map[manifest.HexString][]string{
manifest.HexString("ca978112ca1bbdcafac231b39a23dc4da786eff8147c4e72b9807785afee48bb"): {"a1", "a2"},
manifest.HexString("3e23e8160039594a33894f6564e1b1348bbd7a0088d42c4acb73eeaed59c009d"): {"b1", "b2"},
}
}),
Policies: [][]byte{
[]byte("a"),
[]byte("b"),
},
},
mSGetter: &stubManifestSetGetter{},
caGetter: &stubCertChainGetter{},
},
}

for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)

coordinator := coordAPIServer{
manifSetGetter: tc.mSGetter,
caChainGetter: tc.caGetter,
policyTextStore: memstore.New[manifest.HexString, manifest.Policy](),
logger: slog.Default(),
}

ctx := context.Background()
resp, err := coordinator.SetManifest(ctx, tc.req)

if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
assert.Equal([]byte("root"), resp.CACert)
assert.Equal([]byte("inter"), resp.IntermCert)
assert.Equal(1, tc.mSGetter.setManifestCount)
})
}
}

func TestGetManifests(t *testing.T) {
testCases := map[string]struct {
mSGetter *stubManifestSetGetter
caGetter *stubCertChainGetter
policyStoreContent map[manifest.HexString]manifest.Policy
wantErr bool
}{
"no manifest set": {
mSGetter: &stubManifestSetGetter{},
caGetter: &stubCertChainGetter{},
wantErr: true,
},
"no policy in store": {
mSGetter: &stubManifestSetGetter{
getManifestResp: []*manifest.Manifest{
toPtr(manifest.Default()),
toPtr(manifest.Default()),
},
},
wantErr: true,
},
"one manifest set": {
mSGetter: &stubManifestSetGetter{
getManifestResp: []*manifest.Manifest{
toPtr(manifest.Default()),
toPtr(manifest.Default()),
},
},
policyStoreContent: map[manifest.HexString]manifest.Policy{
manifest.HexString("a"): {},
},
},
}

for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)

policyStore := memstore.New[manifest.HexString, manifest.Policy]()
for k, v := range tc.policyStoreContent {
policyStore.Set(k, v)
}

coordinator := coordAPIServer{
manifSetGetter: tc.mSGetter,
caChainGetter: tc.caGetter,
policyTextStore: policyStore,
logger: slog.Default(),
}

ctx := context.Background()
resp, err := coordinator.GetManifests(ctx, &coordapi.GetManifestsRequest{})

if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
assert.Equal([]byte("root"), resp.CACert)
assert.Equal([]byte("inter"), resp.IntermCert)
assert.Len(resp.Policies, len(tc.policyStoreContent))
})
}
}

// TestCoordAPIConcurrent tests potential synchronization problems between the different
// gRPCs of the server.
func TestCoordAPIConcurrent(t *testing.T) {
newBaseManifest := func() manifest.Manifest {
return manifest.Default()
}
newManifestBytes := func(f func(*manifest.Manifest)) []byte {
m := newBaseManifest()
if f != nil {
f(&m)
}
b, err := json.Marshal(m)
require.NoError(t, err)
return b
}

coordinator := coordAPIServer{
manifSetGetter: &stubManifestSetGetter{},
caChainGetter: &stubCertChainGetter{},
policyTextStore: memstore.New[manifest.HexString, manifest.Policy](),
logger: slog.Default(),
}
setReq := &coordapi.SetManifestRequest{
Manifest: newManifestBytes(func(m *manifest.Manifest) {
m.Policies = map[manifest.HexString][]string{
manifest.HexString("ca978112ca1bbdcafac231b39a23dc4da786eff8147c4e72b9807785afee48bb"): {"a1", "a2"},
manifest.HexString("3e23e8160039594a33894f6564e1b1348bbd7a0088d42c4acb73eeaed59c009d"): {"b1", "b2"},
}
}),
Policies: [][]byte{
[]byte("a"),
[]byte("c"),
},
}

ctx := context.Background()
wg := sync.WaitGroup{}

set := func() {
defer wg.Done()
_, _ = coordinator.SetManifest(ctx, setReq)
}
get := func() {
defer wg.Done()
_, _ = coordinator.GetManifests(ctx, &coordapi.GetManifestsRequest{})
}

wg.Add(12)
go set()
go set()
go set()
go get()
go get()
go get()
go set()
go set()
go set()
go get()
go get()
go get()
wg.Wait()
}

type stubManifestSetGetter struct {
mux sync.RWMutex
setManifestCount int
getManifestResp []*manifest.Manifest
}

func (s *stubManifestSetGetter) SetManifest(*manifest.Manifest) {
s.mux.Lock()
defer s.mux.Unlock()
s.setManifestCount++
}

func (s *stubManifestSetGetter) GetManifests() []*manifest.Manifest {
s.mux.RLock()
defer s.mux.RUnlock()
return s.getManifestResp
}

type stubCertChainGetter struct{}

func (s *stubCertChainGetter) GetRootCACert() []byte { return []byte("root") }
func (s *stubCertChainGetter) GetMeshCACert() []byte { return []byte("mesh") }
func (s *stubCertChainGetter) GetIntermCert() []byte { return []byte("inter") }

func toPtr[T any](t T) *T {
return &t
}

0 comments on commit 99958ac

Please sign in to comment.