diff --git a/cli/internal/cmd/BUILD.bazel b/cli/internal/cmd/BUILD.bazel index fbdf723d50..e7937a9558 100644 --- a/cli/internal/cmd/BUILD.bazel +++ b/cli/internal/cmd/BUILD.bazel @@ -117,6 +117,7 @@ go_library( go_test( name = "cmd_test", srcs = [ + "apply_test.go", "cloud_test.go", "configfetchmeasurements_test.go", "configgenerate_test.go", @@ -173,12 +174,15 @@ go_test( "@com_github_google_go_tpm_tools//proto/tpm", "@com_github_spf13_afero//:afero", "@com_github_spf13_cobra//:cobra", + "@com_github_spf13_pflag//:pflag", "@com_github_stretchr_testify//assert", "@com_github_stretchr_testify//mock", "@com_github_stretchr_testify//require", "@io_k8s_api//core/v1:core", "@io_k8s_apiextensions_apiserver//pkg/apis/apiextensions/v1:apiextensions", + "@io_k8s_apimachinery//pkg/api/errors", "@io_k8s_apimachinery//pkg/apis/meta/v1:meta", + "@io_k8s_apimachinery//pkg/runtime/schema", "@io_k8s_client_go//tools/clientcmd", "@io_k8s_client_go//tools/clientcmd/api", "@org_golang_google_grpc//:go_default_library", diff --git a/cli/internal/cmd/apply.go b/cli/internal/cmd/apply.go index b0f1accb0d..0d032740f4 100644 --- a/cli/internal/cmd/apply.go +++ b/cli/internal/cmd/apply.go @@ -29,6 +29,7 @@ import ( "github.com/edgelesssys/constellation/v2/internal/api/attestationconfigapi" "github.com/edgelesssys/constellation/v2/internal/atls" "github.com/edgelesssys/constellation/v2/internal/attestation/variant" + "github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider" "github.com/edgelesssys/constellation/v2/internal/compatibility" "github.com/edgelesssys/constellation/v2/internal/config" "github.com/edgelesssys/constellation/v2/internal/constants" @@ -64,8 +65,7 @@ func (f *applyFlags) parse(flags *pflag.FlagSet) error { rawSkipPhases, err := flags.GetStringSlice("skip-phases") if err != nil { - rawSkipPhases = []string{} - // return fmt.Errorf("getting 'skip-phases' flag: %w", err) + return fmt.Errorf("getting 'skip-phases' flag: %w", err) } var skipPhases []skipPhase for _, phase := range rawSkipPhases { @@ -80,14 +80,12 @@ func (f *applyFlags) parse(flags *pflag.FlagSet) error { f.yes, err = flags.GetBool("yes") if err != nil { - f.yes = false - // return fmt.Errorf("getting 'yes' flag: %w", err) + return fmt.Errorf("getting 'yes' flag: %w", err) } f.upgradeTimeout, err = flags.GetDuration("timeout") if err != nil { - f.upgradeTimeout = time.Hour - // return fmt.Errorf("getting 'timeout' flag: %w", err) + return fmt.Errorf("getting 'timeout' flag: %w", err) } f.conformance, err = flags.GetBool("conformance") @@ -106,8 +104,7 @@ func (f *applyFlags) parse(flags *pflag.FlagSet) error { f.mergeConfigs, err = flags.GetBool("merge-kubeconfig") if err != nil { - f.mergeConfigs = false - // return fmt.Errorf("getting 'merge-kubeconfig' flag: %w", err) + return fmt.Errorf("getting 'merge-kubeconfig' flag: %w", err) } return nil } @@ -339,6 +336,17 @@ func (a *applyCmd) apply(cmd *cobra.Command, configFetcher attestationconfigapi. } } + // Constellation on QEMU or OpenStack don't support upgrades + // If using one of those providers, make sure the command is only used to initialize a cluster + if !(conf.GetProvider() == cloudprovider.AWS || conf.GetProvider() == cloudprovider.Azure || conf.GetProvider() == cloudprovider.GCP) { + if !initRequired { + return fmt.Errorf("upgrades are not supported for provider %s", conf.GetProvider()) + } + // Skip Terraform phase + a.log.Debugf("Skipping Infrastructure phase for provider %s", conf.GetProvider()) + a.flags.skipPhases = append(a.flags.skipPhases, skipInfrastructurePhase) + } + // Print warning about AWS attestation // TODO(derpsteb): remove once AWS fixes SEV-SNP attestation provisioning issues if initRequired && conf.GetAttestationConfig().GetVariant().Equal(variant.AWSSEVSNP{}) { diff --git a/cli/internal/cmd/apply_test.go b/cli/internal/cmd/apply_test.go new file mode 100644 index 0000000000..65953ca63e --- /dev/null +++ b/cli/internal/cmd/apply_test.go @@ -0,0 +1,153 @@ +/* +Copyright (c) Edgeless Systems GmbH + +SPDX-License-Identifier: AGPL-3.0-only +*/ + +package cmd + +import ( + "context" + "fmt" + "testing" + + "github.com/edgelesssys/constellation/v2/cli/internal/helm" + "github.com/edgelesssys/constellation/v2/internal/file" + "github.com/edgelesssys/constellation/v2/internal/logger" + "github.com/spf13/afero" + "github.com/spf13/pflag" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParseApplyFlags(t *testing.T) { + require := require.New(t) + // TODO: Use flags := applyCmd().Flags() once we have a separate apply command + defaultFlags := func() *pflag.FlagSet { + flags := pflag.NewFlagSet("test", pflag.ContinueOnError) + flags.String("workspace", "", "") + flags.String("tf-log", "NONE", "") + flags.Bool("force", false, "") + flags.Bool("debug", false, "") + flags.Bool("merge-kubeconfig", false, "") + flags.Bool("conformance", false, "") + flags.Bool("skip-helm-wait", false, "") + flags.Bool("yes", false, "") + flags.StringSlice("skip-phases", []string{}, "") + flags.Duration("timeout", 0, "") + return flags + } + + testCases := map[string]struct { + flags *pflag.FlagSet + wantFlags applyFlags + wantErr bool + }{ + "default flags": { + flags: defaultFlags(), + wantFlags: applyFlags{ + helmWaitMode: helm.WaitModeAtomic, + }, + }, + "skip phases": { + flags: func() *pflag.FlagSet { + flags := defaultFlags() + require.NoError(flags.Set("skip-phases", fmt.Sprintf("%s,%s", skipHelmPhase, skipK8sPhase))) + return flags + }(), + wantFlags: applyFlags{ + skipPhases: []skipPhase{skipHelmPhase, skipK8sPhase}, + helmWaitMode: helm.WaitModeAtomic, + }, + }, + "skip helm wait": { + flags: func() *pflag.FlagSet { + flags := defaultFlags() + require.NoError(flags.Set("skip-helm-wait", "true")) + return flags + }(), + wantFlags: applyFlags{ + helmWaitMode: helm.WaitModeNone, + }, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + assert := assert.New(t) + var flags applyFlags + + err := flags.parse(tc.flags) + if tc.wantErr { + assert.Error(err) + return + } + assert.NoError(err) + assert.Equal(tc.wantFlags, flags) + }) + } +} + +func TestBackupHelmCharts(t *testing.T) { + testCases := map[string]struct { + helmApplier helm.Applier + backupClient *stubKubernetesUpgrader + includesUpgrades bool + wantErr bool + }{ + "success, no upgrades": { + helmApplier: &stubRunner{}, + backupClient: &stubKubernetesUpgrader{}, + }, + "success with upgrades": { + helmApplier: &stubRunner{}, + backupClient: &stubKubernetesUpgrader{}, + includesUpgrades: true, + }, + "saving charts fails": { + helmApplier: &stubRunner{ + saveChartsErr: assert.AnError, + }, + backupClient: &stubKubernetesUpgrader{}, + wantErr: true, + }, + "backup CRDs fails": { + helmApplier: &stubRunner{}, + backupClient: &stubKubernetesUpgrader{ + backupCRDsErr: assert.AnError, + }, + includesUpgrades: true, + wantErr: true, + }, + "backup CRs fails": { + helmApplier: &stubRunner{}, + backupClient: &stubKubernetesUpgrader{ + backupCRsErr: assert.AnError, + }, + includesUpgrades: true, + wantErr: true, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + assert := assert.New(t) + + a := applyCmd{ + fileHandler: file.NewHandler(afero.NewMemMapFs()), + log: logger.NewTest(t), + } + + err := a.backupHelmCharts(context.Background(), tc.backupClient, tc.helmApplier, tc.includesUpgrades, "") + if tc.wantErr { + assert.Error(err) + return + } + assert.NoError(err) + if tc.includesUpgrades { + assert.True(tc.backupClient.backupCRDsCalled) + assert.True(tc.backupClient.backupCRsCalled) + } + }) + } +} diff --git a/cli/internal/cmd/init.go b/cli/internal/cmd/init.go index f056280323..113c2a14cd 100644 --- a/cli/internal/cmd/init.go +++ b/cli/internal/cmd/init.go @@ -7,28 +7,15 @@ SPDX-License-Identifier: AGPL-3.0-only package cmd import ( - "bytes" "context" - "encoding/hex" "errors" "fmt" "io" - "net" - "net/url" "os" - "path/filepath" - "strconv" "sync" - "text/tabwriter" "time" - "github.com/edgelesssys/constellation/v2/internal/api/attestationconfigapi" - "github.com/edgelesssys/constellation/v2/internal/atls" - "github.com/edgelesssys/constellation/v2/internal/attestation/variant" - - "github.com/spf13/afero" "github.com/spf13/cobra" - "github.com/spf13/pflag" "google.golang.org/grpc" "k8s.io/apimachinery/pkg/runtime" "k8s.io/client-go/tools/clientcmd" @@ -36,21 +23,13 @@ import ( "sigs.k8s.io/yaml" "github.com/edgelesssys/constellation/v2/bootstrapper/initproto" - "github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd" "github.com/edgelesssys/constellation/v2/cli/internal/helm" - "github.com/edgelesssys/constellation/v2/cli/internal/kubecmd" "github.com/edgelesssys/constellation/v2/cli/internal/state" "github.com/edgelesssys/constellation/v2/internal/config" "github.com/edgelesssys/constellation/v2/internal/constants" - "github.com/edgelesssys/constellation/v2/internal/crypto" "github.com/edgelesssys/constellation/v2/internal/file" - "github.com/edgelesssys/constellation/v2/internal/grpc/dialer" "github.com/edgelesssys/constellation/v2/internal/grpc/grpclog" - grpcRetry "github.com/edgelesssys/constellation/v2/internal/grpc/retry" "github.com/edgelesssys/constellation/v2/internal/kms/uri" - "github.com/edgelesssys/constellation/v2/internal/license" - "github.com/edgelesssys/constellation/v2/internal/retry" - "github.com/edgelesssys/constellation/v2/internal/versions" ) // NewInitCmd returns a new cobra.Command for the init command. @@ -61,7 +40,15 @@ func NewInitCmd() *cobra.Command { Long: "Initialize the Constellation cluster.\n\n" + "Start your confidential Kubernetes.", Args: cobra.ExactArgs(0), - RunE: runApply, + RunE: func(cmd *cobra.Command, args []string) error { + // Define flags for apply backend that are not set by init + cmd.Flags().Bool("yes", false, "") + // Don't skip any phases + // The apply backend should handle init calls correctly + cmd.Flags().StringSlice("skip-phases", []string{}, "") + cmd.Flags().Duration("timeout", time.Hour, "") + return runApply(cmd, args) + }, } cmd.Flags().Bool("conformance", false, "enable conformance mode") cmd.Flags().Bool("skip-helm-wait", false, "install helm charts without waiting for deployments to be ready") @@ -69,270 +56,6 @@ func NewInitCmd() *cobra.Command { return cmd } -// initFlags are flags used by the init command. -type initFlags struct { - rootFlags - conformance bool - helmWaitMode helm.WaitMode - mergeConfigs bool -} - -func (f *initFlags) parse(flags *pflag.FlagSet) error { - if err := f.rootFlags.parse(flags); err != nil { - return err - } - - skipHelmWait, err := flags.GetBool("skip-helm-wait") - if err != nil { - return fmt.Errorf("getting 'skip-helm-wait' flag: %w", err) - } - f.helmWaitMode = helm.WaitModeAtomic - if skipHelmWait { - f.helmWaitMode = helm.WaitModeNone - } - - f.conformance, err = flags.GetBool("conformance") - if err != nil { - return fmt.Errorf("getting 'conformance' flag: %w", err) - } - f.mergeConfigs, err = flags.GetBool("merge-kubeconfig") - if err != nil { - return fmt.Errorf("getting 'merge-kubeconfig' flag: %w", err) - } - return nil -} - -type initCmd struct { - log debugLog - merger configMerger - spinner spinnerInterf - fileHandler file.Handler - flags initFlags -} - -func newInitCmd(fileHandler file.Handler, spinner spinnerInterf, merger configMerger, log debugLog) *initCmd { - return &initCmd{ - log: log, - merger: merger, - spinner: spinner, - fileHandler: fileHandler, - } -} - -// runInitialize runs the initialize command. -func runInitialize(cmd *cobra.Command, _ []string) error { - log, err := newCLILogger(cmd) - if err != nil { - return fmt.Errorf("creating logger: %w", err) - } - defer log.Sync() - fileHandler := file.NewHandler(afero.NewOsFs()) - newDialer := func(validator atls.Validator) *dialer.Dialer { - return dialer.New(nil, validator, &net.Dialer{}) - } - - spinner, err := newSpinnerOrStderr(cmd) - if err != nil { - return err - } - defer spinner.Stop() - - ctx, cancel := context.WithTimeout(cmd.Context(), time.Hour) - defer cancel() - cmd.SetContext(ctx) - - i := newInitCmd(fileHandler, spinner, &kubeconfigMerger{log: log}, log) - if err := i.flags.parse(cmd.Flags()); err != nil { - return err - } - i.log.Debugf("Using flags: %+v", i.flags) - - fetcher := attestationconfigapi.NewFetcher() - newAttestationApplier := func(w io.Writer, kubeConfig string, log debugLog) (attestationConfigApplier, error) { - return kubecmd.New(w, kubeConfig, fileHandler, log) - } - newHelmClient := func(kubeConfigPath string, log debugLog) (helmApplier, error) { - return helm.NewClient(kubeConfigPath, log) - } // need to defer helm client instantiation until kubeconfig is available - - return i.initialize(cmd, newDialer, license.NewClient(), fetcher, newAttestationApplier, newHelmClient) -} - -// initialize initializes a Constellation. -func (i *initCmd) initialize( - cmd *cobra.Command, newDialer func(validator atls.Validator) *dialer.Dialer, - quotaChecker license.QuotaChecker, configFetcher attestationconfigapi.Fetcher, - newAttestationApplier func(io.Writer, string, debugLog) (attestationConfigApplier, error), - newHelmClient func(kubeConfigPath string, log debugLog) (helmApplier, error), -) error { - i.log.Debugf("Loading configuration file from %q", i.flags.pathPrefixer.PrefixPrintablePath(constants.ConfigFilename)) - conf, err := config.New(i.fileHandler, constants.ConfigFilename, configFetcher, i.flags.force) - var configValidationErr *config.ValidationError - if errors.As(err, &configValidationErr) { - cmd.PrintErrln(configValidationErr.LongMessage()) - } - if err != nil { - return err - } - // cfg validation does not check k8s patch version since upgrade may accept an outdated patch version. - k8sVersion, err := versions.NewValidK8sVersion(string(conf.KubernetesVersion), true) - if err != nil { - return err - } - if !i.flags.force { - if err := validateCLIandConstellationVersionAreEqual(constants.BinaryVersion(), conf.Image, conf.MicroserviceVersion); err != nil { - return err - } - } - if conf.GetAttestationConfig().GetVariant().Equal(variant.AWSSEVSNP{}) { - cmd.PrintErrln("WARNING: Attestation temporarily relies on AWS nitroTPM. See https://docs.edgeless.systems/constellation/workflows/config#choosing-a-vm-type for more information.") - } - - stateFile, err := state.ReadFromFile(i.fileHandler, constants.StateFilename) - if err != nil { - return fmt.Errorf("reading state file: %w", err) - } - - i.log.Debugf("Validated k8s version as %s", k8sVersion) - if versions.IsPreviewK8sVersion(k8sVersion) { - cmd.PrintErrf("Warning: Constellation with Kubernetes %v is still in preview. Use only for evaluation purposes.\n", k8sVersion) - } - - provider := conf.GetProvider() - i.log.Debugf("Got provider %s", provider.String()) - checker := license.NewChecker(quotaChecker, i.fileHandler) - if err := checker.CheckLicense(cmd.Context(), provider, conf.Provider, cmd.Printf); err != nil { - cmd.PrintErrf("License check failed: %v", err) - } - i.log.Debugf("Checked license") - - if stateFile.Infrastructure.Azure != nil { - conf.UpdateMAAURL(stateFile.Infrastructure.Azure.AttestationURL) - } - - i.log.Debugf("Creating aTLS Validator for %s", conf.GetAttestationConfig().GetVariant()) - validator, err := cloudcmd.NewValidator(cmd, conf.GetAttestationConfig(), i.log) - if err != nil { - return fmt.Errorf("creating new validator: %w", err) - } - i.log.Debugf("Created a new validator") - serviceAccURI, err := cloudcmd.GetMarshaledServiceAccountURI(conf, i.fileHandler) - if err != nil { - return err - } - i.log.Debugf("Successfully marshaled service account URI") - - i.log.Debugf("Generating master secret") - masterSecret, err := i.generateMasterSecret(cmd.OutOrStdout()) - if err != nil { - return fmt.Errorf("generating master secret: %w", err) - } - - i.log.Debugf("Generating measurement salt") - measurementSalt, err := crypto.GenerateRandomBytes(crypto.RNGLengthDefault) - if err != nil { - return fmt.Errorf("generating measurement salt: %w", err) - } - - i.log.Debugf("Setting cluster name to %s", stateFile.Infrastructure.Name) - - cmd.PrintErrln("Note: If you just created the cluster, it can take a few minutes to connect.") - i.spinner.Start("Connecting ", false) - req := &initproto.InitRequest{ - KmsUri: masterSecret.EncodeToURI(), - StorageUri: uri.NoStoreURI, - MeasurementSalt: measurementSalt, - KubernetesVersion: versions.VersionConfigs[k8sVersion].ClusterVersion, - KubernetesComponents: versions.VersionConfigs[k8sVersion].KubernetesComponents.ToInitProto(), - ConformanceMode: i.flags.conformance, - InitSecret: stateFile.Infrastructure.InitSecret, - ClusterName: stateFile.Infrastructure.Name, - ApiserverCertSans: stateFile.Infrastructure.APIServerCertSANs, - } - i.log.Debugf("Sending initialization request") - resp, err := i.initCall(cmd.Context(), newDialer(validator), stateFile.Infrastructure.ClusterEndpoint, req) - i.spinner.Stop() - - if err != nil { - var nonRetriable *nonRetriableError - if errors.As(err, &nonRetriable) { - cmd.PrintErrln("Cluster initialization failed. This error is not recoverable.") - cmd.PrintErrln("Terminate your cluster and try again.") - if nonRetriable.logCollectionErr != nil { - cmd.PrintErrf("Failed to collect logs from bootstrapper: %s\n", nonRetriable.logCollectionErr) - } else { - cmd.PrintErrf("Fetched bootstrapper logs are stored in %q\n", i.flags.pathPrefixer.PrefixPrintablePath(constants.ErrorLog)) - } - } - return err - } - i.log.Debugf("Initialization request succeeded") - - bufferedOutput := &bytes.Buffer{} - if err := i.writeOutput(stateFile, resp, i.flags.mergeConfigs, bufferedOutput, measurementSalt); err != nil { - return err - } - - attestationApplier, err := newAttestationApplier(cmd.OutOrStdout(), constants.AdminConfFilename, i.log) - if err != nil { - return err - } - if err := attestationApplier.ApplyJoinConfig(cmd.Context(), conf.GetAttestationConfig(), measurementSalt); err != nil { - return fmt.Errorf("applying attestation config: %w", err) - } - - i.spinner.Start("Installing Kubernetes components ", false) - options := helm.Options{ - Force: i.flags.force, - Conformance: i.flags.conformance, - HelmWaitMode: i.flags.helmWaitMode, - AllowDestructive: helm.DenyDestructive, - } - helmApplier, err := newHelmClient(constants.AdminConfFilename, i.log) - if err != nil { - return fmt.Errorf("creating Helm client: %w", err) - } - executor, includesUpgrades, err := helmApplier.PrepareApply(conf, stateFile, options, serviceAccURI, masterSecret) - if err != nil { - return fmt.Errorf("getting Helm chart executor: %w", err) - } - if includesUpgrades { - return errors.New("init: helm tried to upgrade charts instead of installing them") - } - if err := executor.Apply(cmd.Context()); err != nil { - return fmt.Errorf("applying Helm charts: %w", err) - } - i.spinner.Stop() - i.log.Debugf("Helm deployment installation succeeded") - cmd.Println(bufferedOutput.String()) - return nil -} - -func (i *initCmd) initCall(ctx context.Context, dialer grpcDialer, ip string, req *initproto.InitRequest) (*initproto.InitSuccessResponse, error) { - doer := &initDoer{ - dialer: dialer, - endpoint: net.JoinHostPort(ip, strconv.Itoa(constants.BootstrapperPort)), - req: req, - log: i.log, - spinner: i.spinner, - fh: file.NewHandler(afero.NewOsFs()), - } - - // Create a wrapper function that allows logging any returned error from the retrier before checking if it's the expected retriable one. - serviceIsUnavailable := func(err error) bool { - isServiceUnavailable := grpcRetry.ServiceIsUnavailable(err) - i.log.Debugf("Encountered error (retriable: %t): %s", isServiceUnavailable, err) - return isServiceUnavailable - } - - i.log.Debugf("Making initialization call, doer is %+v", doer) - retrier := retry.NewIntervalRetrier(doer, 30*time.Second, serviceIsUnavailable) - if err := retrier.Do(ctx); err != nil { - return nil, err - } - return doer.resp, nil -} - type initDoer struct { dialer grpcDialer endpoint string @@ -469,122 +192,10 @@ func (d *initDoer) handleGRPCStateChanges(ctx context.Context, wg *sync.WaitGrou }) } -// writeOutput writes the output of a cluster initialization to the -// state- / id- / kubeconfig-file and saves it to disk. -func (i *initCmd) writeOutput( - stateFile *state.State, - initResp *initproto.InitSuccessResponse, - mergeConfig bool, wr io.Writer, - measurementSalt []byte, -) error { - fmt.Fprint(wr, "Your Constellation cluster was successfully initialized.\n\n") - - ownerID := hex.EncodeToString(initResp.GetOwnerId()) - clusterID := hex.EncodeToString(initResp.GetClusterId()) - - stateFile.SetClusterValues(state.ClusterValues{ - MeasurementSalt: measurementSalt, - OwnerID: ownerID, - ClusterID: clusterID, - }) - - tw := tabwriter.NewWriter(wr, 0, 0, 2, ' ', 0) - writeRow(tw, "Constellation cluster identifier", clusterID) - writeRow(tw, "Kubernetes configuration", i.flags.pathPrefixer.PrefixPrintablePath(constants.AdminConfFilename)) - tw.Flush() - fmt.Fprintln(wr) - - i.log.Debugf("Rewriting cluster server address in kubeconfig to %s", stateFile.Infrastructure.ClusterEndpoint) - kubeconfig, err := clientcmd.Load(initResp.GetKubeconfig()) - if err != nil { - return fmt.Errorf("loading kubeconfig: %w", err) - } - if len(kubeconfig.Clusters) != 1 { - return fmt.Errorf("expected exactly one cluster in kubeconfig, got %d", len(kubeconfig.Clusters)) - } - for _, cluster := range kubeconfig.Clusters { - kubeEndpoint, err := url.Parse(cluster.Server) - if err != nil { - return fmt.Errorf("parsing kubeconfig server URL: %w", err) - } - kubeEndpoint.Host = net.JoinHostPort(stateFile.Infrastructure.ClusterEndpoint, kubeEndpoint.Port()) - cluster.Server = kubeEndpoint.String() - } - kubeconfigBytes, err := clientcmd.Write(*kubeconfig) - if err != nil { - return fmt.Errorf("marshaling kubeconfig: %w", err) - } - - if err := i.fileHandler.Write(constants.AdminConfFilename, kubeconfigBytes, file.OptNone); err != nil { - return fmt.Errorf("writing kubeconfig: %w", err) - } - i.log.Debugf("Kubeconfig written to %s", i.flags.pathPrefixer.PrefixPrintablePath(constants.AdminConfFilename)) - - if mergeConfig { - if err := i.merger.mergeConfigs(constants.AdminConfFilename, i.fileHandler); err != nil { - writeRow(tw, "Failed to automatically merge kubeconfig", err.Error()) - mergeConfig = false // Set to false so we don't print the wrong message below. - } else { - writeRow(tw, "Kubernetes configuration merged with default config", "") - } - } - - if err := stateFile.WriteToFile(i.fileHandler, constants.StateFilename); err != nil { - return fmt.Errorf("writing Constellation state file: %w", err) - } - - i.log.Debugf("Constellation state file written to %s", i.flags.pathPrefixer.PrefixPrintablePath(constants.StateFilename)) - - if !mergeConfig { - fmt.Fprintln(wr, "You can now connect to your cluster by executing:") - - exportPath, err := filepath.Abs(constants.AdminConfFilename) - if err != nil { - return fmt.Errorf("getting absolute path to kubeconfig: %w", err) - } - - fmt.Fprintf(wr, "\texport KUBECONFIG=%q\n", exportPath) - } else { - fmt.Fprintln(wr, "Constellation kubeconfig merged with default config.") - - if i.merger.kubeconfigEnvVar() != "" { - fmt.Fprintln(wr, "Warning: KUBECONFIG environment variable is set.") - fmt.Fprintln(wr, "You may need to unset it to use the default config and connect to your cluster.") - } else { - fmt.Fprintln(wr, "You can now connect to your cluster.") - } - } - return nil -} - func writeRow(wr io.Writer, col1 string, col2 string) { fmt.Fprint(wr, col1, "\t", col2, "\n") } -// generateMasterSecret reads a base64 encoded master secret from file or generates a new 32 byte secret. -func (i *initCmd) generateMasterSecret(outWriter io.Writer) (uri.MasterSecret, error) { - // No file given, generate a new secret, and save it to disk - i.log.Debugf("Generating new master secret") - key, err := crypto.GenerateRandomBytes(crypto.MasterSecretLengthDefault) - if err != nil { - return uri.MasterSecret{}, err - } - salt, err := crypto.GenerateRandomBytes(crypto.RNGLengthDefault) - if err != nil { - return uri.MasterSecret{}, err - } - secret := uri.MasterSecret{ - Key: key, - Salt: salt, - } - i.log.Debugf("Generated master secret key and salt values") - if err := i.fileHandler.WriteJSON(constants.MasterSecretFilename, secret, file.OptNone); err != nil { - return uri.MasterSecret{}, err - } - fmt.Fprintf(outWriter, "Your Constellation master secret was successfully written to %q\n", i.flags.pathPrefixer.PrefixPrintablePath(constants.MasterSecretFilename)) - return secret, nil -} - type configMerger interface { mergeConfigs(configPath string, fileHandler file.Handler) error kubeconfigEnvVar() string @@ -657,10 +268,6 @@ func (e *nonRetriableError) Unwrap() error { return e.err } -type attestationConfigApplier interface { - ApplyJoinConfig(ctx context.Context, newAttestConfig config.AttestationCfg, measurementSalt []byte) error -} - type helmApplier interface { PrepareApply(conf *config.Config, stateFile *state.State, flags helm.Options, serviceAccURI string, masterSecret uri.MasterSecret) ( diff --git a/cli/internal/cmd/init_test.go b/cli/internal/cmd/init_test.go index 7e8603894d..1b353d1123 100644 --- a/cli/internal/cmd/init_test.go +++ b/cli/internal/cmd/init_test.go @@ -44,6 +44,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/grpc" + k8serrors "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/client-go/tools/clientcmd" k8sclientapi "k8s.io/client-go/tools/clientcmd/api" ) @@ -58,8 +60,6 @@ func TestInitArgumentValidation(t *testing.T) { } func TestInitialize(t *testing.T) { - require := require.New(t) - respKubeconfig := k8sclientapi.Config{ Clusters: map[string]*k8sclientapi.Cluster{ "cluster": { @@ -68,7 +68,7 @@ func TestInitialize(t *testing.T) { }, } respKubeconfigBytes, err := clientcmd.Write(respKubeconfig) - require.NoError(err) + require.NoError(t, err) gcpServiceAccKey := &gcpshared.ServiceAccountKey{ Type: "service_account", @@ -149,31 +149,46 @@ func TestInitialize(t *testing.T) { masterSecretShouldExist: true, wantErr: true, }, - "state file with only version": { - provider: cloudprovider.GCP, - stateFile: &state.State{Version: state.Version1}, - initServerAPI: &stubInitServer{}, - retriable: true, - wantErr: true, - }, - "empty state file": { + /* + Tests currently disabled since we don't actually have validation for the state file yet + These tests cases only passed in the past because of unrelated errors in the test setup + + "state file with only version": { + provider: cloudprovider.GCP, + stateFile: &state.State{Version: state.Version1}, + configMutator: func(c *config.Config) { c.Provider.GCP.ServiceAccountKeyPath = serviceAccPath }, + serviceAccKey: gcpServiceAccKey, + initServerAPI: &stubInitServer{}, + retriable: true, + wantErr: true, + }, + + "empty state file": { + provider: cloudprovider.GCP, + stateFile: &state.State{}, + configMutator: func(c *config.Config) { c.Provider.GCP.ServiceAccountKeyPath = serviceAccPath }, + serviceAccKey: gcpServiceAccKey, + initServerAPI: &stubInitServer{}, + retriable: true, + wantErr: true, + }, + */ + "no state file": { provider: cloudprovider.GCP, - stateFile: &state.State{}, - initServerAPI: &stubInitServer{}, + configMutator: func(c *config.Config) { c.Provider.GCP.ServiceAccountKeyPath = serviceAccPath }, + serviceAccKey: gcpServiceAccKey, retriable: true, wantErr: true, }, - "no state file": { - provider: cloudprovider.GCP, - retriable: true, - wantErr: true, - }, "init call fails": { - provider: cloudprovider.GCP, - stateFile: &state.State{Version: state.Version1, Infrastructure: state.Infrastructure{ClusterEndpoint: "192.0.2.1"}}, - initServerAPI: &stubInitServer{initErr: assert.AnError}, - retriable: true, - wantErr: true, + provider: cloudprovider.GCP, + configMutator: func(c *config.Config) { c.Provider.GCP.ServiceAccountKeyPath = serviceAccPath }, + stateFile: &state.State{Version: state.Version1, Infrastructure: state.Infrastructure{ClusterEndpoint: "192.0.2.1"}}, + serviceAccKey: gcpServiceAccKey, + initServerAPI: &stubInitServer{initErr: assert.AnError}, + retriable: false, + masterSecretShouldExist: true, + wantErr: true, }, "k8s version without v works": { provider: cloudprovider.Azure, @@ -181,7 +196,7 @@ func TestInitialize(t *testing.T) { initServerAPI: &stubInitServer{res: []*initproto.InitResponse{{Kind: &initproto.InitResponse_InitSuccess{InitSuccess: testInitResp}}}}, configMutator: func(c *config.Config) { res, err := versions.NewValidK8sVersion(strings.TrimPrefix(string(versions.Default), "v"), true) - require.NoError(err) + require.NoError(t, err) c.KubernetesVersion = res }, }, @@ -191,7 +206,7 @@ func TestInitialize(t *testing.T) { initServerAPI: &stubInitServer{res: []*initproto.InitResponse{{Kind: &initproto.InitResponse_InitSuccess{InitSuccess: testInitResp}}}}, configMutator: func(c *config.Config) { v, err := semver.New(versions.SupportedK8sVersions()[0]) - require.NoError(err) + require.NoError(t, err) outdatedPatchVer := semver.NewFromInt(v.Major(), v.Minor(), v.Patch()-1, "").String() c.KubernetesVersion = versions.ValidK8sVersion(outdatedPatchVer) }, @@ -203,6 +218,7 @@ func TestInitialize(t *testing.T) { for name, tc := range testCases { t.Run(name, func(t *testing.T) { assert := assert.New(t) + require := require.New(t) // Networking netDialer := testdialer.NewBufconnDialer() newDialer := func(atls.Validator) *dialer.Dialer { @@ -231,8 +247,6 @@ func TestInitialize(t *testing.T) { tc.configMutator(config) } require.NoError(fileHandler.WriteYAML(constants.ConfigFilename, config, file.OptNone)) - stateFile := state.New() - require.NoError(stateFile.WriteToFile(fileHandler, constants.StateFilename)) if tc.stateFile != nil { require.NoError(tc.stateFile.WriteToFile(fileHandler, constants.StateFilename)) } @@ -244,22 +258,31 @@ func TestInitialize(t *testing.T) { ctx, cancel := context.WithTimeout(ctx, 4*time.Second) defer cancel() cmd.SetContext(ctx) - i := newInitCmd(fileHandler, &nopSpinner{}, nil, logger.NewTest(t)) - i.flags.force = true - - err := i.initialize( - cmd, - newDialer, - &stubLicenseClient{}, - stubAttestationFetcher{}, - func(io.Writer, string, debugLog) (attestationConfigApplier, error) { - return &stubAttestationApplier{}, nil - }, - func(_ string, _ debugLog) (helmApplier, error) { + + i := &applyCmd{ + fileHandler: fileHandler, + flags: applyFlags{rootFlags: rootFlags{force: true}}, + log: logger.NewTest(t), + spinner: &nopSpinner{}, + merger: &stubMerger{}, + quotaChecker: &stubLicenseClient{}, + newHelmClient: func(string, debugLog) (helmApplier, error) { return &stubApplier{}, nil - }) + }, + newDialer: newDialer, + newKubeUpgrader: func(io.Writer, string, debugLog) (kubernetesUpgrader, error) { + return &stubKubernetesUpgrader{ + // On init, no attestation config exists yet + getClusterAttestationConfigErr: k8serrors.NewNotFound(schema.GroupResource{}, ""), + }, nil + }, + clusterUpgrader: stubTerraformUpgrader{}, + } + + err := i.apply(cmd, stubAttestationFetcher{}, "test") if tc.wantErr { + fmt.Println(err) assert.Error(err) if !tc.retriable { assert.Contains(errOut.String(), "This error is not recoverable") @@ -291,14 +314,17 @@ func (s stubApplier) PrepareApply(_ *config.Config, _ *state.State, _ helm.Optio return stubRunner{}, false, s.err } -type stubRunner struct{} +type stubRunner struct { + applyErr error + saveChartsErr error +} func (s stubRunner) Apply(_ context.Context) error { - return nil + return s.applyErr } func (s stubRunner) SaveCharts(_ string, _ file.Handler) error { - return nil + return s.saveChartsErr } func TestGetLogs(t *testing.T) { @@ -420,7 +446,12 @@ func TestWriteOutput(t *testing.T) { ClusterEndpoint: clusterEndpoint, }) - i := newInitCmd(fileHandler, &nopSpinner{}, &stubMerger{}, logger.NewTest(t)) + i := &applyCmd{ + fileHandler: fileHandler, + spinner: &nopSpinner{}, + merger: &stubMerger{}, + log: logger.NewTest(t), + } err = i.writeOutput(stateFile, resp.GetInitSuccess(), false, &out, measurementSalt) require.NoError(err) assert.Contains(out.String(), clusterID) @@ -508,7 +539,10 @@ func TestGenerateMasterSecret(t *testing.T) { require.NoError(tc.createFileFunc(fileHandler)) var out bytes.Buffer - i := newInitCmd(fileHandler, nil, nil, logger.NewTest(t)) + i := &applyCmd{ + fileHandler: fileHandler, + log: logger.NewTest(t), + } secret, err := i.generateMasterSecret(&out) if tc.wantErr { @@ -601,13 +635,17 @@ func TestAttestation(t *testing.T) { defer cancel() cmd.SetContext(ctx) - i := newInitCmd(fileHandler, &nopSpinner{}, nil, logger.NewTest(t)) - err := i.initialize(cmd, newDialer, &stubLicenseClient{}, stubAttestationFetcher{}, - func(io.Writer, string, debugLog) (attestationConfigApplier, error) { - return &stubAttestationApplier{}, nil - }, func(_ string, _ debugLog) (helmApplier, error) { - return &stubApplier{}, nil - }) + i := &applyCmd{ + fileHandler: fileHandler, + spinner: &nopSpinner{}, + merger: &stubMerger{}, + log: logger.NewTest(t), + newKubeUpgrader: func(io.Writer, string, debugLog) (kubernetesUpgrader, error) { + return &stubKubernetesUpgrader{}, nil + }, + newDialer: newDialer, + } + _, err := i.runInit(cmd, cfg, existingStateFile) assert.Error(err) // make sure the error is actually a TLS handshake error assert.Contains(err.Error(), "transport: authentication handshake failed") @@ -770,11 +808,3 @@ func (c stubInitClient) Recv() (*initproto.InitResponse, error) { return res, err } - -type stubAttestationApplier struct { - applyErr error -} - -func (a *stubAttestationApplier) ApplyJoinConfig(context.Context, config.AttestationCfg, []byte) error { - return a.applyErr -} diff --git a/cli/internal/cmd/miniup.go b/cli/internal/cmd/miniup.go index f2612a5c6d..c5dc3457f5 100644 --- a/cli/internal/cmd/miniup.go +++ b/cli/internal/cmd/miniup.go @@ -10,23 +10,17 @@ import ( "context" "errors" "fmt" - "io" - "net" + "time" "github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd" "github.com/edgelesssys/constellation/v2/cli/internal/featureset" - "github.com/edgelesssys/constellation/v2/cli/internal/helm" - "github.com/edgelesssys/constellation/v2/cli/internal/kubecmd" "github.com/edgelesssys/constellation/v2/cli/internal/libvirt" "github.com/edgelesssys/constellation/v2/cli/internal/state" "github.com/edgelesssys/constellation/v2/internal/api/attestationconfigapi" - "github.com/edgelesssys/constellation/v2/internal/atls" "github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider" "github.com/edgelesssys/constellation/v2/internal/config" "github.com/edgelesssys/constellation/v2/internal/constants" "github.com/edgelesssys/constellation/v2/internal/file" - "github.com/edgelesssys/constellation/v2/internal/grpc/dialer" - "github.com/edgelesssys/constellation/v2/internal/license" "github.com/spf13/afero" "github.com/spf13/cobra" ) @@ -105,7 +99,7 @@ func (m *miniUpCmd) up(cmd *cobra.Command, creator cloudCreator, spinner spinner cmd.Printf("\tvirsh -c %s\n\n", connectURI) // initialize cluster - if err := m.initializeMiniCluster(cmd, spinner); err != nil { + if err := m.initializeMiniCluster(cmd); err != nil { return fmt.Errorf("initializing cluster: %w", err) } m.log.Debugf("Initialized cluster") @@ -188,7 +182,7 @@ func (m *miniUpCmd) createMiniCluster(ctx context.Context, creator cloudCreator, } // initializeMiniCluster initializes a QEMU cluster. -func (m *miniUpCmd) initializeMiniCluster(cmd *cobra.Command, spinner spinnerInterf) (retErr error) { +func (m *miniUpCmd) initializeMiniCluster(cmd *cobra.Command) (retErr error) { m.log.Debugf("Initializing mini cluster") // clean up cluster resources if initialization fails defer func() { @@ -199,34 +193,19 @@ func (m *miniUpCmd) initializeMiniCluster(cmd *cobra.Command, spinner spinnerInt cmd.PrintErrf("Rollback succeeded.\n\n") } }() - newDialer := func(validator atls.Validator) *dialer.Dialer { - return dialer.New(nil, validator, &net.Dialer{}) - } - m.log.Debugf("Created new dialer") - cmd.Flags().String("endpoint", "", "") - cmd.Flags().Bool("conformance", false, "") - cmd.Flags().Bool("skip-helm-wait", false, "install helm charts without waiting for deployments to be ready") - log, err := newCLILogger(cmd) - if err != nil { - return fmt.Errorf("creating logger: %w", err) - } - m.log.Debugf("Created new logger") - defer log.Sync() - - newAttestationApplier := func(w io.Writer, kubeConfig string, log debugLog) (attestationConfigApplier, error) { - return kubecmd.New(w, kubeConfig, m.fileHandler, log) - } - newHelmClient := func(kubeConfigPath string, log debugLog) (helmApplier, error) { - return helm.NewClient(kubeConfigPath, log) - } // need to defer helm client instantiation until kubeconfig is available - i := newInitCmd(m.fileHandler, spinner, &kubeconfigMerger{log: log}, log) - if err := i.flags.parse(cmd.Flags()); err != nil { - return err - } + // Define flags for apply backend that are not set by mini up + cmd.Flags().StringSlice( + "skip-phases", + []string{string(skipInfrastructurePhase), string(skipK8sPhase), string(skipImagePhase)}, + "", + ) + cmd.Flags().Bool("yes", false, "") + cmd.Flags().Bool("skip-helm-wait", false, "") + cmd.Flags().Bool("conformance", false, "") + cmd.Flags().Duration("timeout", time.Hour, "") - if err := i.initialize(cmd, newDialer, license.NewClient(), m.configFetcher, - newAttestationApplier, newHelmClient); err != nil { + if err := runApply(cmd, nil); err != nil { return err } m.log.Debugf("Initialized mini cluster") diff --git a/cli/internal/cmd/upgradeapply.go b/cli/internal/cmd/upgradeapply.go index d34b10298b..9f98f8da2f 100644 --- a/cli/internal/cmd/upgradeapply.go +++ b/cli/internal/cmd/upgradeapply.go @@ -8,31 +8,18 @@ package cmd import ( "context" - "errors" "fmt" "io" - "path/filepath" "strings" "time" - "github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd" - "github.com/edgelesssys/constellation/v2/cli/internal/helm" - "github.com/edgelesssys/constellation/v2/cli/internal/kubecmd" "github.com/edgelesssys/constellation/v2/cli/internal/state" "github.com/edgelesssys/constellation/v2/cli/internal/terraform" - "github.com/edgelesssys/constellation/v2/internal/api/attestationconfigapi" "github.com/edgelesssys/constellation/v2/internal/attestation/variant" "github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider" - "github.com/edgelesssys/constellation/v2/internal/compatibility" "github.com/edgelesssys/constellation/v2/internal/config" - "github.com/edgelesssys/constellation/v2/internal/constants" - "github.com/edgelesssys/constellation/v2/internal/file" - "github.com/edgelesssys/constellation/v2/internal/kms/uri" - "github.com/edgelesssys/constellation/v2/internal/versions" "github.com/rogpeppe/go-internal/diff" - "github.com/spf13/afero" "github.com/spf13/cobra" - "github.com/spf13/pflag" "gopkg.in/yaml.v3" apiextensionsv1 "k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1" ) @@ -59,7 +46,11 @@ func newUpgradeApplyCmd() *cobra.Command { Short: "Apply an upgrade to a Constellation cluster", Long: "Apply an upgrade to a Constellation cluster by applying the chosen configuration.", Args: cobra.NoArgs, - RunE: runApply, + RunE: func(cmd *cobra.Command, args []string) error { + // Define flags for apply backend that are not set by upgrade-apply + cmd.Flags().Bool("merge-kubeconfig", false, "") + return runApply(cmd, args) + }, } cmd.Flags().BoolP("yes", "y", false, "run upgrades without further confirmation\n"+ @@ -71,238 +62,11 @@ func newUpgradeApplyCmd() *cobra.Command { cmd.Flags().Bool("skip-helm-wait", false, "install helm charts without waiting for deployments to be ready") cmd.Flags().StringSlice("skip-phases", nil, "comma-separated list of upgrade phases to skip\n"+ "one or multiple of { infrastructure | helm | image | k8s }") - if err := cmd.Flags().MarkHidden("timeout"); err != nil { - panic(err) - } + must(cmd.Flags().MarkHidden("timeout")) return cmd } -type upgradeApplyFlags struct { - rootFlags - yes bool - upgradeTimeout time.Duration - conformance bool - helmWaitMode helm.WaitMode - skipPhases skipPhases -} - -func (f *upgradeApplyFlags) parse(flags *pflag.FlagSet) error { - if err := f.rootFlags.parse(flags); err != nil { - return err - } - - rawSkipPhases, err := flags.GetStringSlice("skip-phases") - if err != nil { - return fmt.Errorf("parsing skip-phases flag: %w", err) - } - var skipPhases []skipPhase - for _, phase := range rawSkipPhases { - switch skipPhase(phase) { - case skipInfrastructurePhase, skipHelmPhase, skipImagePhase, skipK8sPhase: - skipPhases = append(skipPhases, skipPhase(phase)) - default: - return fmt.Errorf("invalid phase %s", phase) - } - } - f.skipPhases = skipPhases - - f.yes, err = flags.GetBool("yes") - if err != nil { - return fmt.Errorf("getting 'yes' flag: %w", err) - } - - f.upgradeTimeout, err = flags.GetDuration("timeout") - if err != nil { - return fmt.Errorf("getting 'timeout' flag: %w", err) - } - - f.conformance, err = flags.GetBool("conformance") - if err != nil { - return fmt.Errorf("getting 'conformance' flag: %w", err) - } - skipHelmWait, err := flags.GetBool("skip-helm-wait") - if err != nil { - return fmt.Errorf("getting 'skip-helm-wait' flag: %w", err) - } - f.helmWaitMode = helm.WaitModeAtomic - if skipHelmWait { - f.helmWaitMode = helm.WaitModeNone - } - - return nil -} - -func runUpgradeApply(cmd *cobra.Command, _ []string) error { - log, err := newCLILogger(cmd) - if err != nil { - return fmt.Errorf("creating logger: %w", err) - } - defer log.Sync() - - fileHandler := file.NewHandler(afero.NewOsFs()) - upgradeID := generateUpgradeID(upgradeCmdKindApply) - - kubeUpgrader, err := kubecmd.New(cmd.OutOrStdout(), constants.AdminConfFilename, fileHandler, log) - if err != nil { - return err - } - - configFetcher := attestationconfigapi.NewFetcher() - - var flags upgradeApplyFlags - if err := flags.parse(cmd.Flags()); err != nil { - return err - } - - // Set up terraform upgrader - upgradeDir := filepath.Join(constants.UpgradeDir, upgradeID) - clusterUpgrader, err := cloudcmd.NewClusterUpgrader( - cmd.Context(), - constants.TerraformWorkingDir, - upgradeDir, - flags.tfLogLevel, - fileHandler, - ) - if err != nil { - return fmt.Errorf("setting up cluster upgrader: %w", err) - } - - helmClient, err := helm.NewClient(constants.AdminConfFilename, log) - if err != nil { - return fmt.Errorf("creating Helm client: %w", err) - } - - applyCmd := upgradeApplyCmd{ - kubeUpgrader: kubeUpgrader, - helmApplier: helmClient, - clusterUpgrader: clusterUpgrader, - configFetcher: configFetcher, - fileHandler: fileHandler, - flags: flags, - log: log, - } - return applyCmd.upgradeApply(cmd, upgradeDir) -} - -type upgradeApplyCmd struct { - helmApplier helmApplier - kubeUpgrader kubernetesUpgrader - clusterUpgrader clusterUpgrader - configFetcher attestationconfigapi.Fetcher - fileHandler file.Handler - flags upgradeApplyFlags - log debugLog -} - -func (u *upgradeApplyCmd) upgradeApply(cmd *cobra.Command, upgradeDir string) error { - conf, err := config.New(u.fileHandler, constants.ConfigFilename, u.configFetcher, u.flags.force) - var configValidationErr *config.ValidationError - if errors.As(err, &configValidationErr) { - cmd.PrintErrln(configValidationErr.LongMessage()) - } - if err != nil { - return err - } - if cloudcmd.UpgradeRequiresIAMMigration(conf.GetProvider()) { - cmd.Println("WARNING: This upgrade requires an IAM migration. Please make sure you have applied the IAM migration using `iam upgrade apply` before continuing.") - if !u.flags.yes { - yes, err := askToConfirm(cmd, "Did you upgrade the IAM resources?") - if err != nil { - return fmt.Errorf("asking for confirmation: %w", err) - } - if !yes { - cmd.Println("Skipping upgrade.") - return nil - } - } - } - conf.KubernetesVersion, err = validK8sVersion(cmd, string(conf.KubernetesVersion), u.flags.yes) - if err != nil { - return err - } - - stateFile, err := state.ReadFromFile(u.fileHandler, constants.StateFilename) - if err != nil { - return fmt.Errorf("reading state file: %w", err) - } - - if err := u.confirmAndUpgradeAttestationConfig(cmd, conf.GetAttestationConfig(), stateFile.ClusterValues.MeasurementSalt); err != nil { - return fmt.Errorf("upgrading measurements: %w", err) - } - - // If infrastructure phase is skipped, we expect the new infrastructure - // to be in the Terraform configuration already. Otherwise, perform - // the Terraform migrations. - if !u.flags.skipPhases.contains(skipInfrastructurePhase) { - migrationRequired, err := u.planTerraformMigration(cmd, conf) - if err != nil { - return fmt.Errorf("planning Terraform migrations: %w", err) - } - - if migrationRequired { - postMigrationInfraState, err := u.migrateTerraform(cmd, conf, upgradeDir) - if err != nil { - return fmt.Errorf("performing Terraform migrations: %w", err) - } - - // Merge the pre-upgrade state with the post-migration infrastructure values - if _, err := stateFile.Merge( - // temporary state with post-migration infrastructure values - state.New().SetInfrastructure(postMigrationInfraState), - ); err != nil { - return fmt.Errorf("merging pre-upgrade state with post-migration infrastructure values: %w", err) - } - - // Write the post-migration state to disk - if err := stateFile.WriteToFile(u.fileHandler, constants.StateFilename); err != nil { - return fmt.Errorf("writing state file: %w", err) - } - } - } - - // extend the clusterConfig cert SANs with any of the supported endpoints: - // - (legacy) public IP - // - fallback endpoint - // - custom (user-provided) endpoint - sans := append([]string{stateFile.Infrastructure.ClusterEndpoint, conf.CustomEndpoint}, stateFile.Infrastructure.APIServerCertSANs...) - if err := u.kubeUpgrader.ExtendClusterConfigCertSANs(cmd.Context(), sans); err != nil { - return fmt.Errorf("extending cert SANs: %w", err) - } - - if conf.GetProvider() != cloudprovider.Azure && conf.GetProvider() != cloudprovider.GCP && conf.GetProvider() != cloudprovider.AWS { - cmd.PrintErrln("WARNING: Skipping service and image upgrades, which are currently only supported for AWS, Azure, and GCP.") - return nil - } - - var upgradeErr *compatibility.InvalidUpgradeError - if !u.flags.skipPhases.contains(skipHelmPhase) { - err = u.handleServiceUpgrade(cmd, conf, stateFile, upgradeDir) - switch { - case errors.As(err, &upgradeErr): - cmd.PrintErrln(err) - case err == nil: - cmd.Println("Successfully upgraded Constellation services.") - case err != nil: - return fmt.Errorf("upgrading services: %w", err) - } - } - skipImageUpgrade := u.flags.skipPhases.contains(skipImagePhase) - skipK8sUpgrade := u.flags.skipPhases.contains(skipK8sPhase) - if !(skipImageUpgrade && skipK8sUpgrade) { - err = u.kubeUpgrader.UpgradeNodeVersion(cmd.Context(), conf, u.flags.force, skipImageUpgrade, skipK8sUpgrade) - switch { - case errors.Is(err, kubecmd.ErrInProgress): - cmd.PrintErrln("Skipping image and Kubernetes upgrades. Another upgrade is in progress.") - case errors.As(err, &upgradeErr): - cmd.PrintErrln(err) - case err != nil: - return fmt.Errorf("upgrading NodeVersion: %w", err) - } - } - return nil -} - func diffAttestationCfg(currentAttestationCfg config.AttestationCfg, newAttestationCfg config.AttestationCfg) (string, error) { // cannot compare structs directly with go-cmp because of unexported fields in the attestation config currentYml, err := yaml.Marshal(currentAttestationCfg) @@ -317,209 +81,6 @@ func diffAttestationCfg(currentAttestationCfg config.AttestationCfg, newAttestat return diff, nil } -// planTerraformMigration checks if the Constellation version the cluster is being upgraded to requires a migration. -func (u *upgradeApplyCmd) planTerraformMigration(cmd *cobra.Command, conf *config.Config) (bool, error) { - u.log.Debugf("Planning Terraform migrations") - - vars, err := cloudcmd.TerraformUpgradeVars(conf) - if err != nil { - return false, fmt.Errorf("parsing upgrade variables: %w", err) - } - u.log.Debugf("Using Terraform variables:\n%v", vars) - - // Check if there are any Terraform migrations to apply - - // Add manual migrations here if required - // - // var manualMigrations []terraform.StateMigration - // for _, migration := range manualMigrations { - // u.log.Debugf("Adding manual Terraform migration: %s", migration.DisplayName) - // u.upgrader.AddManualStateMigration(migration) - // } - - return u.clusterUpgrader.PlanClusterUpgrade(cmd.Context(), cmd.OutOrStdout(), vars, conf.GetProvider()) -} - -// migrateTerraform checks if the Constellation version the cluster is being upgraded to requires a migration -// of cloud resources with Terraform. If so, the migration is performed and the post-migration infrastructure state is returned. -// If no migration is required, the current (pre-upgrade) infrastructure state is returned. -func (u *upgradeApplyCmd) migrateTerraform(cmd *cobra.Command, conf *config.Config, upgradeDir string, -) (state.Infrastructure, error) { - // If there are any Terraform migrations to apply, ask for confirmation - fmt.Fprintln(cmd.OutOrStdout(), "The upgrade requires a migration of Constellation cloud resources by applying an updated Terraform template. Please manually review the suggested changes below.") - if !u.flags.yes { - ok, err := askToConfirm(cmd, "Do you want to apply the Terraform migrations?") - if err != nil { - return state.Infrastructure{}, fmt.Errorf("asking for confirmation: %w", err) - } - if !ok { - cmd.Println("Aborting upgrade.") - // User doesn't expect to see any changes in his workspace after aborting an "upgrade apply", - // therefore, roll back to the backed up state. - if err := u.clusterUpgrader.RestoreClusterWorkspace(); err != nil { - return state.Infrastructure{}, fmt.Errorf( - "restoring Terraform workspace: %w, restore the Terraform workspace manually from %s ", - err, - filepath.Join(upgradeDir, constants.TerraformUpgradeBackupDir), - ) - } - return state.Infrastructure{}, fmt.Errorf("cluster upgrade aborted by user") - } - } - u.log.Debugf("Applying Terraform migrations") - - infraState, err := u.clusterUpgrader.ApplyClusterUpgrade(cmd.Context(), conf.GetProvider()) - if err != nil { - return state.Infrastructure{}, fmt.Errorf("applying terraform migrations: %w", err) - } - - cmd.Printf("Infrastructure migrations applied successfully and output written to: %s\n"+ - "A backup of the pre-upgrade state has been written to: %s\n", - u.flags.pathPrefixer.PrefixPrintablePath(constants.StateFilename), - u.flags.pathPrefixer.PrefixPrintablePath(filepath.Join(upgradeDir, constants.TerraformUpgradeBackupDir)), - ) - return infraState, nil -} - -// validK8sVersion checks if the Kubernetes patch version is supported and asks for confirmation if not. -func validK8sVersion(cmd *cobra.Command, version string, yes bool) (validVersion versions.ValidK8sVersion, err error) { - validVersion, err = versions.NewValidK8sVersion(version, true) - if versions.IsPreviewK8sVersion(validVersion) { - cmd.PrintErrf("Warning: Constellation with Kubernetes %v is still in preview. Use only for evaluation purposes.\n", validVersion) - } - valid := err == nil - - if !valid && !yes { - confirmed, err := askToConfirm(cmd, fmt.Sprintf("WARNING: The Kubernetes patch version %s is not supported. If you continue, Kubernetes upgrades will be skipped. Do you want to continue anyway?", version)) - if err != nil { - return validVersion, fmt.Errorf("asking for confirmation: %w", err) - } - if !confirmed { - return validVersion, fmt.Errorf("aborted by user") - } - } - - return validVersion, nil -} - -// confirmAndUpgradeAttestationConfig checks if the locally configured measurements are different from the cluster's measurements. -// If so the function will ask the user to confirm (if --yes is not set) and upgrade the cluster's config. -func (u *upgradeApplyCmd) confirmAndUpgradeAttestationConfig( - cmd *cobra.Command, newConfig config.AttestationCfg, measurementSalt []byte, -) error { - clusterAttestationConfig, err := u.kubeUpgrader.GetClusterAttestationConfig(cmd.Context(), newConfig.GetVariant()) - if err != nil { - return fmt.Errorf("getting cluster attestation config: %w", err) - } - - // If the current config is equal, or there is an error when comparing the configs, we skip the upgrade. - equal, err := newConfig.EqualTo(clusterAttestationConfig) - if err != nil { - return fmt.Errorf("comparing attestation configs: %w", err) - } - if equal { - return nil - } - cmd.Println("The configured attestation config is different from the attestation config in the cluster.") - diffStr, err := diffAttestationCfg(clusterAttestationConfig, newConfig) - if err != nil { - return fmt.Errorf("diffing attestation configs: %w", err) - } - cmd.Println("The following changes will be applied to the attestation config:") - cmd.Println(diffStr) - if !u.flags.yes { - ok, err := askToConfirm(cmd, "Are you sure you want to change your cluster's attestation config?") - if err != nil { - return fmt.Errorf("asking for confirmation: %w", err) - } - if !ok { - return errors.New("aborting upgrade since attestation config is different") - } - } - - if err := u.kubeUpgrader.ApplyJoinConfig(cmd.Context(), newConfig, measurementSalt); err != nil { - return fmt.Errorf("updating attestation config: %w", err) - } - cmd.Println("Successfully updated the cluster's attestation config") - return nil -} - -func (u *upgradeApplyCmd) handleServiceUpgrade( - cmd *cobra.Command, conf *config.Config, stateFile *state.State, upgradeDir string, -) error { - var secret uri.MasterSecret - if err := u.fileHandler.ReadJSON(constants.MasterSecretFilename, &secret); err != nil { - return fmt.Errorf("reading master secret: %w", err) - } - serviceAccURI, err := cloudcmd.GetMarshaledServiceAccountURI(conf, u.fileHandler) - if err != nil { - return fmt.Errorf("getting service account URI: %w", err) - } - options := helm.Options{ - Force: u.flags.force, - Conformance: u.flags.conformance, - HelmWaitMode: u.flags.helmWaitMode, - } - - prepareApply := func(allowDestructive bool) (helm.Applier, bool, error) { - options.AllowDestructive = allowDestructive - executor, includesUpgrades, err := u.helmApplier.PrepareApply(conf, stateFile, options, serviceAccURI, secret) - var upgradeErr *compatibility.InvalidUpgradeError - switch { - case errors.As(err, &upgradeErr): - cmd.PrintErrln(err) - case err != nil: - return nil, false, fmt.Errorf("getting chart executor: %w", err) - } - return executor, includesUpgrades, nil - } - - executor, includesUpgrades, err := prepareApply(helm.DenyDestructive) - if err != nil { - if !errors.Is(err, helm.ErrConfirmationMissing) { - return fmt.Errorf("upgrading charts with deny destructive mode: %w", err) - } - if !u.flags.yes { - cmd.PrintErrln("WARNING: Upgrading cert-manager will destroy all custom resources you have manually created that are based on the current version of cert-manager.") - ok, askErr := askToConfirm(cmd, "Do you want to upgrade cert-manager anyway?") - if askErr != nil { - return fmt.Errorf("asking for confirmation: %w", err) - } - if !ok { - cmd.Println("Skipping upgrade.") - return nil - } - } - executor, includesUpgrades, err = prepareApply(helm.AllowDestructive) - if err != nil { - return fmt.Errorf("upgrading charts with allow destructive mode: %w", err) - } - } - - // Save the Helm charts for the upgrade to disk - chartDir := filepath.Join(upgradeDir, "helm-charts") - if err := executor.SaveCharts(chartDir, u.fileHandler); err != nil { - return fmt.Errorf("saving Helm charts to disk: %w", err) - } - u.log.Debugf("Helm charts saved to %s", chartDir) - - if includesUpgrades { - u.log.Debugf("Creating backup of CRDs and CRs") - crds, err := u.kubeUpgrader.BackupCRDs(cmd.Context(), upgradeDir) - if err != nil { - return fmt.Errorf("creating CRD backup: %w", err) - } - if err := u.kubeUpgrader.BackupCRs(cmd.Context(), crds, upgradeDir); err != nil { - return fmt.Errorf("creating CR backup: %w", err) - } - } - if err := executor.Apply(cmd.Context()); err != nil { - return fmt.Errorf("applying Helm charts: %w", err) - } - - return nil -} - // skipPhases is a list of phases that can be skipped during the upgrade process. type skipPhases []skipPhase diff --git a/cli/internal/cmd/upgradeapply_test.go b/cli/internal/cmd/upgradeapply_test.go index 6f8662d9a4..18caef6cc4 100644 --- a/cli/internal/cmd/upgradeapply_test.go +++ b/cli/internal/cmd/upgradeapply_test.go @@ -55,14 +55,14 @@ func TestUpgradeApply(t *testing.T) { terraformUpgrader clusterUpgrader wantErr bool customK8sVersion string - flags upgradeApplyFlags + flags applyFlags stdin string }{ "success": { kubeUpgrader: &stubKubernetesUpgrader{currentConfig: config.DefaultForAzureSEVSNP()}, helmUpgrader: stubApplier{}, terraformUpgrader: &stubTerraformUpgrader{}, - flags: upgradeApplyFlags{yes: true}, + flags: applyFlags{yes: true}, fh: fsWithStateFile, fhAssertions: func(require *require.Assertions, assert *assert.Assertions, fh file.Handler) { gotState, err := state.ReadFromFile(fh, constants.StateFilename) @@ -71,11 +71,11 @@ func TestUpgradeApply(t *testing.T) { assert.Equal(defaultState, gotState) }, }, - "state file does not exist": { + "id file and state file do not exist": { kubeUpgrader: &stubKubernetesUpgrader{currentConfig: config.DefaultForAzureSEVSNP()}, helmUpgrader: stubApplier{}, terraformUpgrader: &stubTerraformUpgrader{}, - flags: upgradeApplyFlags{yes: true}, + flags: applyFlags{yes: true}, fh: func() file.Handler { return file.NewHandler(afero.NewMemMapFs()) }, @@ -89,7 +89,7 @@ func TestUpgradeApply(t *testing.T) { helmUpgrader: stubApplier{}, terraformUpgrader: &stubTerraformUpgrader{}, wantErr: true, - flags: upgradeApplyFlags{yes: true}, + flags: applyFlags{yes: true}, fh: fsWithStateFile, }, "nodeVersion in progress error": { @@ -99,7 +99,7 @@ func TestUpgradeApply(t *testing.T) { }, helmUpgrader: stubApplier{}, terraformUpgrader: &stubTerraformUpgrader{}, - flags: upgradeApplyFlags{yes: true}, + flags: applyFlags{yes: true}, fh: fsWithStateFile, }, "helm other error": { @@ -109,7 +109,7 @@ func TestUpgradeApply(t *testing.T) { helmUpgrader: stubApplier{err: assert.AnError}, terraformUpgrader: &stubTerraformUpgrader{}, wantErr: true, - flags: upgradeApplyFlags{yes: true}, + flags: applyFlags{yes: true}, fh: fsWithStateFile, }, "abort": { @@ -139,7 +139,7 @@ func TestUpgradeApply(t *testing.T) { helmUpgrader: stubApplier{}, terraformUpgrader: &stubTerraformUpgrader{planTerraformErr: assert.AnError}, wantErr: true, - flags: upgradeApplyFlags{yes: true}, + flags: applyFlags{yes: true}, fh: fsWithStateFile, }, "apply terraform error": { @@ -152,7 +152,7 @@ func TestUpgradeApply(t *testing.T) { terraformDiff: true, }, wantErr: true, - flags: upgradeApplyFlags{yes: true}, + flags: applyFlags{yes: true}, fh: fsWithStateFile, }, "outdated K8s patch version": { @@ -166,7 +166,7 @@ func TestUpgradeApply(t *testing.T) { require.NoError(t, err) return semver.NewFromInt(v.Major(), v.Minor(), v.Patch()-1, "").String() }(), - flags: upgradeApplyFlags{yes: true}, + flags: applyFlags{yes: true}, fh: fsWithStateFile, }, "outdated K8s version": { @@ -176,7 +176,7 @@ func TestUpgradeApply(t *testing.T) { helmUpgrader: stubApplier{}, terraformUpgrader: &stubTerraformUpgrader{}, customK8sVersion: "v1.20.0", - flags: upgradeApplyFlags{yes: true}, + flags: applyFlags{yes: true}, wantErr: true, fh: fsWithStateFile, }, @@ -186,7 +186,7 @@ func TestUpgradeApply(t *testing.T) { }, helmUpgrader: &mockApplier{}, // mocks ensure that no methods are called terraformUpgrader: &mockTerraformUpgrader{}, - flags: upgradeApplyFlags{ + flags: applyFlags{ skipPhases: []skipPhase{skipInfrastructurePhase, skipHelmPhase, skipK8sPhase, skipImagePhase}, yes: true, }, @@ -198,7 +198,7 @@ func TestUpgradeApply(t *testing.T) { }, helmUpgrader: &mockApplier{}, // mocks ensure that no methods are called terraformUpgrader: &mockTerraformUpgrader{}, - flags: upgradeApplyFlags{ + flags: applyFlags{ skipPhases: []skipPhase{skipInfrastructurePhase, skipHelmPhase, skipK8sPhase}, yes: true, }, @@ -218,20 +218,26 @@ func TestUpgradeApply(t *testing.T) { cfg.KubernetesVersion = versions.ValidK8sVersion(tc.customK8sVersion) } fh := tc.fh() + require.NoError(fh.Write(constants.AdminConfFilename, []byte{})) require.NoError(fh.WriteYAML(constants.ConfigFilename, cfg)) require.NoError(fh.WriteJSON(constants.MasterSecretFilename, uri.MasterSecret{})) - upgrader := upgradeApplyCmd{ - kubeUpgrader: tc.kubeUpgrader, - helmApplier: tc.helmUpgrader, + upgrader := &applyCmd{ + fileHandler: fh, + flags: tc.flags, + log: logger.NewTest(t), + spinner: &nopSpinner{}, + merger: &stubMerger{}, + quotaChecker: &stubLicenseClient{}, + newHelmClient: func(string, debugLog) (helmApplier, error) { + return tc.helmUpgrader, nil + }, + newKubeUpgrader: func(_ io.Writer, _ string, _ debugLog) (kubernetesUpgrader, error) { + return tc.kubeUpgrader, nil + }, clusterUpgrader: tc.terraformUpgrader, - log: logger.NewTest(t), - configFetcher: stubAttestationFetcher{}, - flags: tc.flags, - fileHandler: fh, } - - err := upgrader.upgradeApply(cmd, "test") + err := upgrader.apply(cmd, stubAttestationFetcher{}, "test") if tc.wantErr { assert.Error(err) return @@ -255,27 +261,35 @@ func TestUpgradeApplyFlagsForSkipPhases(t *testing.T) { cmd.Flags().Bool("force", true, "") cmd.Flags().String("tf-log", "NONE", "") cmd.Flags().Bool("debug", false, "") + cmd.Flags().Bool("merge-kubeconfig", false, "") require.NoError(cmd.Flags().Set("skip-phases", "infrastructure,helm,k8s,image")) - var flags upgradeApplyFlags + var flags applyFlags err := flags.parse(cmd.Flags()) require.NoError(err) assert.ElementsMatch(t, []skipPhase{skipInfrastructurePhase, skipHelmPhase, skipK8sPhase, skipImagePhase}, flags.skipPhases) } type stubKubernetesUpgrader struct { - nodeVersionErr error - currentConfig config.AttestationCfg - calledNodeUpgrade bool + nodeVersionErr error + currentConfig config.AttestationCfg + getClusterAttestationConfigErr error + calledNodeUpgrade bool + backupCRDsErr error + backupCRDsCalled bool + backupCRsErr error + backupCRsCalled bool } func (u *stubKubernetesUpgrader) BackupCRDs(_ context.Context, _ string) ([]apiextensionsv1.CustomResourceDefinition, error) { - return []apiextensionsv1.CustomResourceDefinition{}, nil + u.backupCRDsCalled = true + return []apiextensionsv1.CustomResourceDefinition{}, u.backupCRDsErr } func (u *stubKubernetesUpgrader) BackupCRs(_ context.Context, _ []apiextensionsv1.CustomResourceDefinition, _ string) error { - return nil + u.backupCRsCalled = true + return u.backupCRsErr } func (u *stubKubernetesUpgrader) UpgradeNodeVersion(_ context.Context, _ *config.Config, _, _, _ bool) error { @@ -288,7 +302,7 @@ func (u *stubKubernetesUpgrader) ApplyJoinConfig(_ context.Context, _ config.Att } func (u *stubKubernetesUpgrader) GetClusterAttestationConfig(_ context.Context, _ variant.Variant) (config.AttestationCfg, error) { - return u.currentConfig, nil + return u.currentConfig, u.getClusterAttestationConfigErr } func (u *stubKubernetesUpgrader) ExtendClusterConfigCertSANs(_ context.Context, _ []string) error {