diff --git a/parse.go b/parse.go index 10c6841..1416223 100644 --- a/parse.go +++ b/parse.go @@ -159,7 +159,7 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) { } // Check whether this field is supported. It's good to do this here rather than - // wait until setScalar because it means that a program with invalid argument + // wait until ParseValue because it means that a program with invalid argument // fields will always fail regardless of whether the arguments it received // exercised those fields. var parseable bool @@ -275,7 +275,7 @@ func process(specs []*spec, args []string) error { } if spec.env != "" { if value, found := os.LookupEnv(spec.env); found { - err := setScalar(spec.dest, value) + err := scalar.ParseValue(spec.dest, value) if err != nil { return fmt.Errorf("error processing environment variable %s: %v", spec.env, err) } @@ -355,7 +355,7 @@ func process(specs []*spec, args []string) error { i++ } - err := setScalar(spec.dest, value) + err := scalar.ParseValue(spec.dest, value) if err != nil { return fmt.Errorf("error processing %s: %v", arg, err) } @@ -374,7 +374,7 @@ func process(specs []*spec, args []string) error { } positionals = nil } else if len(positionals) > 0 { - err := setScalar(spec.dest, positionals[0]) + err := scalar.ParseValue(spec.dest, positionals[0]) if err != nil { return fmt.Errorf("error processing %s: %v", spec.long, err) } @@ -426,7 +426,7 @@ func setSlice(dest reflect.Value, values []string, trunc bool) error { var ptr bool elem := dest.Type().Elem() - if elem.Kind() == reflect.Ptr { + if elem.Kind() == reflect.Ptr && !elem.Implements(textUnmarshalerType) { ptr = true elem = elem.Elem() } @@ -438,7 +438,7 @@ func setSlice(dest reflect.Value, values []string, trunc bool) error { for _, s := range values { v := reflect.New(elem) - if err := setScalar(v.Elem(), s); err != nil { + if err := scalar.ParseValue(v.Elem(), s); err != nil { return err } if !ptr { @@ -451,7 +451,8 @@ func setSlice(dest reflect.Value, values []string, trunc bool) error { // canParse returns true if the type can be parsed from a string func canParse(t reflect.Type) (parseable, boolean, multiple bool) { - parseable, boolean = isScalar(t) + parseable = scalar.CanParse(t) + boolean = isBoolean(t) if parseable { return } @@ -466,7 +467,8 @@ func canParse(t reflect.Type) (parseable, boolean, multiple bool) { t = t.Elem() } - parseable, boolean = isScalar(t) + parseable = scalar.CanParse(t) + boolean = isBoolean(t) if parseable { return } @@ -476,7 +478,8 @@ func canParse(t reflect.Type) (parseable, boolean, multiple bool) { t = t.Elem() } - parseable, boolean = isScalar(t) + parseable = scalar.CanParse(t) + boolean = isBoolean(t) if parseable { return } @@ -486,22 +489,16 @@ func canParse(t reflect.Type) (parseable, boolean, multiple bool) { var textUnmarshalerType = reflect.TypeOf([]encoding.TextUnmarshaler{}).Elem() -// isScalar returns true if the type can be parsed from a single string -func isScalar(t reflect.Type) (parseable, boolean bool) { - parseable = scalar.CanParse(t) +// isBoolean returns true if the type can be parsed from a single string +func isBoolean(t reflect.Type) bool { switch { case t.Implements(textUnmarshalerType): - return parseable, false + return false case t.Kind() == reflect.Bool: - return parseable, true + return true case t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Bool: - return parseable, true + return true default: - return parseable, false + return false } } - -// set a value from a string -func setScalar(v reflect.Value, s string) error { - return scalar.ParseValue(v, s) -} diff --git a/parse_test.go b/parse_test.go index 925a23e..1461c02 100644 --- a/parse_test.go +++ b/parse_test.go @@ -599,6 +599,32 @@ func TestTextUnmarshaler(t *testing.T) { assert.Equal(t, 3, args.Foo.val) } +func TestRepeatedTextUnmarshaler(t *testing.T) { + // fields that implement TextUnmarshaler should be parsed using that interface + var args struct { + Foo []*textUnmarshaler + } + err := parse("--foo abc d ef", &args) + require.NoError(t, err) + require.Len(t, args.Foo, 3) + assert.Equal(t, 3, args.Foo[0].val) + assert.Equal(t, 1, args.Foo[1].val) + assert.Equal(t, 2, args.Foo[2].val) +} + +func TestPositionalTextUnmarshaler(t *testing.T) { + // fields that implement TextUnmarshaler should be parsed using that interface + var args struct { + Foo []*textUnmarshaler `arg:"positional"` + } + err := parse("abc d ef", &args) + require.NoError(t, err) + require.Len(t, args.Foo, 3) + assert.Equal(t, 3, args.Foo[0].val) + assert.Equal(t, 1, args.Foo[1].val) + assert.Equal(t, 2, args.Foo[2].val) +} + type boolUnmarshaler bool func (p *boolUnmarshaler) UnmarshalText(b []byte) error {