Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cli: fix upgrades when using outdated Kubernetes patch version #2718

Merged
merged 2 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions cli/internal/cmd/apply.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import (
"github.com/spf13/afero"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
xsemver "golang.org/x/mod/semver"
apiextensionsv1 "k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1"
k8serrors "k8s.io/apimachinery/pkg/api/errors"
)
Expand Down Expand Up @@ -545,9 +546,19 @@ func (a *applyCmd) validateInputs(cmd *cobra.Command, configFetcher attestationc
return nil, nil, fmt.Errorf("aborted by user")
}
}

a.flags.skipPhases.add(skipK8sPhase)
a.log.Debugf("Outdated Kubernetes version accepted, Kubernetes upgrade will be skipped")
}

validVersionString, err := versions.ResolveK8sPatchVersion(xsemver.MajorMinor(string(conf.KubernetesVersion)))
if err != nil {
return nil, nil, fmt.Errorf("resolving Kubernetes patch version: %w", err)
}
validVersion, err = versions.NewValidK8sVersion(validVersionString, true)
if err != nil {
return nil, nil, fmt.Errorf("parsing Kubernetes version: %w", err)
}
}
if versions.IsPreviewK8sVersion(validVersion) {
cmd.PrintErrf("Warning: Constellation with Kubernetes %s is still in preview. Use only for evaluation purposes.\n", validVersion)
Expand Down
30 changes: 29 additions & 1 deletion cli/internal/cmd/apply_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"github.com/edgelesssys/constellation/v2/internal/file"
"github.com/edgelesssys/constellation/v2/internal/kms/uri"
"github.com/edgelesssys/constellation/v2/internal/logger"
"github.com/edgelesssys/constellation/v2/internal/versions"
"github.com/spf13/afero"
"github.com/spf13/pflag"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -291,6 +292,7 @@ func TestValidateInputs(t *testing.T) {
stdin string
flags applyFlags
wantPhases skipPhases
assert func(require *require.Assertions, assert *assert.Assertions, conf *config.Config, stateFile *state.State)
wantErr bool
}{
"[upgrade] gcp: all files exist": {
Expand Down Expand Up @@ -396,6 +398,28 @@ func TestValidateInputs(t *testing.T) {
},
wantPhases: newPhases(skipInfrastructurePhase, skipImagePhase, skipK8sPhase),
},
"[upgrade] k8s patch version no longer supported, user confirms to skip k8s and continue upgrade. Valid K8s patch version is used in config afterwards": {
createConfig: func(require *require.Assertions, fh file.Handler) {
cfg := defaultConfigWithExpectedMeasurements(t, config.Default(), cloudprovider.GCP)

// use first version in list (oldest) as it should never have a patch version
versionParts := strings.Split(versions.SupportedK8sVersions()[0], ".")
versionParts[len(versionParts)-1] = "0"
cfg.KubernetesVersion = versions.ValidK8sVersion(strings.Join(versionParts, "."))
require.NoError(fh.WriteYAML(constants.ConfigFilename, cfg))
},
createState: postInitState(cloudprovider.GCP),
createMasterSecret: defaultMasterSecret,
createAdminConfig: defaultAdminConfig,
createTfState: defaultTfState,
stdin: "y\n",
wantPhases: newPhases(skipInitPhase, skipK8sPhase),
assert: func(require *require.Assertions, assert *assert.Assertions, conf *config.Config, stateFile *state.State) {
assert.NotEmpty(conf.KubernetesVersion)
_, err := versions.NewValidK8sVersion(string(conf.KubernetesVersion), true)
assert.NoError(err)
},
},
}

for name, tc := range testCases {
Expand Down Expand Up @@ -423,7 +447,7 @@ func TestValidateInputs(t *testing.T) {
flags: tc.flags,
}

_, _, err := a.validateInputs(cmd, &stubAttestationFetcher{})
conf, state, err := a.validateInputs(cmd, &stubAttestationFetcher{})
if tc.wantErr {
assert.Error(err)
return
Expand All @@ -434,6 +458,10 @@ func TestValidateInputs(t *testing.T) {
t.Log(cfgErr.LongMessage())
}
assert.Equal(tc.wantPhases, a.flags.skipPhases)

if tc.assert != nil {
tc.assert(require, assert, conf, state)
}
})
}
}
Expand Down
11 changes: 11 additions & 0 deletions internal/versions/components/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library")
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
load("@rules_proto//proto:defs.bzl", "proto_library")
load("//bazel/go:go_test.bzl", "go_test")
load("//bazel/proto:rules.bzl", "write_go_proto_srcs")

go_library(
Expand Down Expand Up @@ -30,3 +31,13 @@ write_go_proto_srcs(
go_proto_library = ":components_go_proto",
visibility = ["//visibility:public"],
)

go_test(
name = "components_test",
srcs = ["components_test.go"],
embed = [":components"],
deps = [
"@com_github_stretchr_testify//assert",
"@com_github_stretchr_testify//require",
],
)
48 changes: 48 additions & 0 deletions internal/versions/components/components.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package components

import (
"crypto/sha256"
"encoding/json"
"errors"
"fmt"
"strings"
Expand All @@ -16,6 +17,53 @@ import (
// Components is a list of Kubernetes components.
type Components []*Component

type legacyComponent struct {
URL string `json:"URL,omitempty"`
Hash string `json:"Hash,omitempty"`
InstallPath string `json:"InstallPath,omitempty"`
Extract bool `json:"Extract,omitempty"`
}

// UnmarshalJSON implements a custom JSON unmarshaler to ensure backwards compatibility
// with older components lists which had a different format for all keys.
func (c *Components) UnmarshalJSON(b []byte) error {
var legacyComponents []*legacyComponent
if err := json.Unmarshal(b, &legacyComponents); err != nil {
return err
}
var components []*Component
if err := json.Unmarshal(b, &components); err != nil {
return err
}

if len(legacyComponents) != len(components) {
return errors.New("failed to unmarshal data: inconsistent number of components in list") // just a check, should never happen
}

// If a value is not set in the new format,
// it might have been set in the old format.
// In this case, we copy the value from the old format.
comps := make(Components, len(components))
for idx := 0; idx < len(components); idx++ {
comps[idx] = components[idx]
if comps[idx].Url == "" {
comps[idx].Url = legacyComponents[idx].URL
}
if comps[idx].Hash == "" {
comps[idx].Hash = legacyComponents[idx].Hash
}
if comps[idx].InstallPath == "" {
comps[idx].InstallPath = legacyComponents[idx].InstallPath
}
if !comps[idx].Extract {
comps[idx].Extract = legacyComponents[idx].Extract
}
}

*c = comps
return nil
}

// GetHash returns the hash over all component hashes.
func (c Components) GetHash() string {
sha := sha256.New()
Expand Down
31 changes: 31 additions & 0 deletions internal/versions/components/components_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
Copyright (c) Edgeless Systems GmbH

SPDX-License-Identifier: AGPL-3.0-only
*/

package components

import (
"encoding/json"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestUnmarshalComponents(t *testing.T) {
assert := assert.New(t)
require := require.New(t)

legacyFormat := `[{"URL":"https://example.com/foo.tar.gz","Hash":"1234567890","InstallPath":"/foo","Extract":true}]`
newFormat := `[{"url":"https://example.com/foo.tar.gz","hash":"1234567890","install_path":"/foo","extract":true}]`

var fromLegacy Components
require.NoError(json.Unmarshal([]byte(legacyFormat), &fromLegacy))

var fromNew Components
require.NoError(json.Unmarshal([]byte(newFormat), &fromNew))

assert.Equal(fromLegacy, fromNew)
}
Loading