From f5626a8142cbef11982df339542cb53cb01c5698 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?V=C3=ADctor=20Rold=C3=A1n=20Betancort?= Date: Wed, 18 Sep 2024 18:31:03 +0100 Subject: [PATCH] fixes context override A customer reported flags being ignored even when they were being provided via CLI arguments. The behavior of overriding current context via CLI arguments was inconsistent across the codebase. This centralizes the logic so all the various commands use the same logic. --- internal/client/client.go | 116 ++++++++++++++++++++++++++----- internal/client/client_test.go | 62 +++++++++++++++++ internal/cmd/schema.go | 22 +----- internal/cmd/version.go | 22 +----- internal/storage/config.go | 55 +++++++++------ internal/storage/config_test.go | 48 +++++++++++++ internal/storage/secrets.go | 8 +++ internal/storage/secrets_test.go | 19 +++++ internal/testing/test_helpers.go | 12 ++++ 9 files changed, 286 insertions(+), 78 deletions(-) create mode 100644 internal/client/client_test.go create mode 100644 internal/storage/config_test.go create mode 100644 internal/storage/secrets_test.go diff --git a/internal/client/client.go b/internal/client/client.go index 7fd8a617..b2362679 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -12,7 +12,7 @@ import ( "github.com/mitchellh/go-homedir" "github.com/rs/zerolog/log" "github.com/spf13/cobra" - grpc "google.golang.org/grpc" + "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" zgrpcutil "github.com/authzed/zed/internal/grpcutil" @@ -28,20 +28,17 @@ type Client interface { } // NewClient defines an (overridable) means of creating a new client. -var NewClient = newGRPCClient +var ( + NewClient = newClientForCurrentContext + NewClientForContext = newClientForContext +) -func newGRPCClient(cmd *cobra.Command) (Client, error) { +func newClientForCurrentContext(cmd *cobra.Command) (Client, error) { configStore, secretStore := DefaultStorage() - token, err := storage.DefaultToken( - cobrautil.MustGetString(cmd, "endpoint"), - cobrautil.MustGetString(cmd, "token"), - configStore, - secretStore, - ) + token, err := GetCurrentTokenWithCLIOverride(cmd, configStore, secretStore) if err != nil { return nil, err } - log.Trace().Interface("token", token).Send() dialOpts, err := DialOptsFromFlags(cmd, token) if err != nil { @@ -56,28 +53,115 @@ func newGRPCClient(cmd *cobra.Command) (Client, error) { return client, err } +func newClientForContext(cmd *cobra.Command, contextName string, secretStore storage.SecretStore) (*authzed.Client, error) { + currentToken, err := storage.GetToken(contextName, secretStore) + if err != nil { + return nil, err + } + + token, err := GetTokenWithCLIOverride(cmd, currentToken) + if err != nil { + return nil, err + } + + dialOpts, err := DialOptsFromFlags(cmd, token) + if err != nil { + return nil, err + } + + return authzed.NewClient(token.Endpoint, dialOpts...) +} + +// GetCurrentTokenWithCLIOverride returns the current token, but overridden by any parameter specified via CLI args +func GetCurrentTokenWithCLIOverride(cmd *cobra.Command, configStore storage.ConfigStore, secretStore storage.SecretStore) (storage.Token, error) { + token, err := storage.CurrentToken( + configStore, + secretStore, + ) + if err != nil { + return storage.Token{}, err + } + + return GetTokenWithCLIOverride(cmd, token) +} + +// GetTokenWithCLIOverride returns the provided token, but overridden by any parameter specified explicitly via command +// flags +func GetTokenWithCLIOverride(cmd *cobra.Command, token storage.Token) (storage.Token, error) { + overrideToken, err := tokenFromCli(cmd) + if err != nil { + return storage.Token{}, err + } + + result, err := storage.TokenWithOverride( + overrideToken, + token, + ) + if err != nil { + return storage.Token{}, err + } + + log.Trace().Bool("context-override-via-cli", overrideToken.AnyValue()).Interface("context", result).Send() + return result, nil +} + +func tokenFromCli(cmd *cobra.Command) (storage.Token, error) { + certPath := cobrautil.MustGetStringExpanded(cmd, "certificate-path") + var certBytes []byte + var err error + if certPath != "" { + certBytes, err = os.ReadFile(certPath) + if err != nil { + return storage.Token{}, fmt.Errorf("failed to read ceritficate: %w", err) + } + } + + explicitInsecure := cmd.Flags().Changed("insecure") + var notSecure *bool + if explicitInsecure { + i := cobrautil.MustGetBool(cmd, "insecure") + notSecure = &i + } + + explicitNoVerifyCA := cmd.Flags().Changed("no-verify-ca") + var notVerifyCA *bool + if explicitNoVerifyCA { + nvc := cobrautil.MustGetBool(cmd, "no-verify-ca") + notVerifyCA = &nvc + } + overrideToken := storage.Token{ + APIToken: cobrautil.MustGetString(cmd, "token"), + Endpoint: cobrautil.MustGetString(cmd, "endpoint"), + Insecure: notSecure, + NoVerifyCA: notVerifyCA, + CACert: certBytes, + } + return overrideToken, nil +} + // DefaultStorage returns the default configured config store and secret store. func DefaultStorage() (storage.ConfigStore, storage.SecretStore) { var home string if xdg := os.Getenv("XDG_CONFIG_HOME"); xdg != "" { home = filepath.Join(xdg, "zed") } else { - homedir, _ := homedir.Dir() - home = filepath.Join(homedir, ".zed") + hmdir, _ := homedir.Dir() + home = filepath.Join(hmdir, ".zed") } return &storage.JSONConfigStore{ConfigPath: home}, &storage.KeychainSecretStore{ConfigPath: home} } -func certOption(cmd *cobra.Command, token storage.Token) (opt grpc.DialOption, err error) { +func certOption(token storage.Token) (opt grpc.DialOption, err error) { verification := grpcutil.VerifyCA - if cobrautil.MustGetBool(cmd, "no-verify-ca") || token.HasNoVerifyCA() { + if token.HasNoVerifyCA() { verification = grpcutil.SkipVerifyCA } if certBytes, ok := token.Certificate(); ok { return grpcutil.WithCustomCertBytes(verification, certBytes) } + return grpcutil.WithSystemCerts(verification) } @@ -96,12 +180,12 @@ func DialOptsFromFlags(cmd *cobra.Command, token storage.Token) ([]grpc.DialOpti grpc.WithChainStreamInterceptor(zgrpcutil.StreamLogDispatchTrailers), } - if cobrautil.MustGetBool(cmd, "insecure") || (token.IsInsecure()) { + if token.IsInsecure() { opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) opts = append(opts, grpcutil.WithInsecureBearerToken(token.APIToken)) } else { opts = append(opts, grpcutil.WithBearerToken(token.APIToken)) - certOpt, err := certOption(cmd, token) + certOpt, err := certOption(token) if err != nil { return nil, fmt.Errorf("failed to configure TLS cert: %w", err) } diff --git a/internal/client/client_test.go b/internal/client/client_test.go new file mode 100644 index 00000000..e1a0d92d --- /dev/null +++ b/internal/client/client_test.go @@ -0,0 +1,62 @@ +package client_test + +import ( + "os" + "testing" + + "github.com/authzed/zed/internal/client" + "github.com/authzed/zed/internal/storage" + zedtesting "github.com/authzed/zed/internal/testing" + + "github.com/stretchr/testify/require" +) + +func TestGetTokenWithCLIOverride(t *testing.T) { + testCert, err := os.CreateTemp("", "") + require.NoError(t, err) + _, err = testCert.Write([]byte("hi")) + require.NoError(t, err) + cmd := zedtesting.CreateTestCobraCommandWithFlagValue(t, + zedtesting.StringFlag{FlagName: "token", FlagValue: "t1", Changed: true}, + zedtesting.StringFlag{FlagName: "certificate-path", FlagValue: testCert.Name(), Changed: true}, + zedtesting.StringFlag{FlagName: "endpoint", FlagValue: "e1", Changed: true}, + zedtesting.BoolFlag{FlagName: "insecure", FlagValue: true, Changed: true}, + zedtesting.BoolFlag{FlagName: "no-verify-ca", FlagValue: true, Changed: true}, + ) + + bTrue := true + bFalse := false + + // cli args take precedence when defined + to, err := client.GetTokenWithCLIOverride(cmd, storage.Token{}) + require.NoError(t, err) + require.True(t, to.AnyValue()) + require.Equal(t, "t1", to.APIToken) + require.Equal(t, "e1", to.Endpoint) + require.Equal(t, []byte("hi"), to.CACert) + require.Equal(t, &bTrue, to.Insecure) + require.Equal(t, &bTrue, to.NoVerifyCA) + + // storage token takes precedence when defined + cmd = zedtesting.CreateTestCobraCommandWithFlagValue(t, + zedtesting.StringFlag{FlagName: "token", FlagValue: "", Changed: false}, + zedtesting.StringFlag{FlagName: "certificate-path", FlagValue: "", Changed: false}, + zedtesting.StringFlag{FlagName: "endpoint", FlagValue: "", Changed: false}, + zedtesting.BoolFlag{FlagName: "insecure", FlagValue: true, Changed: false}, + zedtesting.BoolFlag{FlagName: "no-verify-ca", FlagValue: true, Changed: false}, + ) + to, err = client.GetTokenWithCLIOverride(cmd, storage.Token{ + APIToken: "t2", + Endpoint: "e2", + CACert: []byte("bye"), + Insecure: &bFalse, + NoVerifyCA: &bFalse, + }) + require.NoError(t, err) + require.True(t, to.AnyValue()) + require.Equal(t, "t2", to.APIToken) + require.Equal(t, "e2", to.Endpoint) + require.Equal(t, []byte("bye"), to.CACert) + require.Equal(t, &bFalse, to.Insecure) + require.Equal(t, &bFalse, to.NoVerifyCA) +} diff --git a/internal/cmd/schema.go b/internal/cmd/schema.go index dc8439b3..28b11862 100644 --- a/internal/cmd/schema.go +++ b/internal/cmd/schema.go @@ -9,7 +9,6 @@ import ( "strings" v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" - "github.com/authzed/authzed-go/v1" "github.com/authzed/spicedb/pkg/schemadsl/compiler" "github.com/authzed/spicedb/pkg/schemadsl/generator" "github.com/authzed/spicedb/pkg/schemadsl/input" @@ -23,7 +22,6 @@ import ( "github.com/authzed/zed/internal/client" "github.com/authzed/zed/internal/commands" "github.com/authzed/zed/internal/console" - "github.com/authzed/zed/internal/storage" ) func registerAdditionalSchemaCmds(schemaCmd *cobra.Command) { @@ -52,28 +50,14 @@ var schemaCopyCmd = &cobra.Command{ RunE: schemaCopyCmdFunc, } -// TODO(jschorr): support this in the client package -func clientForContext(cmd *cobra.Command, contextName string, secretStore storage.SecretStore) (*authzed.Client, error) { - token, err := storage.GetToken(contextName, secretStore) - if err != nil { - return nil, err - } - log.Trace().Interface("token", token).Send() - - dialOpts, err := client.DialOptsFromFlags(cmd, token) - if err != nil { - return nil, err - } - return authzed.NewClient(token.Endpoint, dialOpts...) -} - func schemaCopyCmdFunc(cmd *cobra.Command, args []string) error { _, secretStore := client.DefaultStorage() - srcClient, err := clientForContext(cmd, args[0], secretStore) + srcClient, err := client.NewClientForContext(cmd, args[0], secretStore) if err != nil { return err } - destClient, err := clientForContext(cmd, args[1], secretStore) + + destClient, err := client.NewClientForContext(cmd, args[1], secretStore) if err != nil { return err } diff --git a/internal/cmd/version.go b/internal/cmd/version.go index 8c67b16d..95d30d8d 100644 --- a/internal/cmd/version.go +++ b/internal/cmd/version.go @@ -9,14 +9,12 @@ import ( "github.com/gookit/color" "github.com/jzelinskie/cobrautil/v2" "github.com/mattn/go-isatty" - "github.com/rs/zerolog/log" "github.com/spf13/cobra" "google.golang.org/grpc" "google.golang.org/grpc/metadata" "github.com/authzed/zed/internal/client" "github.com/authzed/zed/internal/console" - "github.com/authzed/zed/internal/storage" ) func versionCmdFunc(cmd *cobra.Command, _ []string) error { @@ -26,14 +24,9 @@ func versionCmdFunc(cmd *cobra.Command, _ []string) error { includeRemoteVersion := cobrautil.MustGetBool(cmd, "include-remote-version") hasContext := false - configStore, secretStore := client.DefaultStorage() if includeRemoteVersion { - _, err := storage.DefaultToken( - cobrautil.MustGetString(cmd, "endpoint"), - cobrautil.MustGetString(cmd, "token"), - configStore, - secretStore, - ) + configStore, secretStore := client.DefaultStorage() + _, err := client.GetCurrentTokenWithCLIOverride(cmd, configStore, secretStore) hasContext = err == nil } @@ -45,17 +38,6 @@ func versionCmdFunc(cmd *cobra.Command, _ []string) error { console.Println(cobrautil.UsageVersion("zed", cobrautil.MustGetBool(cmd, "include-deps"))) if hasContext && includeRemoteVersion { - token, err := storage.DefaultToken( - cobrautil.MustGetString(cmd, "endpoint"), - cobrautil.MustGetString(cmd, "token"), - configStore, - secretStore, - ) - if err != nil { - return err - } - log.Trace().Interface("token", token).Send() - client, err := client.NewClient(cmd) if err != nil { return err diff --git a/internal/storage/config.go b/internal/storage/config.go index ef379d07..739079eb 100644 --- a/internal/storage/config.go +++ b/internal/storage/config.go @@ -27,38 +27,47 @@ type ConfigStore interface { Put(Config) error } -var ErrMissingToken = errors.New("could not find token") +// TokenWithOverride returns a Token that retrieves its values from the reference Token, and has its values overridden +// any of the non-empty/non-nil values of the overrideToken. +func TokenWithOverride(overrideToken Token, referenceToken Token) (Token, error) { + insecure := referenceToken.Insecure + if overrideToken.Insecure != nil { + insecure = overrideToken.Insecure + } -// DefaultToken creates a Token from input, filling any missing values in -// with the current context's defaults. -func DefaultToken(overrideEndpoint, overrideAPIToken string, cs ConfigStore, ss SecretStore) (Token, error) { - if overrideEndpoint != "" && overrideAPIToken != "" { - return Token{ - Name: "env", - Endpoint: overrideEndpoint, - APIToken: overrideAPIToken, - }, nil + // done so that logging messages don't show nil for the resulting context + if insecure == nil { + bFalse := false + insecure = &bFalse } - token, err := CurrentToken(cs, ss) - if err != nil { - if errors.Is(err, ErrConfigNotFound) { - return Token{}, errors.New("no context found: see `zed context set --help` to setup a context or make sure to specifiy *all* context flags (--endpoint, --token and --insecure if necessary) to run without context") - } - return Token{}, err + noVerifyCA := referenceToken.NoVerifyCA + if overrideToken.NoVerifyCA != nil { + noVerifyCA = overrideToken.NoVerifyCA + } + + // done so that logging messages don't show nil for the resulting context + if noVerifyCA == nil { + bFalse := false + noVerifyCA = &bFalse + } + + caCert := referenceToken.CACert + if overrideToken.CACert != nil { + caCert = overrideToken.CACert } return Token{ - Name: token.Name, - Endpoint: stringz.DefaultEmpty(overrideEndpoint, token.Endpoint), - APIToken: stringz.DefaultEmpty(overrideAPIToken, token.APIToken), - Insecure: token.Insecure, - NoVerifyCA: token.NoVerifyCA, - CACert: token.CACert, + Name: referenceToken.Name, + Endpoint: stringz.DefaultEmpty(overrideToken.Endpoint, referenceToken.Endpoint), + APIToken: stringz.DefaultEmpty(overrideToken.APIToken, referenceToken.APIToken), + Insecure: insecure, + NoVerifyCA: noVerifyCA, + CACert: caCert, }, nil } -// CurrentToken is convenient way to obtain the CurrentToken field from the +// CurrentToken is a convenient way to obtain the CurrentToken field from the // current Config. func CurrentToken(cs ConfigStore, ss SecretStore) (Token, error) { cfg, err := cs.Get() diff --git a/internal/storage/config_test.go b/internal/storage/config_test.go new file mode 100644 index 00000000..aa65d6ed --- /dev/null +++ b/internal/storage/config_test.go @@ -0,0 +1,48 @@ +package storage + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestTokenWithOverride(t *testing.T) { + bTrue := true + referenceToken := Token{ + Name: "n1", + Endpoint: "e1", + APIToken: "a1", + Insecure: &bTrue, + NoVerifyCA: &bTrue, + CACert: []byte("c1"), + } + + bFalse := false + override := Token{ + Name: "n2", + Endpoint: "e2", + APIToken: "a2", + Insecure: &bFalse, + NoVerifyCA: &bFalse, + CACert: []byte("c2"), + } + + result, err := TokenWithOverride(override, referenceToken) + require.NoError(t, err) + require.Equal(t, "n1", result.Name) + require.Equal(t, "e2", result.Endpoint) + require.Equal(t, "a2", result.APIToken) + require.Equal(t, false, *result.Insecure) + require.Equal(t, false, *result.NoVerifyCA) + require.Equal(t, 0, bytes.Compare([]byte("c2"), result.CACert)) + + result, err = TokenWithOverride(Token{}, referenceToken) + require.NoError(t, err) + require.Equal(t, "n1", result.Name) + require.Equal(t, "e1", result.Endpoint) + require.Equal(t, "a1", result.APIToken) + require.Equal(t, true, *result.Insecure) + require.Equal(t, true, *result.NoVerifyCA) + require.Equal(t, 0, bytes.Compare([]byte("c1"), result.CACert)) +} diff --git a/internal/storage/secrets.go b/internal/storage/secrets.go index d29d9a98..90d1ebb7 100644 --- a/internal/storage/secrets.go +++ b/internal/storage/secrets.go @@ -26,6 +26,14 @@ type Token struct { CACert []byte } +func (t Token) AnyValue() bool { + if t.Endpoint != "" || t.APIToken != "" || t.Insecure != nil || t.NoVerifyCA != nil || len(t.CACert) > 0 { + return true + } + + return false +} + func (t Token) Certificate() (cert []byte, ok bool) { if len(t.CACert) > 0 { return t.CACert, true diff --git a/internal/storage/secrets_test.go b/internal/storage/secrets_test.go new file mode 100644 index 00000000..da8dd154 --- /dev/null +++ b/internal/storage/secrets_test.go @@ -0,0 +1,19 @@ +package storage + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestTokenAnyValue(t *testing.T) { + b := false + + require.False(t, Token{}.AnyValue()) + require.False(t, Token{}.AnyValue()) + require.True(t, Token{Endpoint: "foo"}.AnyValue()) + require.True(t, Token{APIToken: "foo"}.AnyValue()) + require.True(t, Token{Insecure: &b}.AnyValue()) + require.True(t, Token{NoVerifyCA: &b}.AnyValue()) + require.True(t, Token{CACert: []byte("a")}.AnyValue()) +} diff --git a/internal/testing/test_helpers.go b/internal/testing/test_helpers.go index 451b2ff7..8535ba0f 100644 --- a/internal/testing/test_helpers.go +++ b/internal/testing/test_helpers.go @@ -64,31 +64,37 @@ func NewTestServer(ctx context.Context, t *testing.T) server.RunnableServer { type StringFlag struct { FlagName string FlagValue string + Changed bool } type BoolFlag struct { FlagName string FlagValue bool + Changed bool } type IntFlag struct { FlagName string FlagValue int + Changed bool } type UintFlag struct { FlagName string FlagValue uint + Changed bool } type UintFlag32 struct { FlagName string FlagValue uint32 + Changed bool } type DurationFlag struct { FlagName string FlagValue time.Duration + Changed bool } func CreateTestCobraCommandWithFlagValue(t *testing.T, flagAndValues ...any) *cobra.Command { @@ -99,16 +105,22 @@ func CreateTestCobraCommandWithFlagValue(t *testing.T, flagAndValues ...any) *co switch f := flagAndValue.(type) { case StringFlag: c.Flags().String(f.FlagName, f.FlagValue, "") + c.Flag(f.FlagName).Changed = f.Changed case BoolFlag: c.Flags().Bool(f.FlagName, f.FlagValue, "") + c.Flag(f.FlagName).Changed = f.Changed case IntFlag: c.Flags().Int(f.FlagName, f.FlagValue, "") + c.Flag(f.FlagName).Changed = f.Changed case UintFlag: c.Flags().Uint(f.FlagName, f.FlagValue, "") + c.Flag(f.FlagName).Changed = f.Changed case UintFlag32: c.Flags().Uint32(f.FlagName, f.FlagValue, "") + c.Flag(f.FlagName).Changed = f.Changed case DurationFlag: c.Flags().Duration(f.FlagName, f.FlagValue, "") + c.Flag(f.FlagName).Changed = f.Changed default: t.Fatalf("unknown flag type: %T", f) }