Skip to content

Commit

Permalink
Fix flaky cert tests (#3572) (#3577)
Browse files Browse the repository at this point in the history
(cherry picked from commit 763a7ad)

Co-authored-by: Michel Laterman <[email protected]>
  • Loading branch information
mergify[bot] and michel-laterman authored May 23, 2024
1 parent e6f64e6 commit d67d34a
Showing 1 changed file with 83 additions and 42 deletions.
125 changes: 83 additions & 42 deletions internal/pkg/api/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,8 @@ import (
"github.com/stretchr/testify/require"

fbuild "github.com/elastic/fleet-server/v7/internal/pkg/build"
"github.com/elastic/fleet-server/v7/internal/pkg/cache"
"github.com/elastic/fleet-server/v7/internal/pkg/checkin"
"github.com/elastic/fleet-server/v7/internal/pkg/config"
"github.com/elastic/fleet-server/v7/internal/pkg/monitor/mock"
"github.com/elastic/fleet-server/v7/internal/pkg/policy"
ftesting "github.com/elastic/fleet-server/v7/internal/pkg/testing"
"github.com/elastic/fleet-server/v7/internal/pkg/testing/certs"
testlog "github.com/elastic/fleet-server/v7/internal/pkg/testing/log"
Expand All @@ -46,40 +43,38 @@ func Test_server_Run(t *testing.T) {
cfg.Port = port
addr := cfg.BindEndpoints()[0]

verCon := mustBuildConstraints("8.0.0")
c, err := cache.New(config.Cache{NumCounters: 100, MaxCost: 100000})
require.NoError(t, err)
bulker := ftesting.NewMockBulk()
pim := mock.NewMockMonitor()
pm := policy.NewMonitor(bulker, pim, config.ServerLimits{PolicyLimit: config.Limit{Interval: 5 * time.Millisecond, Burst: 1}})
bc := checkin.NewBulk(nil)
ct := NewCheckinT(verCon, cfg, c, bc, pm, nil, nil, nil, nil)
et, err := NewEnrollerT(verCon, cfg, nil, c)
require.NoError(t, err)

srv := NewServer(addr, cfg, ct, et, nil, nil, nil, nil, fbuild.Info{}, nil, nil, nil, nil, nil)
errCh := make(chan error)
srv := NewServer(addr, cfg, nil, nil, nil, nil, nil, nil, fbuild.Info{}, nil, nil, nil, nil, nil)

started := make(chan struct{}, 1)
errCh := make(chan error, 1)
var wg sync.WaitGroup
wg.Add(1)
go func() {
if err := srv.Run(ctx); err != nil {
started <- struct{}{}
if err := srv.Run(ctx); err != nil && !errors.Is(err, context.Canceled) {
errCh <- err
}
wg.Done()
}()
var errFromChan error
select {

select { // if the goroutine has started within 500ms something is wrong, test has timed out
case <-started:
case <-time.After(500 * time.Millisecond):
require.Fail(t, "timed out waiting for server to start")
}
select { // check if there is an error in the 1st 500ms of the server running
case err := <-errCh:
errFromChan = err
require.NoError(t, err, "error during startup")
case <-time.After(500 * time.Millisecond):
break
}

cancel()
wg.Wait()
require.NoError(t, errFromChan)
if !errors.Is(err, http.ErrServerClosed) {
select {
case err := <-errCh:
require.NoError(t, err)
default:
}
}

Expand Down Expand Up @@ -130,7 +125,6 @@ func Test_server_ClientCert(t *testing.T) {

st := NewStatusT(cfg, nil, nil)
srv := NewServer(addr, cfg, nil, nil, nil, nil, st, sm, fbuild.Info{}, nil, nil, nil, nil, nil)
errCh := make(chan error)

// make http client with no client certs
certPool := x509.NewCertPool()
Expand All @@ -144,17 +138,29 @@ func Test_server_ClientCert(t *testing.T) {
}

started := make(chan struct{}, 1)
errCh := make(chan error, 1)
var wg sync.WaitGroup
wg.Add(1)
go func() {
started <- struct{}{}
if err := srv.Run(ctx); err != nil {
if err := srv.Run(ctx); err != nil && !errors.Is(err, context.Canceled) {
errCh <- err
}
wg.Done()
}()

<-started
select { // make sure goroutine starts within 500ms
case <-started:
case <-time.After(500 * time.Millisecond):
require.Fail(t, "timed out waiting for server to start")
}
select { // make sure there are no errors within 500ms of api server running
case err := <-errCh:
require.NoError(t, err, "error during startup")
case <-time.After(500 * time.Millisecond):
break
}

rCtx, rCancel := context.WithTimeout(ctx, time.Second)
defer rCancel()
req, err := http.NewRequestWithContext(rCtx, "GET", "https://"+addr+"/api/status", nil)
Expand All @@ -164,13 +170,13 @@ func Test_server_ClientCert(t *testing.T) {
resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)

cancel()
wg.Wait()
select {
case err := <-errCh:
require.NoError(t, err)
default:
}
cancel()
wg.Wait()
})

t.Run("valid client certs", func(t *testing.T) {
Expand All @@ -189,7 +195,6 @@ func Test_server_ClientCert(t *testing.T) {

st := NewStatusT(cfg, nil, nil)
srv := NewServer(addr, cfg, nil, nil, nil, nil, st, sm, fbuild.Info{}, nil, nil, nil, nil, nil)
errCh := make(chan error)

// make http client with valid client certs
clientCert := certs.GenCert(t, ca)
Expand All @@ -205,17 +210,29 @@ func Test_server_ClientCert(t *testing.T) {
}

started := make(chan struct{}, 1)
errCh := make(chan error, 1)
var wg sync.WaitGroup
wg.Add(1)
go func() {
started <- struct{}{}
if err := srv.Run(ctx); err != nil {
if err := srv.Run(ctx); err != nil && !errors.Is(err, context.Canceled) {
errCh <- err
}
wg.Done()
}()

<-started
select {
case <-started:
case <-time.After(500 * time.Millisecond):
require.Fail(t, "timed out waiting for server to start")
}
select {
case err := <-errCh:
require.NoError(t, err, "error during startup")
case <-time.After(500 * time.Millisecond):
break
}

rCtx, rCancel := context.WithTimeout(ctx, time.Second)
defer rCancel()
req, err := http.NewRequestWithContext(rCtx, "GET", "https://"+addr+"/api/status", nil)
Expand All @@ -225,13 +242,13 @@ func Test_server_ClientCert(t *testing.T) {
resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)

cancel()
wg.Wait()
select {
case err := <-errCh:
require.NoError(t, err)
default:
}
cancel()
wg.Wait()
})

t.Run("invalid client certs", func(t *testing.T) {
Expand All @@ -250,7 +267,6 @@ func Test_server_ClientCert(t *testing.T) {

st := NewStatusT(cfg, nil, nil)
srv := NewServer(addr, cfg, nil, nil, nil, nil, st, sm, fbuild.Info{}, nil, nil, nil, nil, nil)
errCh := make(chan error)

// make http client with invalid client certs
clientCA := certs.GenCA(t)
Expand All @@ -267,35 +283,46 @@ func Test_server_ClientCert(t *testing.T) {
}

started := make(chan struct{}, 1)
errCh := make(chan error, 1)
var wg sync.WaitGroup
wg.Add(1)
go func() {
started <- struct{}{}
if err := srv.Run(ctx); err != nil {
if err := srv.Run(ctx); err != nil && !errors.Is(err, context.Canceled) {
errCh <- err
}
wg.Done()
}()

<-started
select {
case <-started:
case <-time.After(500 * time.Millisecond):
require.Fail(t, "timed out waiting for server to start")
}
select {
case err := <-errCh:
require.NoError(t, err, "error during startup")
case <-time.After(500 * time.Millisecond):
break
}

rCtx, rCancel := context.WithTimeout(ctx, time.Second)
defer rCancel()
req, err := http.NewRequestWithContext(rCtx, "GET", "https://"+addr+"/api/status", nil)
require.NoError(t, err)
_, err = httpClient.Do(req)
require.Error(t, err)

cancel()
wg.Wait()
select {
case err := <-errCh:
require.NoError(t, err)
default:
}
cancel()
wg.Wait()
})

t.Run("valid client certs no certs requested", func(t *testing.T) {
t.Skip("test is flakey see fleet-server/issue/3266")
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ctx = testlog.SetLogger(t).WithContext(ctx)
Expand Down Expand Up @@ -323,7 +350,6 @@ key: %s`,

st := NewStatusT(cfg, nil, nil)
srv := NewServer(addr, cfg, nil, nil, nil, nil, st, sm, fbuild.Info{}, nil, nil, nil, nil, nil)
errCh := make(chan error)

// make http client with valid client certs
clientCert := certs.GenCert(t, ca)
Expand All @@ -338,15 +364,30 @@ key: %s`,
},
}

started := make(chan struct{}, 1)
errCh := make(chan error, 1)
var wg sync.WaitGroup
wg.Add(1)
go func() {
if err := srv.Run(ctx); err != nil {
started <- struct{}{}
if err := srv.Run(ctx); err != nil && !errors.Is(err, context.Canceled) {
errCh <- err
}
wg.Done()
}()

select {
case <-started:
case <-time.After(500 * time.Millisecond):
require.Fail(t, "timed out waiting for server to start")
}
select {
case err := <-errCh:
require.NoError(t, err, "error during startup")
case <-time.After(500 * time.Millisecond):
break
}

rCtx, rCancel := context.WithTimeout(ctx, time.Second)
defer rCancel()
req, err := http.NewRequestWithContext(rCtx, "GET", "https://"+addr+"/api/status", nil)
Expand All @@ -356,12 +397,12 @@ key: %s`,
resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)

cancel()
wg.Wait()
select {
case err := <-errCh:
require.NoError(t, err)
default:
}
cancel()
wg.Wait()
})
}

0 comments on commit d67d34a

Please sign in to comment.