Skip to content

Commit

Permalink
Update CLI verify command flag parsing
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Weiße <[email protected]>
  • Loading branch information
daniel-weisse committed Oct 12, 2023
1 parent 8fb2bdb commit 0cdf912
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 127 deletions.
219 changes: 105 additions & 114 deletions cli/internal/cmd/verify.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import (
tpmProto "github.com/google/go-tpm-tools/proto/tpm"

"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
"github.com/edgelesssys/constellation/v2/cli/internal/cmd/pathprefix"
"github.com/edgelesssys/constellation/v2/cli/internal/state"
"github.com/edgelesssys/constellation/v2/internal/api/attestationconfigapi"
"github.com/edgelesssys/constellation/v2/internal/atls"
Expand All @@ -45,6 +44,7 @@ import (
"github.com/google/go-sev-guest/kds"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
"google.golang.org/grpc"
)

Expand All @@ -64,8 +64,39 @@ func NewVerifyCmd() *cobra.Command {
return cmd
}

type verifyFlags struct {
rootFlags
endpoint string
ownerID string
clusterID string
output string
}

func (f *verifyFlags) parse(flags *pflag.FlagSet) error {
if err := f.rootFlags.parse(flags); err != nil {
return err
}

var err error
f.output, err = flags.GetString("output")
if err != nil {
return fmt.Errorf("getting 'output' flag: %w", err)
}
f.endpoint, err = flags.GetString("node-endpoint")
if err != nil {
return fmt.Errorf("getting 'node-endpoint' flag: %w", err)
}
f.clusterID, err = flags.GetString("cluster-id")
if err != nil {
return fmt.Errorf("getting 'cluster-id' flag: %w", err)
}
return nil
}

type verifyCmd struct {
log debugLog
fileHandler file.Handler
flags verifyFlags
log debugLog
}

func runVerify(cmd *cobra.Command, _ []string) error {
Expand Down Expand Up @@ -95,22 +126,23 @@ func runVerify(cmd *cobra.Command, _ []string) error {
return nil, fmt.Errorf("invalid output value for formatter: %s", output)
}
}
v := &verifyCmd{log: log}
v := &verifyCmd{
fileHandler: fileHandler,
log: log,
}
if err := v.flags.parse(cmd.Flags()); err != nil {
return err
}
v.log.Debugf("Using flags: %+v", v.flags)
fetcher := attestationconfigapi.NewFetcher()
return v.verify(cmd, fileHandler, verifyClient, formatterFactory, fetcher)
return v.verify(cmd, verifyClient, formatterFactory, fetcher)
}

type formatterFactory func(output string, provider cloudprovider.Provider, log debugLog) (attestationDocFormatter, error)

