From f9008993421ace348e1f643b9af870ec32bc50ee Mon Sep 17 00:00:00 2001 From: mars315 <57739378+mars315@users.noreply.github.com> Date: Thu, 4 Jan 2024 14:20:39 +0800 Subject: [PATCH] auto bind server node flags --- example/full/cmd.go | 10 +- flag.go | 351 +++++++++++++++++------------------------ lib/builtin/builtin.go | 2 + 3 files changed, 157 insertions(+), 206 deletions(-) diff --git a/example/full/cmd.go b/example/full/cmd.go index c1ae242..57246cf 100644 --- a/example/full/cmd.go +++ b/example/full/cmd.go @@ -1,10 +1,12 @@ -// `go run cmd.go --a2.f2 "KK"` +// `go run cmd.go --f2 "KK"` // output: -// // { // "F1": "x86xm", // "F2": "KK", // "F3": 87, +// "A4": { +// "F4": 99 +// }, // "DBUrl": ":27071", // "LogFile": "stdout", // "Debug": true, @@ -18,7 +20,6 @@ // // `go run cmd.go -h` // output: -// // Flags: // --a3.f3 int f3 (default 87) // --age int age (default 18) @@ -26,12 +27,15 @@ // --dburl string dburl (default ":27071") // --debug enable debug model,false to disable; ,please (default true) // --f2 string f2 (default "ZH") +// --f4 int f4 (default 99) // -h, --help help for test // --keep duration (default 1s) // --logfile string udp|udp:UdpAddr|FilePath|redirect:x (default "stdout") // --name string name (default "test") // --short string short (default "s") // --usage string usage (default "usage") +// + package main import ( diff --git a/flag.go b/flag.go index ee7629f..60ebcb5 100644 --- a/flag.go +++ b/flag.go @@ -70,6 +70,7 @@ import ( "strings" "time" + "github.com/mars315/autoflags/lib/builtin" "github.com/mars315/autoflags/lib/stringx" "github.com/mitchellh/mapstructure" "github.com/spf13/cobra" @@ -117,7 +118,7 @@ type ( ) // BindAndExecute automatically bind flag and execute -func BindAndExecute(cmd *cobra.Command, v0 any, opts ...FlagOption) error { +func BindAndExecute(cmd *cobra.Command, v0 builtin.Any, opts ...FlagOption) error { if err := BindFlags(cmd, v0, opts...); err != nil { return err } @@ -129,9 +130,8 @@ func BindAndExecute(cmd *cobra.Command, v0 any, opts ...FlagOption) error { // supported type: string, bool, int, int32, int64, float32, float64, []string, []int time.Duration // // struct and struct pointer -func BindFlags(cmd *cobra.Command, v0 any, opts ...FlagOption) error { +func BindFlags(cmd *cobra.Command, v0 builtin.Any, opts ...FlagOption) error { autoMarshalOption(cmd, v0, opts...) - if err := bindFlags(cmd, v0, defaultFlagConfig(opts...)); err != nil { return err } @@ -143,14 +143,14 @@ func BindFlags(cmd *cobra.Command, v0 any, opts ...FlagOption) error { // supported type: string, bool, int, int32, int64, float32, float64, []string, []int time.Duration // // struct and struct pointer -func ReadFlags(v0 any, opts ...FlagOption) error { +func ReadFlags(v0 builtin.Any, opts ...FlagOption) error { cfg := defaultFlagConfig(opts...) return readFlags(v0, cfg) } // UnmarshalFlags unmarshal flag value from viper // use `mapstructure` to unmarshal -func UnmarshalFlags(v0 any, opts ...FlagOption) error { +func UnmarshalFlags(v0 builtin.Any, opts ...FlagOption) error { defaultOpts := castConfigOptions(defaultFlagConfig(opts...)) return viper.Unmarshal(v0, defaultOpts...) } @@ -207,6 +207,101 @@ func WithPreAutoUnMarshalEOption(preE func(cmd *cobra.Command, args []string) er } } +/////////////////////////////////////////////////////// implement /////////////////////////////////////////////////////// + +func bindFlags(cmd *cobra.Command, v0 builtin.Any, cfg *FlagConfig) error { + if reflect.TypeOf(v0).Kind() != reflect.Pointer { + return fmt.Errorf("v0 must be pointer") + } + + v := reflect.ValueOf(v0).Elem() + t := v.Type() + flagSet := cmd.Flags() + for i := 0; i < v.NumField(); i++ { + fValue := v.Field(i) + field := t.Field(i) + tag := parseTag(field, cfg) + if tag == nil { + continue + } + switch fValue.Kind() { + case reflect.String: + flagSet.StringVarP(fValue.Addr().Interface().(*string), tag.Name, tag.Short, tag.Default, tag.Desc) + case reflect.Bool: + flagSet.BoolVarP(fValue.Addr().Interface().(*bool), tag.Name, tag.Short, stringx.ToBool(tag.Default), tag.Desc) + case reflect.Float32: + flagSet.Float32VarP(fValue.Addr().Interface().(*float32), tag.Name, tag.Short, stringx.Atof[float32](tag.Default), tag.Desc) + case reflect.Float64: + flagSet.Float64VarP(fValue.Addr().Interface().(*float64), tag.Name, tag.Short, stringx.Atof[float64](tag.Default), tag.Desc) + case reflect.Int: + flagSet.IntVarP(fValue.Addr().Interface().(*int), tag.Name, tag.Short, stringx.Atoi[int](tag.Default), tag.Desc) + case reflect.Int32: + flagSet.Int32VarP(fValue.Addr().Interface().(*int32), tag.Name, tag.Short, stringx.Atoi[int32](tag.Default), tag.Desc) + case reflect.Int64: + bindInt64(flagSet, fValue, tag) + case reflect.Slice: + if err := bindSlice(flagSet, fValue, field, tag); err != nil { + return err + } + case reflect.Struct: + if err := bindStruct(cmd, fValue, field, cfg); err != nil { + return err + } + case reflect.Pointer: + if err := bindPointer(cmd, fValue, field, cfg); err != nil { + return err + } + default: + return fmt.Errorf("unsupported type: %s|%s", field.Name, fValue.Kind()) + } + } + return nil +} + +func readFlags(v0 builtin.Any, cfg *FlagConfig) error { + v := reflect.ValueOf(v0).Elem() + t := v.Type() + for i := 0; i < v.NumField(); i++ { + fValue := v.Field(i) + field := t.Field(i) + tag := getTag(field, cfg) + if tag == nil { + continue + } + switch fValue.Kind() { + case reflect.String: + fValue.Set(reflect.ValueOf(viper.GetString(tag.Name))) + case reflect.Bool: + fValue.Set(reflect.ValueOf(viper.GetBool(tag.Name))) + case reflect.Float32: + fValue.Set(reflect.ValueOf(float32(viper.GetFloat64(tag.Name)))) + case reflect.Float64: + fValue.Set(reflect.ValueOf(viper.GetFloat64(tag.Name))) + case reflect.Int: + fValue.Set(reflect.ValueOf(viper.GetInt(tag.Name))) + case reflect.Int32: + fValue.Set(reflect.ValueOf(viper.GetInt32(tag.Name))) + case reflect.Int64: + readInt64(fValue, tag) + case reflect.Slice: + if err := readSlice(fValue, tag); err != nil { + return err + } + case reflect.Struct: + if err := readStruct(fValue, field, cfg); err != nil { + return err + } + case reflect.Pointer: + if err := readPointer(fValue, field, cfg); err != nil { + return err + } + default: + return fmt.Errorf("unsupported type: %s|%s", field.Name, fValue.Kind()) + } + } + return nil +} + /////////////////////////////////////////////////////// cast /////////////////////////////////////////////////////// // alias @@ -240,10 +335,10 @@ func withIgnoreUntaggedFieldsOption(ignore bool) decoderConfigOption { } } -// ///////////////////////////////////////////////////// helper /////////////////////////////////////////////////////// +/////////////////////////////////////////////////////// helper /////////////////////////////////////////////////////// // set auto marshal function -func autoMarshalOption(cmd *cobra.Command, v0 any, opts ...FlagOption) { +func autoMarshalOption(cmd *cobra.Command, v0 builtin.Any, opts ...FlagOption) { cfg := defaultFlagConfig(opts...) if !cfg.autoUnMarshalFlag { return @@ -298,8 +393,8 @@ func tryStepOut(field reflect.StructField, cfg *FlagConfig) { return } - tagInfo := getTag(field, cfg) - squash := cfg.squash || tagInfo != nil && tagInfo.squash + tag := getTag(field, cfg) + squash := cfg.squash || tag != nil && tag.squash if squash { return } @@ -307,99 +402,7 @@ func tryStepOut(field reflect.StructField, cfg *FlagConfig) { cfg.parent = cfg.parent[:len(cfg.parent)-1] } -// ///////////////////////////////////////////////////// implement /////////////////////////////////////////////////////// - -func bindFlags(cmd *cobra.Command, v0 any, cfg *FlagConfig) error { - if reflect.TypeOf(v0).Kind() != reflect.Pointer { - return fmt.Errorf("v0 must be pointer") - } - - v := reflect.ValueOf(v0).Elem() - t := v.Type() - flagSet := cmd.Flags() - for i := 0; i < v.NumField(); i++ { - fValue := v.Field(i) - field := t.Field(i) - if parseTag(field, cfg) == nil { - continue - } - switch fValue.Kind() { - case reflect.String: - bindString(flagSet, fValue, field, cfg) - case reflect.Bool: - bindBool(flagSet, fValue, field, cfg) - case reflect.Int: - bindInt(flagSet, fValue, field, cfg) - case reflect.Int32: - bindInt32(flagSet, fValue, field, cfg) - case reflect.Int64: - bindInt64(flagSet, fValue, field, cfg) - case reflect.Float32: - bindFloat32(flagSet, fValue, field, cfg) - case reflect.Float64: - bindFloat64(flagSet, fValue, field, cfg) - case reflect.Slice: - if err := bindSlice(flagSet, fValue, field, cfg); err != nil { - return err - } - case reflect.Struct: - if err := bindStruct(cmd, fValue, field, cfg); err != nil { - return err - } - case reflect.Pointer: - if err := bindPointer(cmd, fValue, field, cfg); err != nil { - return err - } - default: - return fmt.Errorf("unsupported type: %s|%s", field.Name, fValue.Kind()) - } - } - return nil -} - -func readFlags(v0 any, cfg *FlagConfig) error { - v := reflect.ValueOf(v0).Elem() - t := v.Type() - for i := 0; i < v.NumField(); i++ { - fValue := v.Field(i) - field := t.Field(i) - if getTag(field, cfg) == nil { - continue - } - switch fValue.Kind() { - case reflect.String: - readString(fValue, field, cfg) - case reflect.Bool: - readBool(fValue, field, cfg) - case reflect.Int: - readInt(fValue, field, cfg) - case reflect.Int32: - readInt32(fValue, field, cfg) - case reflect.Int64: - readInt64(fValue, field, cfg) - case reflect.Float32: - readFloat32(fValue, field, cfg) - case reflect.Float64: - readFloat64(fValue, field, cfg) - case reflect.Slice: - if err := readSlice(fValue, field, cfg); err != nil { - return err - } - case reflect.Struct: - if err := readStruct(fValue, field, cfg); err != nil { - return err - } - case reflect.Pointer: - if err := readPointer(fValue, field, cfg); err != nil { - return err - } - default: - return fmt.Errorf("unsupported type: %s|%s", field.Name, fValue.Kind()) - } - } - return nil -} - +// ///////////////////////////////////////////////////// tag /////////////////////////////////////////////////////// // private type tagData struct { origin string @@ -414,7 +417,6 @@ func parseTag(field reflect.StructField, cfg *FlagConfig) *tagData { if !field.IsExported() { return nil } - data := getTag(field, cfg) if data == nil { return nil @@ -473,13 +475,13 @@ func getTag(field reflect.StructField, cfg *FlagConfig) *tagData { if settings[cfg.tagName] == TagLabelSkip { return nil } - data := &tagData{ Name: settings[cfg.tagName], Short: settings[TagLabelShort], Desc: settings[TagLabelDesc], Default: settings[TagLabelDefault], } + // untagged field use field name as the flag name if len(data.Name) == 0 { data.Name = strings.ToLower(field.Name) @@ -487,6 +489,7 @@ func getTag(field reflect.StructField, cfg *FlagConfig) *tagData { _, squashLabel := settings[TagLabelSquash] data.squash = squashLabel && isStepInto(field) + data.origin = data.Name // add prefix @@ -504,6 +507,8 @@ func getTag(field reflect.StructField, cfg *FlagConfig) *tagData { return data } +/////////////////////////////////////////////////////// struct /////////////////////////////////////////////////////// + func bindStruct(cmd *cobra.Command, fValue reflect.Value, field reflect.StructField, cfg *FlagConfig) error { defer tryStepOut(field, cfg) return bindFlags(cmd, fValue.Addr().Interface(), cfg) @@ -514,14 +519,16 @@ func readStruct(fValue reflect.Value, field reflect.StructField, cfg *FlagConfig return readFlags(fValue.Addr().Interface(), cfg) } +/////////////////////////////////////////////////////// pointer /////////////////////////////////////////////////////// + func bindPointer(cmd *cobra.Command, fValue reflect.Value, field reflect.StructField, cfg *FlagConfig) error { if fValue.IsNil() { return fmt.Errorf("nil value of *%s", field.Name) } + if fValue.Elem().Kind() != reflect.Struct { return fmt.Errorf("unsupported type: %s|%s(%s)", field.Name, fValue.Kind(), fValue.Elem().Kind()) } - defer tryStepOut(field, cfg) return bindFlags(cmd, fValue.Interface(), cfg) } @@ -530,6 +537,7 @@ func readPointer(fValue reflect.Value, field reflect.StructField, cfg *FlagConfi if fValue.IsNil() { return fmt.Errorf("nil value of *%s", field.Name) } + if fValue.Elem().Kind() != reflect.Struct { return fmt.Errorf("unsupported type: %s|%s(%s)", field.Name, fValue.Kind(), fValue.Elem().Kind()) } @@ -538,84 +546,19 @@ func readPointer(fValue reflect.Value, field reflect.StructField, cfg *FlagConfi return readFlags(fValue.Interface(), cfg) } -func bindSlice(set *flag.FlagSet, fValue reflect.Value, field reflect.StructField, cfg *FlagConfig) error { - switch fValue.Type().Elem().Kind() { - case reflect.String: - bindStringSlice(set, fValue, field, cfg) - case reflect.Int: - bindIntSlice(set, fValue, field, cfg) - default: - return fmt.Errorf("unsupported slice type: %s|%s", field.Name, fValue.Type().Elem().Kind()) - } - return nil -} - -func readSlice(fValue reflect.Value, field reflect.StructField, cfg *FlagConfig) error { - switch fValue.Type().Elem().Kind() { - case reflect.String: - readStringSlice(fValue, field, cfg) - case reflect.Int: - readIntSlice(fValue, field, cfg) - default: - return fmt.Errorf("unsupported slice type: %s|%s", field.Name, fValue.Type().Elem().Kind()) - } - return nil -} - -func bindFloat32(set *flag.FlagSet, fValue reflect.Value, field reflect.StructField, cfg *FlagConfig) { - tag := getTag(field, cfg) - set.Float32VarP(fValue.Addr().Interface().(*float32), tag.Name, tag.Short, stringx.Atof[float32](tag.Default), tag.Desc) -} - -func readFloat32(fValue reflect.Value, field reflect.StructField, cfg *FlagConfig) { - tag := getTag(field, cfg) - fValue.Set(reflect.ValueOf(float32(viper.GetFloat64(tag.Name)))) -} - -func bindFloat64(set *flag.FlagSet, fValue reflect.Value, field reflect.StructField, cfg *FlagConfig) { - tag := getTag(field, cfg) - set.Float64VarP(fValue.Addr().Interface().(*float64), tag.Name, tag.Short, stringx.Atof[float64](tag.Default), tag.Desc) -} - -func readFloat64(fValue reflect.Value, field reflect.StructField, cfg *FlagConfig) { - tag := getTag(field, cfg) - fValue.Set(reflect.ValueOf(viper.GetFloat64(tag.Name))) -} - -func bindInt(set *flag.FlagSet, fValue reflect.Value, field reflect.StructField, cfg *FlagConfig) { - tag := getTag(field, cfg) - set.IntVarP(fValue.Addr().Interface().(*int), tag.Name, tag.Short, stringx.Atoi[int](tag.Default), tag.Desc) -} - -func readInt(fValue reflect.Value, field reflect.StructField, cfg *FlagConfig) { - tag := getTag(field, cfg) - fValue.Set(reflect.ValueOf(viper.GetInt(tag.Name))) -} - -func bindInt32(set *flag.FlagSet, fValue reflect.Value, field reflect.StructField, cfg *FlagConfig) { - tag := getTag(field, cfg) - set.Int32VarP(fValue.Addr().Interface().(*int32), tag.Name, tag.Short, stringx.Atoi[int32](tag.Default), tag.Desc) -} - -func readInt32(fValue reflect.Value, field reflect.StructField, cfg *FlagConfig) { - tag := getTag(field, cfg) - fValue.Set(reflect.ValueOf(viper.GetInt32(tag.Name))) -} +/////////////////////////////////////////////////////// int64 /////////////////////////////////////////////////////// -func bindInt64(set *flag.FlagSet, fValue reflect.Value, field reflect.StructField, cfg *FlagConfig) { - tag := getTag(field, cfg) - i := fValue.Addr().Interface() - switch i.(type) { +func bindInt64(flagSet *flag.FlagSet, fValue reflect.Value, tag *tagData) { + switch fValue.Addr().Interface().(type) { case *time.Duration: duration, _ := time.ParseDuration(tag.Default) - set.DurationVarP(fValue.Addr().Interface().(*time.Duration), tag.Name, tag.Short, duration, tag.Desc) + flagSet.DurationVarP(fValue.Addr().Interface().(*time.Duration), tag.Name, tag.Short, duration, tag.Desc) default: - set.Int64VarP(fValue.Addr().Interface().(*int64), tag.Name, tag.Short, stringx.Atoi[int64](tag.Default), tag.Desc) + flagSet.Int64VarP(fValue.Addr().Interface().(*int64), tag.Name, tag.Short, stringx.Atoi[int64](tag.Default), tag.Desc) } } -func readInt64(fValue reflect.Value, field reflect.StructField, cfg *FlagConfig) { - tag := getTag(field, cfg) +func readInt64(fValue reflect.Value, tag *tagData) { i := fValue.Addr().Interface() switch i.(type) { case *time.Duration: @@ -625,42 +568,44 @@ func readInt64(fValue reflect.Value, field reflect.StructField, cfg *FlagConfig) } } -func bindIntSlice(set *flag.FlagSet, fValue reflect.Value, field reflect.StructField, cfg *FlagConfig) { - tag := getTag(field, cfg) - set.IntSliceVarP(fValue.Addr().Interface().(*[]int), tag.Name, tag.Short, stringx.AtoSlice[int](tag.Default, ","), tag.Desc) -} - -func readIntSlice(fValue reflect.Value, field reflect.StructField, cfg *FlagConfig) { - tag := getTag(field, cfg) - fValue.Set(reflect.ValueOf(viper.GetIntSlice(tag.Name))) -} +/////////////////////////////////////////////////////// slice /////////////////////////////////////////////////////// -func bindString(set *flag.FlagSet, fValue reflect.Value, field reflect.StructField, cfg *FlagConfig) { - tag := getTag(field, cfg) - set.StringVarP(fValue.Addr().Interface().(*string), tag.Name, tag.Short, tag.Default, tag.Desc) +func bindSlice(flagSet *flag.FlagSet, fValue reflect.Value, field reflect.StructField, tag *tagData) error { + switch fValue.Type().Elem().Kind() { + case reflect.String: + bindStringSlice(flagSet, fValue, tag) + case reflect.Int: + bindIntSlice(flagSet, fValue, tag) + default: + return fmt.Errorf("field `%s` unsupported slice type %s", field.Name, fValue.Type().Elem().Kind()) + } + return nil } -func readString(fValue reflect.Value, field reflect.StructField, cfg *FlagConfig) { - tag := getTag(field, cfg) - fValue.Set(reflect.ValueOf(viper.GetString(tag.Name))) +func readSlice(fValue reflect.Value, tag *tagData) error { + switch fValue.Type().Elem().Kind() { + case reflect.String: + readStringSlice(fValue, tag) + case reflect.Int: + readIntSlice(fValue, tag) + default: + return fmt.Errorf("unsupported slice type: %s|%s", fValue.Type().Elem().Name(), fValue.Type().Elem().Kind()) + } + return nil } -func bindStringSlice(set *flag.FlagSet, fValue reflect.Value, field reflect.StructField, cfg *FlagConfig) { - tag := getTag(field, cfg) - set.StringSliceVarP(fValue.Addr().Interface().(*[]string), tag.Name, tag.Short, stringx.Split(tag.Default, ","), tag.Desc) +func bindIntSlice(flagSet *flag.FlagSet, fValue reflect.Value, tag *tagData) { + flagSet.IntSliceVarP(fValue.Addr().Interface().(*[]int), tag.Name, tag.Short, stringx.AtoSlice[int](tag.Default, ","), tag.Desc) } -func readStringSlice(fValue reflect.Value, field reflect.StructField, cfg *FlagConfig) { - tag := getTag(field, cfg) - fValue.Set(reflect.ValueOf(viper.GetStringSlice(tag.Name))) +func readIntSlice(fValue reflect.Value, tag *tagData) { + fValue.Set(reflect.ValueOf(viper.GetIntSlice(tag.Name))) } -func bindBool(set *flag.FlagSet, fValue reflect.Value, field reflect.StructField, cfg *FlagConfig) { - tag := getTag(field, cfg) - set.BoolVarP(fValue.Addr().Interface().(*bool), tag.Name, tag.Short, stringx.ToBool(tag.Default), tag.Desc) +func bindStringSlice(flagSet *flag.FlagSet, fValue reflect.Value, tag *tagData) { + flagSet.StringSliceVarP(fValue.Addr().Interface().(*[]string), tag.Name, tag.Short, stringx.Split(tag.Default, ","), tag.Desc) } -func readBool(fValue reflect.Value, field reflect.StructField, cfg *FlagConfig) { - tag := getTag(field, cfg) - fValue.Set(reflect.ValueOf(viper.GetBool(tag.Name))) +func readStringSlice(fValue reflect.Value, tag *tagData) { + fValue.Set(reflect.ValueOf(viper.GetStringSlice(tag.Name))) } diff --git a/lib/builtin/builtin.go b/lib/builtin/builtin.go index 223e5e6..2795905 100644 --- a/lib/builtin/builtin.go +++ b/lib/builtin/builtin.go @@ -26,4 +26,6 @@ type ( Float interface { ~float32 | ~float64 } + + Any = any )