diff --git a/.github/workflows/bump-elastic-stack-snapshot.yml b/.github/workflows/bump-elastic-stack-snapshot.yml index 564e1ef5f..a43be87df 100644 --- a/.github/workflows/bump-elastic-stack-snapshot.yml +++ b/.github/workflows/bump-elastic-stack-snapshot.yml @@ -34,5 +34,7 @@ jobs: vaultRoleId: ${{ secrets.VAULT_ROLE_ID }} vaultSecretId: ${{ secrets.VAULT_SECRET_ID }} pipeline: ./.ci/bump-elastic-stack-snapshot.yml + notifySlackChannel: "#fleet-notifications" + messageIfFailure: ":traffic_cone: updatecli failed for `${{ github.repository }}@${{ github.ref_name }}`, `@fleet_team` please look what's going on <${{ env.JOB_URL }}|here>" env: BRANCH: ${{ matrix.branch }} diff --git a/changelog/fragments/1702574888-Drain-HTTP-connections-on-shutdown.yaml b/changelog/fragments/1702574888-Drain-HTTP-connections-on-shutdown.yaml new file mode 100644 index 000000000..acfc0b9ab --- /dev/null +++ b/changelog/fragments/1702574888-Drain-HTTP-connections-on-shutdown.yaml @@ -0,0 +1,34 @@ +# Kind can be one of: +# - breaking-change: a change to previously-documented behavior +# - deprecation: functionality that is being removed in a later release +# - bug-fix: fixes a problem in a previous version +# - enhancement: extends functionality but does not break or fix existing behavior +# - feature: new functionality +# - known-issue: problems that we are aware of in a given version +# - security: impacts on the security of a product or a user’s deployment. +# - upgrade: important information for someone upgrading from a prior version +# - other: does not fit into any of the other categories +kind: enhancement + +# Change summary; a 80ish characters long description of the change. +summary: Drain HTTP connections on shutdown + +# Long description; in case the summary is not enough to describe the change +# this field accommodate a description without length limits. +# NOTE: This field will be rendered only for breaking-change and known-issue kinds at the moment. +description: | + Attempt to safely drain HTTP connections on shutdown by using the http server's Shutdown method. + Add a new timeout.Drain config attribute that how long the shutdown will wait (default 10s). + +# Affected component; a word indicating the component this changeset affects. +component: + +# PR URL; optional; the PR number that added the changeset. +# If not present is automatically filled by the tooling finding the PR where this changelog fragment has been added. +# NOTE: the tooling supports backports, so it's able to fill the original PR number instead of the backport PR number. +# Please provide it if you are adding a fragment for a different PR. +pr: 3165 + +# Issue URL; optional; the GitHub issue related to this changeset (either closes or is part of). +# If not present is automatically filled by the tooling with the issue linked to the PR number. +issue: 2902 diff --git a/dev-tools/integration/.env b/dev-tools/integration/.env index 2f164c804..fe7acfa9e 100644 --- a/dev-tools/integration/.env +++ b/dev-tools/integration/.env @@ -1,4 +1,6 @@ -ELASTICSEARCH_VERSION=8.13.0-SNAPSHOT +# If you use change this version without a pinned one, please update +# .ci/bump-elastic-stack-snapshot.yml or .github/workflows/bump-golang.yml +ELASTICSEARCH_VERSION=8.13.0-yufkxnwm-SNAPSHOT ELASTICSEARCH_USERNAME=elastic ELASTICSEARCH_PASSWORD=changeme TEST_ELASTICSEARCH_HOSTS=localhost:9200 diff --git a/fleet-server.reference.yml b/fleet-server.reference.yml index 67af0bfea..99f5936fc 100644 --- a/fleet-server.reference.yml +++ b/fleet-server.reference.yml @@ -122,6 +122,8 @@ fleet: # checkin_jitter: 30s # # checkin_max_poll is the maximum long_poll value a client can request. # checkin_max_poll: 1h +# # drain is the amount of time fleet-server will wait for HTTP connections to terminate on a shutdown signal before forcing all connections closed +# drain: 10s # # # profiler will bind Go's pprof endpoints to a new listener if enabled. # profiler: diff --git a/internal/pkg/api/handleCheckin.go b/internal/pkg/api/handleCheckin.go index b3f2b47ab..7d4685535 100644 --- a/internal/pkg/api/handleCheckin.go +++ b/internal/pkg/api/handleCheckin.go @@ -319,7 +319,16 @@ func (ct *CheckinT) ProcessRequest(zlog zerolog.Logger, w http.ResponseWriter, r for { select { case <-ctx.Done(): - span.End() + defer span.End() + // If the request context is canceled, the API server is shutting down. + // We want to immediately stop the long-poll and return a 200 with the ackToken and no actions. + if errors.Is(ctx.Err(), context.Canceled) { + resp := CheckinResponse{ + AckToken: &ackToken, + Action: "checkin", + } + return ct.writeResponse(zlog, w, r, agent, resp) + } return ctx.Err() case acdocs := <-actCh: var acs []Action diff --git a/internal/pkg/api/handleStatus.go b/internal/pkg/api/handleStatus.go index 5875dfe4b..b17711aeb 100644 --- a/internal/pkg/api/handleStatus.go +++ b/internal/pkg/api/handleStatus.go @@ -91,6 +91,12 @@ func (st StatusT) handleStatus(zlog zerolog.Logger, sm policy.SelfMonitor, bi bu span, _ = apm.StartSpan(r.Context(), "response", "write") defer span.End() + // If the request context has been cancelled, such as the case when the server is stopping we should return a 503 + // Note that the API server uses Shutdown, so no new requests should be accepted and this edge case will be rare. + if errors.Is(r.Context().Err(), context.Canceled) { + state = client.UnitStateStopping + } + data, err := json.Marshal(&resp) if err != nil { return err diff --git a/internal/pkg/api/server.go b/internal/pkg/api/server.go index 4b3e4e9da..2ec1efabc 100644 --- a/internal/pkg/api/server.go +++ b/internal/pkg/api/server.go @@ -8,6 +8,7 @@ import ( "context" "crypto/tls" "errors" + "fmt" slog "log" "net" "net/http" @@ -75,25 +76,7 @@ func (s *server) Run(ctx context.Context) error { ConnState: diagConn, } - forceCh := make(chan struct{}) - defer close(forceCh) - - // handler to close server - go func() { - select { - case <-ctx.Done(): - zerolog.Ctx(ctx).Debug().Msg("force server close on ctx.Done()") - err := srv.Close() - if err != nil { - zerolog.Ctx(ctx).Error().Err(err).Msg("error while closing server") - } - case <-forceCh: - zerolog.Ctx(ctx).Debug().Msg("go routine forced closed on exit") - } - }() - var listenCfg net.ListenConfig - ln, err := listenCfg.Listen(ctx, "tcp", s.addr) if err != nil { return err @@ -130,23 +113,29 @@ func (s *server) Run(ctx context.Context) error { zerolog.Ctx(ctx).Warn().Msg("Exposed over insecure HTTP; enablement of TLS is strongly recommended") } + // Start the API server on another goroutine and return any non ErrServerClosed errors through a channel. errCh := make(chan error) - baseCtx, cancel := context.WithCancel(ctx) - defer cancel() - go func(ctx context.Context, errCh chan error, ln net.Listener) { zerolog.Ctx(ctx).Info().Msgf("Listening on %s", s.addr) if err := srv.Serve(ln); err != nil && !errors.Is(err, http.ErrServerClosed) { errCh <- err } - }(baseCtx, errCh, ln) + }(ctx, errCh, ln) select { + // Listen and return any errors that occur from the server listener case err := <-errCh: if !errors.Is(err, context.Canceled) { - return err + return fmt.Errorf("error while serving API listener: %w", err) + } + // Do a clean shutdown if the context is cancelled + case <-ctx.Done(): + sCtx, cancel := context.WithTimeout(context.Background(), s.cfg.Timeouts.Drain) + defer cancel() + if err := srv.Shutdown(sCtx); err != nil { + cErr := srv.Close() // force it closed + return errors.Join(fmt.Errorf("error while shutting down api listener: %w", err), cErr) } - case <-baseCtx.Done(): } return nil diff --git a/internal/pkg/checkin/bulk.go b/internal/pkg/checkin/bulk.go index cab3caf3f..c097e3e4b 100644 --- a/internal/pkg/checkin/bulk.go +++ b/internal/pkg/checkin/bulk.go @@ -99,6 +99,7 @@ func (bc *Bulk) timestamp() string { // CheckIn will add the agent (identified by id) to the pending set. // The pending agents are sent to elasticsearch as a bulk update at each flush interval. +// NOTE: If Checkin is called after Run has returned it will just add the entry to the pending map and not do any operations, this may occur when the fleet-server is shutting down. // WARNING: Bulk will take ownership of fields, so do not use after passing in. func (bc *Bulk) CheckIn(id string, status string, message string, meta []byte, components []byte, seqno sqn.SeqNo, newVer string) error { // Separate out the extra data to minimize diff --git a/internal/pkg/config/config_test.go b/internal/pkg/config/config_test.go index 4e4552d57..c969a697f 100644 --- a/internal/pkg/config/config_test.go +++ b/internal/pkg/config/config_test.go @@ -119,6 +119,7 @@ func TestConfig(t *testing.T) { CheckinLongPoll: 5 * time.Minute, CheckinJitter: 30 * time.Second, CheckinMaxPoll: 10 * time.Minute, + Drain: 10 * time.Second, }, Profiler: ServerProfiler{ Enabled: false, diff --git a/internal/pkg/config/timeouts.go b/internal/pkg/config/timeouts.go index faf5860c9..86d75f7d5 100644 --- a/internal/pkg/config/timeouts.go +++ b/internal/pkg/config/timeouts.go @@ -18,6 +18,7 @@ type ServerTimeouts struct { CheckinLongPoll time.Duration `config:"checkin_long_poll"` CheckinJitter time.Duration `config:"checkin_jitter"` CheckinMaxPoll time.Duration `config:"checkin_max_poll"` + Drain time.Duration `config:"drain"` } // InitDefaults initializes the defaults for the configuration. @@ -64,4 +65,9 @@ func (c *ServerTimeouts) InitDefaults() { // The long poll value is poll_timeout-2m, and the request's write timeout is set to poll_timeout-1m // CheckinMaxPoll values of less then 1m are effectively ignored and a 1m limit is used. c.CheckinMaxPoll = time.Hour + + // Drain is the max duration that a server will keep connections open when a shutdown signal is received in order to gracefully handle in progress-requests. + // It is used as a context timeout value for server.ShutDown(ctx). + // A long-poll checkin connection should immediately return with a 200 status and the same ackToken it was sent, the same as if the long-poll completed with no changes detected. + c.Drain = 10 * time.Second } diff --git a/internal/pkg/profile/profile.go b/internal/pkg/profile/profile.go index 931859480..1624d5159 100644 --- a/internal/pkg/profile/profile.go +++ b/internal/pkg/profile/profile.go @@ -7,6 +7,8 @@ package profile import ( "context" + "errors" + "fmt" "net" "net/http" "net/http/pprof" @@ -17,7 +19,6 @@ import ( // RunProfiler exposes /debug/pprof on the passed address by staring a server. func RunProfiler(ctx context.Context, addr string) error { - if addr == "" { zerolog.Ctx(ctx).Info().Msg("Profiler disabled") return nil @@ -47,10 +48,24 @@ func RunProfiler(ctx context.Context, addr string) error { } zerolog.Ctx(ctx).Info().Str("bind", addr).Msg("Installing profiler") - if err := server.ListenAndServe(); err != nil { + errCh := make(chan error) + go func() { + if err := server.ListenAndServe(); err != nil { + errCh <- err + } + }() + + select { + case err := <-errCh: zerolog.Ctx(ctx).Error().Err(err).Str("bind", addr).Msg("Fail install profiler") return err + case <-ctx.Done(): + sCtx, cancel := context.WithTimeout(context.Background(), cfg.Drain) + defer cancel() + if err := server.Shutdown(sCtx); err != nil { + cErr := server.Close() // force it closed + return errors.Join(fmt.Errorf("error while shutting down profile listener: %w", err), cErr) + } } - return nil } diff --git a/internal/pkg/profile/profile_test.go b/internal/pkg/profile/profile_test.go new file mode 100644 index 000000000..45b900528 --- /dev/null +++ b/internal/pkg/profile/profile_test.go @@ -0,0 +1,56 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + +package profile + +import ( + "context" + "net" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestRunProfiler(t *testing.T) { + ln, err := net.Listen("tcp", "localhost:8081") + if err != nil { + t.Skip("Port 8081 must be free to run this test") + } + ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + errCh := make(chan error) + + go func() { + errCh <- RunProfiler(ctx, "localhost:8081") + }() + + req, err := http.NewRequestWithContext(ctx, "GET", "http://localhost:8081/debug/pprof", nil) + require.NoError(t, err) + + var resp *http.Response + for i := 0; i < 10; i++ { + resp, err = http.DefaultClient.Do(req) //nolint:bodyclose // closed outside the loop + if err == nil { + break + } + t.Logf("profile request %d failed with: %v, retrying...", i, err) + time.Sleep(time.Millisecond * 200) + } + require.NoError(t, err) + defer resp.Body.Close() + require.Equal(t, http.StatusOK, resp.StatusCode) + + cancel() + + select { + case err := <-errCh: + require.NoError(t, err) + default: + } +} diff --git a/internal/pkg/server/fleet.go b/internal/pkg/server/fleet.go index 1fb239a9a..5a0a9f3d5 100644 --- a/internal/pkg/server/fleet.go +++ b/internal/pkg/server/fleet.go @@ -207,7 +207,7 @@ LOOP: // Server is coming down; wait for the server group to exit cleanly. // Timeout if something is locked up. - err = safeWait(srvEg, time.Second) + err = safeWait(srvEg, curCfg.Inputs[0].Server.Timeouts.Drain) // Eat cancel error to minimize confusion in logs if errors.Is(err, context.Canceled) { @@ -506,10 +506,10 @@ func (f *Fleet) runSubsystems(ctx context.Context, cfg *config.Config, g *errgro if err != nil { return err } - g.Go(loggedRunFunc(ctx, "Revision monitor", am.Run)) + g.Go(loggedRunFunc(ctx, "Action monitor", am.Run)) ad = action.NewDispatcher(am, cfg.Inputs[0].Server.Limits.ActionLimit.Interval, cfg.Inputs[0].Server.Limits.ActionLimit.Burst) - g.Go(loggedRunFunc(ctx, "Revision dispatcher", ad.Run)) + g.Go(loggedRunFunc(ctx, "Action dispatcher", ad.Run)) tr, err = action.NewTokenResolver(bulker) if err != nil { return err diff --git a/internal/pkg/server/fleet_integration_test.go b/internal/pkg/server/fleet_integration_test.go index 74f088f74..de14688d2 100644 --- a/internal/pkg/server/fleet_integration_test.go +++ b/internal/pkg/server/fleet_integration_test.go @@ -1168,3 +1168,116 @@ func Test_SmokeTest_CheckinPollTimeout(t *testing.T) { require.LessOrEqual(t, dur, 3*time.Minute) // include write timeout require.GreaterOrEqual(t, dur, time.Minute) } + +func Test_SmokeTest_CheckinPollShutdown(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Start test server + srv, err := startTestServer(t, ctx, policyData) + require.NoError(t, err) + ctx = testlog.SetLogger(t).WithContext(ctx) + + cli := cleanhttp.DefaultClient() + + // enroll an agent + t.Log("Enroll an agent") + req, err := http.NewRequestWithContext(ctx, "POST", srv.baseURL()+"/api/fleet/agents/enroll", strings.NewReader(enrollBody)) + require.NoError(t, err) + req.Header.Set("Authorization", "ApiKey "+srv.enrollKey) + req.Header.Set("User-Agent", "elastic agent "+serverVersion) + req.Header.Set("Content-Type", "application/json") + res, err := cli.Do(req) + require.NoError(t, err) + + require.Equal(t, http.StatusOK, res.StatusCode) + dec := json.NewDecoder(res.Body) + var enrollResponse api.EnrollResponse + err = dec.Decode(&enrollResponse) + res.Body.Close() + require.NoError(t, err) + agentID := enrollResponse.Item.Id + apiKey := enrollResponse.Item.AccessApiKey + + // checkin + t.Logf("checkin 1: agent %s no poll_timeout", agentID) + req, err = http.NewRequestWithContext(ctx, "POST", srv.baseURL()+"/api/fleet/agents/"+agentID+"/checkin", strings.NewReader(checkinBody)) + require.NoError(t, err) + req.Header.Set("Authorization", "ApiKey "+apiKey) + req.Header.Set("User-Agent", "elastic agent "+serverVersion) + req.Header.Set("Content-Type", "application/json") + start := time.Now() + res, err = cli.Do(req) + require.NoError(t, err) + t.Logf("checkin 1: agent %s took %s", agentID, time.Since(start)) + + require.Equal(t, http.StatusOK, res.StatusCode) + var checkinResponse api.CheckinResponse + dec = json.NewDecoder(res.Body) + err = dec.Decode(&checkinResponse) + res.Body.Close() + require.NoError(t, err) + + t.Logf("Ack actions for agent %s", agentID) + events := make([]api.AckRequest_Events_Item, 0, len(*checkinResponse.Actions)) + for _, action := range *checkinResponse.Actions { + event := api.GenericEvent{ + ActionId: action.Id, + AgentId: agentID, + Message: "test-message", + Type: api.ACTIONRESULT, + Subtype: api.ACKNOWLEDGED, + } + ev := api.AckRequest_Events_Item{} + err := ev.FromGenericEvent(event) + require.NoError(t, err) + events = append(events, ev) + } + p, err := json.Marshal(api.AckRequest{Events: events}) + require.NoError(t, err) + req, err = http.NewRequestWithContext(ctx, "POST", srv.baseURL()+"/api/fleet/agents/"+agentID+"/acks", bytes.NewBuffer(p)) + require.NoError(t, err) + req.Header.Set("Authorization", "ApiKey "+apiKey) + req.Header.Set("User-Agent", "elastic agent "+serverVersion) + req.Header.Set("Content-Type", "application/json") + res, err = cli.Do(req) + require.NoError(t, err) + res.Body.Close() + require.Equal(t, http.StatusOK, res.StatusCode) + + t.Logf("checkin 2: agent %s poll_timeout 3m server will shutdown after 10s", agentID) + //nolint:noctx // we want to halt the request via the server context cancelation + req, err = http.NewRequest("POST", srv.baseURL()+"/api/fleet/agents/"+agentID+"/checkin", strings.NewReader(fmt.Sprintf(`{ + "ack_token": "%s", + "status": "online", + "message": "", + "poll_timeout": "3m" + }`, *checkinResponse.AckToken))) + require.NoError(t, err) + req.Header.Set("Authorization", "ApiKey "+apiKey) + req.Header.Set("User-Agent", "elastic agent "+serverVersion) + req.Header.Set("Content-Type", "application/json") + start = time.Now() + + go func() { + time.Sleep(time.Second * 10) + t.Log("Shutting down server") + cancel() + }() + res, err = cli.Do(req) + require.NoError(t, err) + dur := time.Since(start) + t.Logf("checkin 2: agent %s took %s", agentID, time.Since(start)) + p, err = io.ReadAll(res.Body) + res.Body.Close() + require.NoError(t, err) + t.Logf("Response body: %s", string(p)) + t.Logf("Request duration: %s", dur) + require.Equal(t, http.StatusOK, res.StatusCode) + require.LessOrEqual(t, dur, 2*time.Minute) + require.GreaterOrEqual(t, dur, time.Second*10) + token := *checkinResponse.AckToken + err = json.Unmarshal(p, &checkinResponse) + require.NoError(t, err) + require.Equal(t, token, *checkinResponse.AckToken) +}