diff --git a/cli/internal/cmd/verify.go b/cli/internal/cmd/verify.go index 68f787641e..856de315e2 100644 --- a/cli/internal/cmd/verify.go +++ b/cli/internal/cmd/verify.go @@ -26,7 +26,6 @@ import ( tpmProto "github.com/google/go-tpm-tools/proto/tpm" "github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd" - "github.com/edgelesssys/constellation/v2/cli/internal/cmd/pathprefix" "github.com/edgelesssys/constellation/v2/cli/internal/state" "github.com/edgelesssys/constellation/v2/internal/api/attestationconfigapi" "github.com/edgelesssys/constellation/v2/internal/atls" @@ -45,6 +44,7 @@ import ( "github.com/google/go-sev-guest/kds" "github.com/spf13/afero" "github.com/spf13/cobra" + "github.com/spf13/pflag" "google.golang.org/grpc" ) @@ -64,8 +64,39 @@ func NewVerifyCmd() *cobra.Command { return cmd } +type verifyFlags struct { + rootFlags + endpoint string + ownerID string + clusterID string + output string +} + +func (f *verifyFlags) parse(flags *pflag.FlagSet) error { + if err := f.rootFlags.parse(flags); err != nil { + return err + } + + var err error + f.output, err = flags.GetString("output") + if err != nil { + return fmt.Errorf("getting 'output' flag: %w", err) + } + f.endpoint, err = flags.GetString("node-endpoint") + if err != nil { + return fmt.Errorf("getting 'node-endpoint' flag: %w", err) + } + f.clusterID, err = flags.GetString("cluster-id") + if err != nil { + return fmt.Errorf("getting 'cluster-id' flag: %w", err) + } + return nil +} + type verifyCmd struct { - log debugLog + fileHandler file.Handler + flags verifyFlags + log debugLog } func runVerify(cmd *cobra.Command, _ []string) error { @@ -95,22 +126,23 @@ func runVerify(cmd *cobra.Command, _ []string) error { return nil, fmt.Errorf("invalid output value for formatter: %s", output) } } - v := &verifyCmd{log: log} + v := &verifyCmd{ + fileHandler: fileHandler, + log: log, + } + if err := v.flags.parse(cmd.Flags()); err != nil { + return err + } + v.log.Debugf("Using flags: %+v", v.flags) fetcher := attestationconfigapi.NewFetcher() - return v.verify(cmd, fileHandler, verifyClient, formatterFactory, fetcher) + return v.verify(cmd, verifyClient, formatterFactory, fetcher) } type formatterFactory func(output string, provider cloudprovider.Provider, log debugLog) (attestationDocFormatter, error) -func (c *verifyCmd) verify(cmd *cobra.Command, fileHandler file.Handler, verifyClient verifyClient, factory formatterFactory, configFetcher attestationconfigapi.Fetcher) error { - flags, err := c.parseVerifyFlags(cmd, fileHandler) - if err != nil { - return fmt.Errorf("parsing flags: %w", err) - } - c.log.Debugf("Using flags: %+v", flags) - - c.log.Debugf("Loading configuration file from %q", flags.pf.PrefixPrintablePath(constants.ConfigFilename)) - conf, err := config.New(fileHandler, constants.ConfigFilename, configFetcher, flags.force) +func (c *verifyCmd) verify(cmd *cobra.Command, verifyClient verifyClient, factory formatterFactory, configFetcher attestationconfigapi.Fetcher) error { + c.log.Debugf("Loading configuration file from %q", c.flags.pathPrefixer.PrefixPrintablePath(constants.ConfigFilename)) + conf, err := config.New(c.fileHandler, constants.ConfigFilename, configFetcher, c.flags.force) var configValidationErr *config.ValidationError if errors.As(err, &configValidationErr) { cmd.PrintErrln(configValidationErr.LongMessage()) @@ -119,10 +151,29 @@ func (c *verifyCmd) verify(cmd *cobra.Command, fileHandler file.Handler, verifyC return fmt.Errorf("loading config file: %w", err) } - conf.UpdateMAAURL(flags.maaURL) + stateFile, err := state.ReadFromFile(c.fileHandler, constants.StateFilename) + if err != nil { + return fmt.Errorf("reading state file: %w", err) + } + + ownerID, clusterID, err := c.validateIDFlags(cmd, stateFile) + if err != nil { + return err + } + endpoint, err := c.validateEndpointFlag(cmd, stateFile) + if err != nil { + return err + } + + var maaURL string + if stateFile.Infrastructure.Azure != nil { + maaURL = stateFile.Infrastructure.Azure.AttestationURL + } + conf.UpdateMAAURL(maaURL) + c.log.Debugf("Updating expected PCRs") attConfig := conf.GetAttestationConfig() - if err := cloudcmd.UpdateInitMeasurements(attConfig, flags.ownerID, flags.clusterID); err != nil { + if err := cloudcmd.UpdateInitMeasurements(attConfig, ownerID, clusterID); err != nil { return fmt.Errorf("updating expected PCRs: %w", err) } @@ -140,7 +191,7 @@ func (c *verifyCmd) verify(cmd *cobra.Command, fileHandler file.Handler, verifyC rawAttestationDoc, err := verifyClient.Verify( cmd.Context(), - flags.endpoint, + endpoint, &verifyproto.GetAttestationRequest{ Nonce: nonce, }, @@ -151,7 +202,7 @@ func (c *verifyCmd) verify(cmd *cobra.Command, fileHandler file.Handler, verifyC } // certificates are only available for Azure - formatter, err := factory(flags.output, conf.GetProvider(), c.log) + formatter, err := factory(c.flags.output, conf.GetProvider(), c.log) if err != nil { return fmt.Errorf("creating formatter: %w", err) } @@ -160,7 +211,7 @@ func (c *verifyCmd) verify(cmd *cobra.Command, fileHandler file.Handler, verifyC rawAttestationDoc, conf.Provider.Azure == nil, attConfig.GetMeasurements(), - flags.maaURL, + maaURL, ) if err != nil { return fmt.Errorf("printing attestation document: %w", err) @@ -171,114 +222,37 @@ func (c *verifyCmd) verify(cmd *cobra.Command, fileHandler file.Handler, verifyC return nil } -func (c *verifyCmd) parseVerifyFlags(cmd *cobra.Command, fileHandler file.Handler) (verifyFlags, error) { - workDir, err := cmd.Flags().GetString("workspace") - if err != nil { - return verifyFlags{}, fmt.Errorf("parsing config path argument: %w", err) - } - c.log.Debugf("Flag 'workspace' set to %q", workDir) - pf := pathprefix.New(workDir) - - ownerID := "" - clusterID, err := cmd.Flags().GetString("cluster-id") - if err != nil { - return verifyFlags{}, fmt.Errorf("parsing cluster-id argument: %w", err) - } - c.log.Debugf("Flag 'cluster-id' set to %q", clusterID) - - endpoint, err := cmd.Flags().GetString("node-endpoint") - if err != nil { - return verifyFlags{}, fmt.Errorf("parsing node-endpoint argument: %w", err) - } - c.log.Debugf("Flag 'node-endpoint' set to %q", endpoint) - - force, err := cmd.Flags().GetBool("force") - if err != nil { - return verifyFlags{}, fmt.Errorf("parsing force argument: %w", err) - } - c.log.Debugf("Flag 'force' set to %t", force) - - output, err := cmd.Flags().GetString("output") - if err != nil { - return verifyFlags{}, fmt.Errorf("parsing raw argument: %w", err) - } - c.log.Debugf("Flag 'output' set to %t", output) - - // Get empty values from state file - stateFile, err := state.ReadFromFile(fileHandler, constants.StateFilename) - isFileNotFound := errors.Is(err, afero.ErrFileNotFound) - if isFileNotFound { - c.log.Debugf("State file %q not found, using empty state", pf.PrefixPrintablePath(constants.StateFilename)) - stateFile = state.New() // error compat - } else if err != nil { - return verifyFlags{}, fmt.Errorf("reading state file: %w", err) - } - - emptyEndpoint := endpoint == "" - emptyIDs := ownerID == "" && clusterID == "" - if emptyEndpoint || emptyIDs { - c.log.Debugf("Trying to supplement empty flag values from %q", pf.PrefixPrintablePath(constants.StateFilename)) - if emptyEndpoint { - cmd.PrintErrf("Using endpoint from %q. Specify --node-endpoint to override this.\n", pf.PrefixPrintablePath(constants.StateFilename)) - endpoint = stateFile.Infrastructure.ClusterEndpoint - } - if emptyIDs { - cmd.PrintErrf("Using ID from %q. Specify --cluster-id to override this.\n", pf.PrefixPrintablePath(constants.StateFilename)) - ownerID = stateFile.ClusterValues.OwnerID - clusterID = stateFile.ClusterValues.ClusterID - } +func (c *verifyCmd) validateIDFlags(cmd *cobra.Command, stateFile *state.State) (ownerID, clusterID string, err error) { + ownerID, clusterID = c.flags.ownerID, c.flags.clusterID + if c.flags.clusterID == "" { + cmd.PrintErrf("Using ID from %q. Specify --cluster-id to override this.\n", c.flags.pathPrefixer.PrefixPrintablePath(constants.StateFilename)) + clusterID = stateFile.ClusterValues.ClusterID } - - var attestationURL string - if stateFile.Infrastructure.Azure != nil { - attestationURL = stateFile.Infrastructure.Azure.AttestationURL + if ownerID == "" { + // We don't want to print warnings until this is implemented again + // cmd.PrintErrf("Using ID from %q. Specify --owner-id to override this.\n", c.flags.pathPrefixer.PrefixPrintablePath(constants.StateFilename)) + ownerID = stateFile.ClusterValues.OwnerID } // Validate if ownerID == "" && clusterID == "" { - return verifyFlags{}, errors.New("cluster-id not provided to verify the cluster") - } - endpoint, err = addPortIfMissing(endpoint, constants.VerifyServiceNodePortGRPC) - if err != nil { - return verifyFlags{}, fmt.Errorf("validating endpoint argument: %w", err) + return "", "", errors.New("cluster-id not provided to verify the cluster") } - return verifyFlags{ - endpoint: endpoint, - pf: pf, - ownerID: ownerID, - clusterID: clusterID, - output: output, - maaURL: attestationURL, - force: force, - }, nil -} - -type verifyFlags struct { - endpoint string - ownerID string - clusterID string - maaURL string - output string - force bool - pf pathprefix.PathPrefixer + return ownerID, clusterID, nil } -func addPortIfMissing(endpoint string, defaultPort int) (string, error) { +func (c *verifyCmd) validateEndpointFlag(cmd *cobra.Command, stateFile *state.State) (string, error) { + endpoint := c.flags.endpoint if endpoint == "" { - return "", errors.New("endpoint is empty") - } - - _, _, err := net.SplitHostPort(endpoint) - if err == nil { - return endpoint, nil + cmd.PrintErrf("Using endpoint from %q. Specify --node-endpoint to override this.\n", c.flags.pathPrefixer.PrefixPrintablePath(constants.StateFilename)) + endpoint = stateFile.Infrastructure.ClusterEndpoint } - - if strings.Contains(err.Error(), "missing port in address") { - return net.JoinHostPort(endpoint, strconv.Itoa(defaultPort)), nil + endpoint, err := addPortIfMissing(endpoint, constants.VerifyServiceNodePortGRPC) + if err != nil { + return "", fmt.Errorf("validating endpoint argument: %w", err) } - - return "", err + return endpoint, nil } // an attestationDocFormatter formats the attestation document. @@ -869,3 +843,20 @@ func extractAzureInstanceInfo(docString string) (azureInstanceInfo, error) { } return instanceInfo, nil } + +func addPortIfMissing(endpoint string, defaultPort int) (string, error) { + if endpoint == "" { + return "", errors.New("endpoint is empty") + } + + _, _, err := net.SplitHostPort(endpoint) + if err == nil { + return endpoint, nil + } + + if strings.Contains(err.Error(), "missing port in address") { + return net.JoinHostPort(endpoint, strconv.Itoa(defaultPort)), nil + } + + return "", err +} diff --git a/cli/internal/cmd/verify_test.go b/cli/internal/cmd/verify_test.go index f2da96a257..a9a44c4515 100644 --- a/cli/internal/cmd/verify_test.go +++ b/cli/internal/cmd/verify_test.go @@ -58,6 +58,7 @@ func TestVerify(t *testing.T) { nodeEndpointFlag: "192.0.2.1:1234", clusterIDFlag: zeroBase64, protoClient: &stubVerifyClient{}, + stateFile: state.New(), wantEndpoint: "192.0.2.1:1234", formatter: &stubAttDocFormatter{}, }, @@ -66,6 +67,7 @@ func TestVerify(t *testing.T) { nodeEndpointFlag: "192.0.2.1:1234", clusterIDFlag: zeroBase64, protoClient: &stubVerifyClient{}, + stateFile: state.New(), wantEndpoint: "192.0.2.1:1234", formatter: &stubAttDocFormatter{}, }, @@ -74,6 +76,7 @@ func TestVerify(t *testing.T) { nodeEndpointFlag: "192.0.2.1", clusterIDFlag: zeroBase64, protoClient: &stubVerifyClient{}, + stateFile: state.New(), wantEndpoint: "192.0.2.1:" + strconv.Itoa(constants.VerifyServiceNodePortGRPC), formatter: &stubAttDocFormatter{}, }, @@ -81,6 +84,7 @@ func TestVerify(t *testing.T) { provider: cloudprovider.GCP, clusterIDFlag: zeroBase64, protoClient: &stubVerifyClient{}, + stateFile: state.New(), formatter: &stubAttDocFormatter{}, wantErr: true, }, @@ -106,12 +110,14 @@ func TestVerify(t *testing.T) { nodeEndpointFlag: ":::::", clusterIDFlag: zeroBase64, protoClient: &stubVerifyClient{}, + stateFile: state.New(), formatter: &stubAttDocFormatter{}, wantErr: true, }, "neither owner id nor cluster id set": { provider: cloudprovider.GCP, nodeEndpointFlag: "192.0.2.1:1234", + stateFile: state.New(), formatter: &stubAttDocFormatter{}, wantErr: true, }, @@ -127,6 +133,7 @@ func TestVerify(t *testing.T) { provider: cloudprovider.GCP, clusterIDFlag: zeroBase64, nodeEndpointFlag: "192.0.2.1:1234", + stateFile: state.New(), formatter: &stubAttDocFormatter{}, skipConfigCreation: true, wantErr: true, @@ -136,6 +143,7 @@ func TestVerify(t *testing.T) { nodeEndpointFlag: "192.0.2.1:1234", clusterIDFlag: zeroBase64, protoClient: &stubVerifyClient{verifyErr: rpcStatus.Error(codes.Internal, "failed")}, + stateFile: state.New(), formatter: &stubAttDocFormatter{}, wantErr: true, }, @@ -144,6 +152,7 @@ func TestVerify(t *testing.T) { nodeEndpointFlag: "192.0.2.1:1234", clusterIDFlag: zeroBase64, protoClient: &stubVerifyClient{verifyErr: someErr}, + stateFile: state.New(), formatter: &stubAttDocFormatter{}, wantErr: true, }, @@ -152,6 +161,7 @@ func TestVerify(t *testing.T) { nodeEndpointFlag: "192.0.2.1:1234", clusterIDFlag: zeroBase64, protoClient: &stubVerifyClient{}, + stateFile: state.New(), wantEndpoint: "192.0.2.1:1234", formatter: &stubAttDocFormatter{formatErr: someErr}, wantErr: true, @@ -164,31 +174,28 @@ func TestVerify(t *testing.T) { require := require.New(t) cmd := NewVerifyCmd() - cmd.Flags().String("workspace", "", "") // register persistent flag manually - cmd.Flags().Bool("force", true, "") // register persistent flag manually out := &bytes.Buffer{} cmd.SetErr(out) - if tc.clusterIDFlag != "" { - require.NoError(cmd.Flags().Set("cluster-id", tc.clusterIDFlag)) - } - if tc.nodeEndpointFlag != "" { - require.NoError(cmd.Flags().Set("node-endpoint", tc.nodeEndpointFlag)) - } fileHandler := file.NewHandler(afero.NewMemMapFs()) if !tc.skipConfigCreation { cfg := defaultConfigWithExpectedMeasurements(t, config.Default(), tc.provider) require.NoError(fileHandler.WriteYAML(constants.ConfigFilename, cfg)) } - if tc.stateFile != nil { - require.NoError(tc.stateFile.WriteToFile(fileHandler, constants.StateFilename)) + require.NoError(tc.stateFile.WriteToFile(fileHandler, constants.StateFilename)) + + v := &verifyCmd{ + fileHandler: fileHandler, + log: logger.NewTest(t), + flags: verifyFlags{ + clusterID: tc.clusterIDFlag, + endpoint: tc.nodeEndpointFlag, + }, } - - v := &verifyCmd{log: logger.NewTest(t)} formatterFac := func(_ string, _ cloudprovider.Provider, _ debugLog) (attestationDocFormatter, error) { return tc.formatter, nil } - err := v.verify(cmd, fileHandler, tc.protoClient, formatterFac, stubAttestationFetcher{}) + err := v.verify(cmd, tc.protoClient, formatterFac, stubAttestationFetcher{}) if tc.wantErr { assert.Error(err) } else {