diff --git a/flag.go b/flag.go index a1eb3c3..a21ed31 100644 --- a/flag.go +++ b/flag.go @@ -114,6 +114,8 @@ type ( tagName string // The tag label separator, default is "," tagLabelSep string + // persist flags `cmd.PersistentFlags()` default cmd.Flags() + persist bool } ) @@ -136,7 +138,7 @@ func BindFlags(cmd *cobra.Command, v0 builtin.Any, opts ...FlagOption) error { return err } - return viper.BindPFlags(cmd.Flags()) + return viper.BindPFlags(getFlagSet(cmd, defaultFlagConfig(opts...))) } // ReadFlags read flag value from viper @@ -157,6 +159,13 @@ func UnmarshalFlags(v0 builtin.Any, opts ...FlagOption) error { /////////////////////////////////////////////////////// option /////////////////////////////////////////////////////// +// WithPersistFlagSetOption persist flags +func WithPersistFlagSetOption() FlagOption { + return func(cfg *FlagConfig) { + cfg.persist = true + } +} + // WithTagNameOption custom tag name func WithTagNameOption(tag string) FlagOption { return func(cfg *FlagConfig) { @@ -216,7 +225,7 @@ func bindFlags(cmd *cobra.Command, v0 builtin.Any, cfg *FlagConfig) error { v := reflect.ValueOf(v0).Elem() t := v.Type() - flagSet := cmd.Flags() + flagSet := getFlagSet(cmd, cfg) for i := 0; i < v.NumField(); i++ { fValue := v.Field(i) field := t.Field(i) @@ -401,6 +410,15 @@ func tryStepOut(field reflect.StructField, cfg *FlagConfig) { cfg.parent = cfg.parent[:len(cfg.parent)-1] } +func getFlagSet(cmd *cobra.Command, cfg *FlagConfig) *flag.FlagSet { + switch cfg.persist { + case true: + return cmd.PersistentFlags() + default: + return cmd.Flags() + } +} + // ///////////////////////////////////////////////////// tag /////////////////////////////////////////////////////// // private