func (c *verifyCmd) verify(cmd *cobra.Command, fileHandler file.Handler, verifyClient verifyClient, factory formatterFactory, configFetcher attestationconfigapi.Fetcher) error {
flags, err := c.parseVerifyFlags(cmd, fileHandler)
if err != nil {
return fmt.Errorf("parsing flags: %w", err)
}
c.log.Debugf("Using flags: %+v", flags)

c.log.Debugf("Loading configuration file from %q", flags.pf.PrefixPrintablePath(constants.ConfigFilename))
conf, err := config.New(fileHandler, constants.ConfigFilename, configFetcher, flags.force)
func (c *verifyCmd) verify(cmd *cobra.Command, verifyClient verifyClient, factory formatterFactory, configFetcher attestationconfigapi.Fetcher) error {
c.log.Debugf("Loading configuration file from %q", c.flags.pathPrefixer.PrefixPrintablePath(constants.ConfigFilename))
conf, err := config.New(c.fileHandler, constants.ConfigFilename, configFetcher, c.flags.force)
var configValidationErr *config.ValidationError
if errors.As(err, &configValidationErr) {
cmd.PrintErrln(configValidationErr.LongMessage())
Expand All @@ -119,10 +151,29 @@ func (c *verifyCmd) verify(cmd *cobra.Command, fileHandler file.Handler, verifyC
return fmt.Errorf("loading config file: %w", err)
}

conf.UpdateMAAURL(flags.maaURL)
stateFile, err := state.ReadFromFile(c.fileHandler, constants.StateFilename)
if err != nil {
return fmt.Errorf("reading state file: %w", err)
}

ownerID, clusterID, err := c.validateIDFlags(cmd, stateFile)
if err != nil {
return err
}
endpoint, err := c.validateEndpointFlag(cmd, stateFile)
if err != nil {
return err
}

var maaURL string
if stateFile.Infrastructure.Azure != nil {
maaURL = stateFile.Infrastructure.Azure.AttestationURL
}
conf.UpdateMAAURL(maaURL)

c.log.Debugf("Updating expected PCRs")
attConfig := conf.GetAttestationConfig()
if err := cloudcmd.UpdateInitMeasurements(attConfig, flags.ownerID, flags.clusterID); err != nil {
if err := cloudcmd.UpdateInitMeasurements(attConfig, ownerID, clusterID); err != nil {
return fmt.Errorf("updating expected PCRs: %w", err)
}

Expand All @@ -140,7 +191,7 @@ func (c *verifyCmd) verify(cmd *cobra.Command, fileHandler file.Handler, verifyC

rawAttestationDoc, err := verifyClient.Verify(
cmd.Context(),
flags.endpoint,
endpoint,
&verifyproto.GetAttestationRequest{
Nonce: nonce,
},
Expand All @@ -151,7 +202,7 @@ func (c *verifyCmd) verify(cmd *cobra.Command, fileHandler file.Handler, verifyC
}

// certificates are only available for Azure
formatter, err := factory(flags.output, conf.GetProvider(), c.log)
formatter, err := factory(c.flags.output, conf.GetProvider(), c.log)
if err != nil {
return fmt.Errorf("creating formatter: %w", err)
}
Expand All @@ -160,7 +211,7 @@ func (c *verifyCmd) verify(cmd *cobra.Command, fileHandler file.Handler, verifyC
rawAttestationDoc,
conf.Provider.Azure == nil,
attConfig.GetMeasurements(),
flags.maaURL,
maaURL,
)
if err != nil {
return fmt.Errorf("printing attestation document: %w", err)
Expand All @@ -171,114 +222,37 @@ func (c *verifyCmd) verify(cmd *cobra.Command, fileHandler file.Handler, verifyC
return nil
}

func (c *verifyCmd) parseVerifyFlags(cmd *cobra.Command, fileHandler file.Handler) (verifyFlags, error) {
workDir, err := cmd.Flags().GetString("workspace")
if err != nil {
return verifyFlags{}, fmt.Errorf("parsing config path argument: %w", err)
}
c.log.Debugf("Flag 'workspace' set to %q", workDir)
pf := pathprefix.New(workDir)

ownerID := ""
clusterID, err := cmd.Flags().GetString("cluster-id")
if err != nil {
return verifyFlags{}, fmt.Errorf("parsing cluster-id argument: %w", err)
}
c.log.Debugf("Flag 'cluster-id' set to %q", clusterID)

endpoint, err := cmd.Flags().GetString("node-endpoint")
if err != nil {
return verifyFlags{}, fmt.Errorf("parsing node-endpoint argument: %w", err)
}
c.log.Debugf("Flag 'node-endpoint' set to %q", endpoint)

force, err := cmd.Flags().GetBool("force")
if err != nil {
return verifyFlags{}, fmt.Errorf("parsing force argument: %w", err)
}
c.log.Debugf("Flag 'force' set to %t", force)

output, err := cmd.Flags().GetString("output")
if err != nil {
return verifyFlags{}, fmt.Errorf("parsing raw argument: %w", err)
}
c.log.Debugf("Flag 'output' set to %t", output)

// Get empty values from state file
stateFile, err := state.ReadFromFile(fileHandler, constants.StateFilename)
isFileNotFound := errors.Is(err, afero.ErrFileNotFound)
if isFileNotFound {
c.log.Debugf("State file %q not found, using empty state", pf.PrefixPrintablePath(constants.StateFilename))
stateFile = state.New() // error compat
} else if err != nil {
return verifyFlags{}, fmt.Errorf("reading state file: %w", err)
}

emptyEndpoint := endpoint == ""
emptyIDs := ownerID == "" && clusterID == ""
if emptyEndpoint || emptyIDs {
c.log.Debugf("Trying to supplement empty flag values from %q", pf.PrefixPrintablePath(constants.StateFilename))
if emptyEndpoint {
cmd.PrintErrf("Using endpoint from %q. Specify --node-endpoint to override this.\n", pf.PrefixPrintablePath(constants.StateFilename))
endpoint = stateFile.Infrastructure.ClusterEndpoint
}
if emptyIDs {
cmd.PrintErrf("Using ID from %q. Specify --cluster-id to override this.\n", pf.PrefixPrintablePath(constants.StateFilename))
ownerID = stateFile.ClusterValues.OwnerID
clusterID = stateFile.ClusterValues.ClusterID
}
func (c *verifyCmd) validateIDFlags(cmd *cobra.Command, stateFile *state.State) (ownerID, clusterID string, err error) {
ownerID, clusterID = c.flags.ownerID, c.flags.clusterID
if c.flags.clusterID == "" {
cmd.PrintErrf("Using ID from %q. Specify --cluster-id to override this.\n", c.flags.pathPrefixer.PrefixPrintablePath(constants.StateFilename))
clusterID = stateFile.ClusterValues.ClusterID
}

var attestationURL string
if stateFile.Infrastructure.Azure != nil {
attestationURL = stateFile.Infrastructure.Azure.AttestationURL
if ownerID == "" {
// We don't want to print warnings until this is implemented again
// cmd.PrintErrf("Using ID from %q. Specify --owner-id to override this.\n", c.flags.pathPrefixer.PrefixPrintablePath(constants.StateFilename))
ownerID = stateFile.ClusterValues.OwnerID
}

// Validate
if ownerID == "" && clusterID == "" {
return verifyFlags{}, errors.New("cluster-id not provided to verify the cluster")
}
endpoint, err = addPortIfMissing(endpoint, constants.VerifyServiceNodePortGRPC)
if err != nil {
return verifyFlags{}, fmt.Errorf("validating endpoint argument: %w", err)
return "", "", errors.New("cluster-id not provided to verify the cluster")
}

return verifyFlags{
endpoint: endpoint,
pf: pf,
ownerID: ownerID,
clusterID: clusterID,
output: output,
maaURL: attestationURL,
force: force,
}, nil
}

type verifyFlags struct {
endpoint string
ownerID string
clusterID string
maaURL string
output string
force bool
pf pathprefix.PathPrefixer
return ownerID, clusterID, nil
}

func addPortIfMissing(endpoint string, defaultPort int) (string, error) {
func (c *verifyCmd) validateEndpointFlag(cmd *cobra.Command, stateFile *state.State) (string, error) {
endpoint := c.flags.endpoint
if endpoint == "" {
return "", errors.New("endpoint is empty")
}

_, _, err := net.SplitHostPort(endpoint)
if err == nil {
return endpoint, nil
cmd.PrintErrf("Using endpoint from %q. Specify --node-endpoint to override this.\n", c.flags.pathPrefixer.PrefixPrintablePath(constants.StateFilename))
endpoint = stateFile.Infrastructure.ClusterEndpoint
}

if strings.Contains(err.Error(), "missing port in address") {
return net.JoinHostPort(endpoint, strconv.Itoa(defaultPort)), nil
endpoint, err := addPortIfMissing(endpoint, constants.VerifyServiceNodePortGRPC)
if err != nil {
return "", fmt.Errorf("validating endpoint argument: %w", err)
}

return "", err
return endpoint, nil
}

// an attestationDocFormatter formats the attestation document.
Expand Down Expand Up @@ -869,3 +843,20 @@ func extractAzureInstanceInfo(docString string) (azureInstanceInfo, error) {
}
return instanceInfo, nil
}

func addPortIfMissing(endpoint string, defaultPort int) (string, error) {
if endpoint == "" {
return "", errors.New("endpoint is empty")
}

_, _, err := net.SplitHostPort(endpoint)
if err == nil {
return endpoint, nil
}

if strings.Contains(err.Error(), "missing port in address") {
return net.JoinHostPort(endpoint, strconv.Itoa(defaultPort)), nil
}

return "", err
}
Loading

0 comments on commit 0cdf912

Please sign in to comment.