Skip to content

Commit

Permalink
[Standalone] Disallow upgrade if upgrade is already in progress
Browse files Browse the repository at this point in the history
  • Loading branch information
ycombinator committed Sep 25, 2023
1 parent 0c43005 commit b0a9962
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 0 deletions.
18 changes: 18 additions & 0 deletions internal/pkg/agent/cmd/upgrade.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (

"github.com/elastic/elastic-agent/pkg/control"
"github.com/elastic/elastic-agent/pkg/control/v2/client"
"github.com/elastic/elastic-agent/pkg/control/v2/cproto"

"github.com/spf13/cobra"

Expand Down Expand Up @@ -64,6 +65,14 @@ func upgradeCmd(streams *cli.IOStreams, cmd *cobra.Command, args []string) error
}
defer c.Disconnect()

isBeingUpgraded, err := isUpgradeInProgress(c)
if err != nil {
return fmt.Errorf("failed to check if upgrade is already in progress: %w", err)
}
if isBeingUpgraded {
return errors.New("an upgrade is already in progress; please try again later.")
}

skipVerification, _ := cmd.Flags().GetBool(flagSkipVerify)
var pgpChecks []string
if !skipVerification {
Expand Down Expand Up @@ -102,3 +111,12 @@ func upgradeCmd(streams *cli.IOStreams, cmd *cobra.Command, args []string) error
fmt.Fprintf(streams.Out, "Upgrade triggered to version %s, Elastic Agent is currently restarting\n", version)
return nil
}

func isUpgradeInProgress(c client.Client) (bool, error) {
state, err := c.State(context.Background())
if err != nil {
return false, fmt.Errorf("failed to get agent state: %w", err)
}

return state.State == cproto.State_UPGRADING, nil
}
96 changes: 96 additions & 0 deletions internal/pkg/agent/cmd/upgrade_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
// or more contributor license agreements. Licensed under the Elastic License;
// you may not use this file except in compliance with the Elastic License.

package cmd

import (
"context"
"errors"
"testing"

"github.com/stretchr/testify/require"

"github.com/elastic/elastic-agent/pkg/control/v2/client"

"github.com/elastic/elastic-agent/pkg/control/v2/cproto"
)

type mockClient struct {
stateErr string
state cproto.State
}

func (mc *mockClient) Connect(ctx context.Context) error { return nil }
func (mc *mockClient) Disconnect() {}
func (mc *mockClient) Version(ctx context.Context) (client.Version, error) {
return client.Version{}, nil
}
func (mc *mockClient) State(ctx context.Context) (*client.AgentState, error) {
if mc.stateErr != "" {
return nil, errors.New(mc.stateErr)
}

return &client.AgentState{State: mc.state}, nil
}
func (mc *mockClient) StateWatch(ctx context.Context) (client.ClientStateWatch, error) {
return nil, nil
}
func (mc *mockClient) Restart(ctx context.Context) error { return nil }
func (mc *mockClient) Upgrade(ctx context.Context, version string, sourceURI string, skipVerify bool, skipDefaultPgp bool, pgpBytes ...string) (string, error) {
return "", nil
}
func (mc *mockClient) DiagnosticAgent(ctx context.Context, additionalDiags []client.AdditionalMetrics) ([]client.DiagnosticFileResult, error) {
return nil, nil
}
func (mc *mockClient) DiagnosticUnits(ctx context.Context, units ...client.DiagnosticUnitRequest) ([]client.DiagnosticUnitResult, error) {
return nil, nil
}
func (mc *mockClient) DiagnosticComponents(ctx context.Context, additionalDiags []client.AdditionalMetrics, components ...client.DiagnosticComponentRequest) ([]client.DiagnosticComponentResult, error) {
return nil, nil
}
func (mc *mockClient) Configure(ctx context.Context, config string) error { return nil }

func TestIsUpgradeInProgress(t *testing.T) {
tests := map[string]struct {
state cproto.State
stateErr string

expected bool
expectedErr string
}{
"state_error": {
state: cproto.State_STARTING,
stateErr: "some error",

expected: false,
expectedErr: "failed to get agent state: some error",
},
"state_upgrading": {
state: cproto.State_UPGRADING,
stateErr: "",

expected: true,
expectedErr: "",
},
"state_healthy": {
state: cproto.State_HEALTHY,
stateErr: "",

expected: false,
expectedErr: "",
},
}

for name, test := range tests {
t.Run(name, func(t *testing.T) {
mc := mockClient{state: test.state, stateErr: test.stateErr}
inProgress, err := isUpgradeInProgress(&mc)
if test.expectedErr != "" {
require.Equal(t, test.expectedErr, err.Error())
} else {
require.Equal(t, test.expected, inProgress)
}
})
}
}

0 comments on commit b0a9962

Please sign in to comment.