diff --git a/pkg/cli/bool_flag.go b/pkg/cli/bool_flag.go index 0730c6bf2b..1b3ffc69b3 100644 --- a/pkg/cli/bool_flag.go +++ b/pkg/cli/bool_flag.go @@ -3,6 +3,7 @@ package cli import ( libflag "flag" + "github.com/gruntwork-io/go-commons/errors" "github.com/urfave/cli/v2" ) @@ -42,11 +43,22 @@ func (flag *BoolFlag) Apply(set *libflag.FlagSet) error { flag.Destination = new(bool) } - var err error + var ( + err error + envValue *string + ) valType := FlagType[bool](&boolFlagType{negative: flag.Negative}) - if flag.FlagValue, err = newGenericValue(valType, flag.LookupEnv(flag.EnvVar), flag.Destination); err != nil { + if val := flag.LookupEnv(flag.EnvVar); val != nil && *val != "" { + envValue = val + } + + if flag.FlagValue, err = newGenericValue(valType, envValue, flag.Destination); err != nil { + if envValue != nil { + return errors.Errorf("invalid boolean value %q for %s: %w", *envValue, flag.EnvVar, err) + } + return err } diff --git a/pkg/cli/bool_flag_test.go b/pkg/cli/bool_flag_test.go index 9de3ed49dd..582682f263 100644 --- a/pkg/cli/bool_flag_test.go +++ b/pkg/cli/bool_flag_test.go @@ -11,6 +11,7 @@ import ( "github.com/gruntwork-io/terragrunt/pkg/cli" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/exp/maps" ) func TestBoolFlagApply(t *testing.T) { @@ -86,6 +87,20 @@ func TestBoolFlagApply(t *testing.T) { false, errors.New(`invalid boolean flag foo: setting the flag multiple times`), }, + { + cli.BoolFlag{Name: "foo", EnvVar: "FOO"}, + nil, + map[string]string{"FOO": ""}, + false, + nil, + }, + { + cli.BoolFlag{Name: "foo", EnvVar: "FOO"}, + nil, + map[string]string{"FOO": "monkey"}, + false, + errors.New(`invalid boolean value "monkey" for FOO: must be one of: "0", "1", "f", "t", "false", "true"`), + }, } for i, testCase := range testCases { @@ -130,11 +145,12 @@ func testBoolFlagApply(t *testing.T, flag *cli.BoolFlag, args []string, envs map flagSet.SetOutput(io.Discard) err := flag.Apply(flagSet) - require.NoError(t, err) + if err == nil { + err = flagSet.Parse(args) + } - err = flagSet.Parse(args) if expectedErr != nil { - require.Equal(t, expectedErr, err) + require.ErrorContains(t, expectedErr, err.Error()) return } require.NoError(t, err) @@ -148,6 +164,8 @@ func testBoolFlagApply(t *testing.T, flag *cli.BoolFlag, args []string, envs map assert.Equal(t, strconv.FormatBool(expectedValue), flag.GetValue(), "GetValue()") } + maps.DeleteFunc(envs, func(k, v string) bool { return v == "" }) + assert.Equal(t, len(args) > 0 || len(envs) > 0, flag.Value().IsSet(), "IsSet()") assert.Equal(t, expectedDefaultValue, flag.Value().GetDefaultText(), "GetDefaultText()") diff --git a/pkg/cli/errors.go b/pkg/cli/errors.go index 1a922a9294..aad91324aa 100644 --- a/pkg/cli/errors.go +++ b/pkg/cli/errors.go @@ -88,3 +88,17 @@ func handleExitCoder(err error, osExiter func(code int)) error { return err } + +// InvalidValueError is used to wrap errors from `strconv` to make the error message more user friendly. +type InvalidValueError struct { + underlyingError error + msg string +} + +func (err InvalidValueError) Error() string { + return err.msg +} + +func (err InvalidValueError) Unwrap() error { + return err.underlyingError +} diff --git a/pkg/cli/generic_flag.go b/pkg/cli/generic_flag.go index b86dd0cb87..6628c9c826 100644 --- a/pkg/cli/generic_flag.go +++ b/pkg/cli/generic_flag.go @@ -47,11 +47,22 @@ func (flag *GenericFlag[T]) Apply(set *libflag.FlagSet) error { flag.Destination = new(T) } - var err error + var ( + err error + envValue *string + ) valType := FlagType[T](new(genericType[T])) - if flag.FlagValue, err = newGenericValue(valType, flag.LookupEnv(flag.EnvVar), flag.Destination); err != nil { + if val := flag.LookupEnv(flag.EnvVar); val != nil { + envValue = val + } + + if flag.FlagValue, err = newGenericValue(valType, envValue, flag.Destination); err != nil { + if envValue != nil { + return errors.Errorf("invalid value %q for %s: %w", *envValue, flag.EnvVar, err) + } + return err } @@ -199,7 +210,7 @@ func (val *genericType[T]) Set(str string) error { case *bool: v, err := strconv.ParseBool(str) if err != nil { - return errors.Errorf("error parse: %w", err) + return errors.WithStackTrace(InvalidValueError{underlyingError: err, msg: `must be one of: "0", "1", "f", "t", "false", "true"`}) } *dest = v @@ -207,7 +218,7 @@ func (val *genericType[T]) Set(str string) error { case *int: v, err := strconv.ParseInt(str, 0, strconv.IntSize) if err != nil { - return errors.Errorf("error parse: %w", err) + return errors.WithStackTrace(InvalidValueError{underlyingError: err, msg: "must be 32-bit integer"}) } *dest = int(v) @@ -215,7 +226,7 @@ func (val *genericType[T]) Set(str string) error { case *uint: v, err := strconv.ParseUint(str, 10, 64) if err != nil { - return errors.Errorf("error parse: %w", err) + return errors.WithStackTrace(InvalidValueError{underlyingError: err, msg: "must be 32-bit unsigned integer"}) } *dest = uint(v) @@ -223,7 +234,7 @@ func (val *genericType[T]) Set(str string) error { case *int64: v, err := strconv.ParseInt(str, 0, 64) if err != nil { - return errors.Errorf("error parse: %w", err) + return errors.WithStackTrace(InvalidValueError{underlyingError: err, msg: "must be 64-bit integer"}) } *dest = v diff --git a/pkg/cli/generic_flag_test.go b/pkg/cli/generic_flag_test.go index 2138274be1..ac864e72ea 100644 --- a/pkg/cli/generic_flag_test.go +++ b/pkg/cli/generic_flag_test.go @@ -101,6 +101,13 @@ func TestGenericFlagIntApply(t *testing.T) { 20, nil, }, + { + cli.GenericFlag[int]{Name: "foo", EnvVar: "FOO"}, + []string{}, + map[string]string{"FOO": "monkey"}, + 0, + errors.New(`invalid value "monkey" for FOO: must be 32-bit integer`), + }, { cli.GenericFlag[int]{Name: "foo", Destination: mockDestValue(55)}, nil, @@ -145,6 +152,13 @@ func TestGenericFlagInt64Apply(t *testing.T) { 20, nil, }, + { + cli.GenericFlag[int64]{Name: "foo", EnvVar: "FOO"}, + []string{}, + map[string]string{"FOO": "monkey"}, + 0, + errors.New(`invalid value "monkey" for FOO: must be 64-bit integer`), + }, { cli.GenericFlag[int64]{Name: "foo", Destination: mockDestValue(int64(55))}, nil, @@ -196,11 +210,12 @@ func testGenericFlagApply[T cli.GenericType](t *testing.T, flag *cli.GenericFlag flagSet.SetOutput(io.Discard) err := flag.Apply(flagSet) - require.NoError(t, err) + if err == nil { + err = flagSet.Parse(args) + } - err = flagSet.Parse(args) if expectedErr != nil { - require.Equal(t, expectedErr, err) + require.ErrorContains(t, expectedErr, err.Error()) return } require.NoError(t, err) diff --git a/pkg/cli/map_flag.go b/pkg/cli/map_flag.go index 4a7ac691e1..dd28504921 100644 --- a/pkg/cli/map_flag.go +++ b/pkg/cli/map_flag.go @@ -76,12 +76,23 @@ func (flag *MapFlag[K, V]) Apply(set *libflag.FlagSet) error { flag.KeyValSep = MapFlagKeyValSep } - var err error + var ( + err error + envValue *string + ) keyType := FlagType[K](new(genericType[K])) valType := FlagType[V](new(genericType[V])) - if flag.FlagValue, err = newMapValue(keyType, valType, flag.LookupEnv(flag.EnvVar), flag.EnvVarSep, flag.KeyValSep, flag.Splitter, flag.Destination); err != nil { + if val := flag.LookupEnv(flag.EnvVar); val != nil { + envValue = val + } + + if flag.FlagValue, err = newMapValue(keyType, valType, envValue, flag.EnvVarSep, flag.KeyValSep, flag.Splitter, flag.Destination); err != nil { + if envValue != nil { + return errors.Errorf("invalid value %q for %s: %w", *envValue, flag.EnvVar, err) + } + return err } diff --git a/pkg/cli/slice_flag.go b/pkg/cli/slice_flag.go index 3c355fe069..5f25ea1ec6 100644 --- a/pkg/cli/slice_flag.go +++ b/pkg/cli/slice_flag.go @@ -4,6 +4,7 @@ import ( libflag "flag" "strings" + "github.com/gruntwork-io/go-commons/errors" "github.com/urfave/cli/v2" ) @@ -62,11 +63,22 @@ func (flag *SliceFlag[T]) Apply(set *libflag.FlagSet) error { flag.EnvVarSep = SliceFlagEnvVarSep } - var err error + var ( + err error + envValue *string + ) valType := FlagType[T](new(genericType[T])) - if flag.FlagValue, err = newSliceValue(valType, flag.LookupEnv(flag.EnvVar), flag.EnvVarSep, flag.Splitter, flag.Destination); err != nil { + if val := flag.LookupEnv(flag.EnvVar); val != nil { + envValue = val + } + + if flag.FlagValue, err = newSliceValue(valType, envValue, flag.EnvVarSep, flag.Splitter, flag.Destination); err != nil { + if envValue != nil { + return errors.Errorf("invalid value %q for %s: %w", *envValue, flag.EnvVar, err) + } + return err }