diff --git a/internal/pkg/api/handleCheckin.go b/internal/pkg/api/handleCheckin.go index 8512a637c1..5ad76b8fec 100644 --- a/internal/pkg/api/handleCheckin.go +++ b/internal/pkg/api/handleCheckin.go @@ -302,7 +302,16 @@ func (ct *CheckinT) ProcessRequest(zlog zerolog.Logger, w http.ResponseWriter, r for { select { case <-ctx.Done(): - span.End() + defer span.End() + // If the request context is canceled, the API server is shutting down. + // We want to immediatly stop the long-poll and return a 200 with the ackToken and no actions. + if errors.Is(ctx.Err(), context.Canceled) { + resp := CheckinResponse{ + AckToken: &ackToken, + Action: "checkin", + } + return ct.writeResponse(zlog, w, r, agent, resp) + } return ctx.Err() case acdocs := <-actCh: var acs []Action diff --git a/internal/pkg/api/handleStatus.go b/internal/pkg/api/handleStatus.go index 5875dfe4be..b17711aeb1 100644 --- a/internal/pkg/api/handleStatus.go +++ b/internal/pkg/api/handleStatus.go @@ -91,6 +91,12 @@ func (st StatusT) handleStatus(zlog zerolog.Logger, sm policy.SelfMonitor, bi bu span, _ = apm.StartSpan(r.Context(), "response", "write") defer span.End() + // If the request context has been cancelled, such as the case when the server is stopping we should return a 503 + // Note that the API server uses Shutdown, so no new requests should be accepted and this edge case will be rare. + if errors.Is(r.Context().Err(), context.Canceled) { + state = client.UnitStateStopping + } + data, err := json.Marshal(&resp) if err != nil { return err diff --git a/internal/pkg/api/server.go b/internal/pkg/api/server.go index 4b3e4e9da5..f2bf4d7939 100644 --- a/internal/pkg/api/server.go +++ b/internal/pkg/api/server.go @@ -8,6 +8,7 @@ import ( "context" "crypto/tls" "errors" + "fmt" slog "log" "net" "net/http" @@ -75,25 +76,7 @@ func (s *server) Run(ctx context.Context) error { ConnState: diagConn, } - forceCh := make(chan struct{}) - defer close(forceCh) - - // handler to close server - go func() { - select { - case <-ctx.Done(): - zerolog.Ctx(ctx).Debug().Msg("force server close on ctx.Done()") - err := srv.Close() - if err != nil { - zerolog.Ctx(ctx).Error().Err(err).Msg("error while closing server") - } - case <-forceCh: - zerolog.Ctx(ctx).Debug().Msg("go routine forced closed on exit") - } - }() - var listenCfg net.ListenConfig - ln, err := listenCfg.Listen(ctx, "tcp", s.addr) if err != nil { return err @@ -130,23 +113,27 @@ func (s *server) Run(ctx context.Context) error { zerolog.Ctx(ctx).Warn().Msg("Exposed over insecure HTTP; enablement of TLS is strongly recommended") } + // Start the API server on another goroutine and return any non ErrServerClosed errors through a channel. errCh := make(chan error) - baseCtx, cancel := context.WithCancel(ctx) - defer cancel() - go func(ctx context.Context, errCh chan error, ln net.Listener) { zerolog.Ctx(ctx).Info().Msgf("Listening on %s", s.addr) if err := srv.Serve(ln); err != nil && !errors.Is(err, http.ErrServerClosed) { errCh <- err } - }(baseCtx, errCh, ln) + }(ctx, errCh, ln) select { + // Listen and return any errors that occur from the server listener case err := <-errCh: if !errors.Is(err, context.Canceled) { - return err + return fmt.Errorf("error while serving API listener: %w", err) + } + // Do a clean shutdown if the context is cancelled + // Note that is will wait for connections to close, we may want to use a new timeout value here. + case <-ctx.Done(): + if err := srv.Shutdown(context.TODO()); err != nil { + return fmt.Errorf("error while shutting down api listener: %w", err) } - case <-baseCtx.Done(): } return nil diff --git a/internal/pkg/checkin/bulk.go b/internal/pkg/checkin/bulk.go index cab3caf3f2..c097e3e4ba 100644 --- a/internal/pkg/checkin/bulk.go +++ b/internal/pkg/checkin/bulk.go @@ -99,6 +99,7 @@ func (bc *Bulk) timestamp() string { // CheckIn will add the agent (identified by id) to the pending set. // The pending agents are sent to elasticsearch as a bulk update at each flush interval. +// NOTE: If Checkin is called after Run has returned it will just add the entry to the pending map and not do any operations, this may occur when the fleet-server is shutting down. // WARNING: Bulk will take ownership of fields, so do not use after passing in. func (bc *Bulk) CheckIn(id string, status string, message string, meta []byte, components []byte, seqno sqn.SeqNo, newVer string) error { // Separate out the extra data to minimize diff --git a/internal/pkg/server/fleet.go b/internal/pkg/server/fleet.go index 95b08b9b7e..8100bb681a 100644 --- a/internal/pkg/server/fleet.go +++ b/internal/pkg/server/fleet.go @@ -205,6 +205,8 @@ LOOP: } } + // FIXME cancelling the context will break out of the above loop and wait 1s, we should instead have a cancel context or some way to ensure that the API servers have shut down correctly. + // Server is coming down; wait for the server group to exit cleanly. // Timeout if something is locked up. err = safeWait(srvEg, time.Second) @@ -376,6 +378,7 @@ func (f *Fleet) runServer(ctx context.Context, cfg *config.Config) (err error) { // unexpectedly (ie. not cancelled by the bulkCancel context). errCh := make(chan error) + // FIXME why not run this in the errgroup? go func() { runFunc := loggedRunFunc(bulkCtx, "Bulker", bulker.Run) @@ -505,10 +508,10 @@ func (f *Fleet) runSubsystems(ctx context.Context, cfg *config.Config, g *errgro if err != nil { return err } - g.Go(loggedRunFunc(ctx, "Revision monitor", am.Run)) + g.Go(loggedRunFunc(ctx, "Action monitor", am.Run)) ad = action.NewDispatcher(am, cfg.Inputs[0].Server.Limits.ActionLimit.Interval, cfg.Inputs[0].Server.Limits.ActionLimit.Burst) - g.Go(loggedRunFunc(ctx, "Revision dispatcher", ad.Run)) + g.Go(loggedRunFunc(ctx, "Action dispatcher", ad.Run)) tr, err = action.NewTokenResolver(bulker) if err != nil { return err