From b082669449eb58e73f3003e67ba7b2640b790468 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Wei=C3=9Fe?= Date: Thu, 19 Oct 2023 10:15:51 +0200 Subject: [PATCH] Rework skipPhases logic MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Daniel Weiße --- cli/internal/cmd/apply.go | 15 ++++++++------- cli/internal/cmd/apply_test.go | 2 +- cli/internal/cmd/upgradeapply.go | 18 ++++++++++++------ cli/internal/cmd/upgradeapply_test.go | 18 +++++++++++++----- 4 files changed, 34 insertions(+), 19 deletions(-) diff --git a/cli/internal/cmd/apply.go b/cli/internal/cmd/apply.go index 238b801d3c..937a5a1e1a 100644 --- a/cli/internal/cmd/apply.go +++ b/cli/internal/cmd/apply.go @@ -16,6 +16,7 @@ import ( "net" "os" "path/filepath" + "strings" "time" "github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd" @@ -60,11 +61,11 @@ func (f *applyFlags) parse(flags *pflag.FlagSet) error { if err != nil { return fmt.Errorf("getting 'skip-phases' flag: %w", err) } - var skipPhases []skipPhase + var skipPhases skipPhases for _, phase := range rawSkipPhases { - switch skipPhase(phase) { + switch skipPhase(strings.ToLower(phase)) { case skipInfrastructurePhase, skipHelmPhase, skipImagePhase, skipK8sPhase: - skipPhases = append(skipPhases, skipPhase(phase)) + skipPhases.add(skipPhase(phase)) default: return fmt.Errorf("invalid phase %s", phase) } @@ -350,7 +351,7 @@ func (a *applyCmd) validateInputs(cmd *cobra.Command, configFetcher attestationc // It is the user's responsibility to make sure the cluster is in a valid state a.log.Debugf("Checking if %s exists", a.flags.pathPrefixer.PrefixPrintablePath(constants.AdminConfFilename)) if _, err := a.fileHandler.Stat(constants.AdminConfFilename); err == nil { - a.flags.skipPhases = append(a.flags.skipPhases, skipInitPhase) + a.flags.skipPhases.add(skipInitPhase) } else if !errors.Is(err, os.ErrNotExist) { return nil, nil, fmt.Errorf("checking for %s: %w", a.flags.pathPrefixer.PrefixPrintablePath(constants.AdminConfFilename), err) } @@ -384,7 +385,7 @@ func (a *applyCmd) validateInputs(cmd *cobra.Command, configFetcher attestationc return nil, nil, fmt.Errorf("aborted by user") } } - a.flags.skipPhases = append(a.flags.skipPhases, skipK8sPhase) + a.flags.skipPhases.add(skipK8sPhase) a.log.Debugf("Outdated Kubernetes version accepted, Kubernetes upgrade will be skipped") } if versions.IsPreviewK8sVersion(validVersion) { @@ -409,14 +410,14 @@ func (a *applyCmd) validateInputs(cmd *cobra.Command, configFetcher attestationc } // Skip Terraform phase a.log.Debugf("Skipping Infrastructure upgrade") - a.flags.skipPhases = append(a.flags.skipPhases, skipInfrastructurePhase) + a.flags.skipPhases.add(skipInfrastructurePhase) } // Check if Terraform state exists if tfStateExists, err := a.tfStateExists(); err != nil { return nil, nil, fmt.Errorf("checking Terraform state: %w", err) } else if !tfStateExists { - a.flags.skipPhases = append(a.flags.skipPhases, skipInfrastructurePhase) + a.flags.skipPhases.add(skipInfrastructurePhase) a.log.Debugf("No Terraform state found in current working directory. Assuming self-managed infrastructure. Infrastructure upgrades will not be performed.") } diff --git a/cli/internal/cmd/apply_test.go b/cli/internal/cmd/apply_test.go index 65953ca63e..9218cd3cae 100644 --- a/cli/internal/cmd/apply_test.go +++ b/cli/internal/cmd/apply_test.go @@ -56,7 +56,7 @@ func TestParseApplyFlags(t *testing.T) { return flags }(), wantFlags: applyFlags{ - skipPhases: []skipPhase{skipHelmPhase, skipK8sPhase}, + skipPhases: skipPhases{skipHelmPhase: struct{}{}, skipK8sPhase: struct{}{}}, helmWaitMode: helm.WaitModeAtomic, }, }, diff --git a/cli/internal/cmd/upgradeapply.go b/cli/internal/cmd/upgradeapply.go index 9f98f8da2f..b9e9faf21c 100644 --- a/cli/internal/cmd/upgradeapply.go +++ b/cli/internal/cmd/upgradeapply.go @@ -82,16 +82,22 @@ func diffAttestationCfg(currentAttestationCfg config.AttestationCfg, newAttestat } // skipPhases is a list of phases that can be skipped during the upgrade process. -type skipPhases []skipPhase +type skipPhases map[skipPhase]struct{} // contains returns true if the list of phases contains the given phase. func (s skipPhases) contains(phase skipPhase) bool { - for _, p := range s { - if strings.EqualFold(string(p), string(phase)) { - return true - } + _, ok := s[skipPhase(strings.ToLower(string(phase)))] + return ok +} + +// add a phase to the list of phases. +func (s *skipPhases) add(phases ...skipPhase) { + if *s == nil { + *s = make(skipPhases) + } + for _, phase := range phases { + (*s)[skipPhase(strings.ToLower(string(phase)))] = struct{}{} } - return false } type kubernetesUpgrader interface { diff --git a/cli/internal/cmd/upgradeapply_test.go b/cli/internal/cmd/upgradeapply_test.go index 4e8ecf59f2..227c141874 100644 --- a/cli/internal/cmd/upgradeapply_test.go +++ b/cli/internal/cmd/upgradeapply_test.go @@ -188,8 +188,11 @@ func TestUpgradeApply(t *testing.T) { helmUpgrader: &mockApplier{}, // mocks ensure that no methods are called terraformUpgrader: &mockTerraformUpgrader{}, flags: applyFlags{ - skipPhases: []skipPhase{skipInfrastructurePhase, skipHelmPhase, skipK8sPhase, skipImagePhase}, - yes: true, + skipPhases: skipPhases{ + skipInfrastructurePhase: struct{}{}, skipHelmPhase: struct{}{}, + skipK8sPhase: struct{}{}, skipImagePhase: struct{}{}, + }, + yes: true, }, fh: fsWithStateFileAndTfState, }, @@ -200,8 +203,11 @@ func TestUpgradeApply(t *testing.T) { helmUpgrader: &mockApplier{}, // mocks ensure that no methods are called terraformUpgrader: &mockTerraformUpgrader{}, flags: applyFlags{ - skipPhases: []skipPhase{skipInfrastructurePhase, skipHelmPhase, skipK8sPhase}, - yes: true, + skipPhases: skipPhases{ + skipInfrastructurePhase: struct{}{}, skipHelmPhase: struct{}{}, + skipK8sPhase: struct{}{}, + }, + yes: true, }, fh: fsWithStateFileAndTfState, }, @@ -288,11 +294,13 @@ func TestUpgradeApplyFlagsForSkipPhases(t *testing.T) { cmd.Flags().Bool("merge-kubeconfig", false, "") require.NoError(cmd.Flags().Set("skip-phases", "infrastructure,helm,k8s,image")) + wantPhases := skipPhases{} + wantPhases.add(skipInfrastructurePhase, skipHelmPhase, skipK8sPhase, skipImagePhase) var flags applyFlags err := flags.parse(cmd.Flags()) require.NoError(err) - assert.ElementsMatch(t, []skipPhase{skipInfrastructurePhase, skipHelmPhase, skipK8sPhase, skipImagePhase}, flags.skipPhases) + assert.Equal(t, wantPhases, flags.skipPhases) } type stubKubernetesUpgrader struct {