diff --git a/cli/providers/providers.go b/cli/providers/providers.go index 3038dd194e..af5920ba65 100644 --- a/cli/providers/providers.go +++ b/cli/providers/providers.go @@ -34,7 +34,7 @@ func AttachCLIs(rootCmd *cobra.Command, commands ...*Command) error { return err } - connectorName, autoUpdate := detectConnectorName(os.Args, commands) + connectorName, autoUpdate := detectConnectorName(os.Args, commands, existing) if connectorName != "" { if _, err := providers.EnsureProvider(existing, connectorName, "", autoUpdate); err != nil { return err @@ -49,7 +49,7 @@ func AttachCLIs(rootCmd *cobra.Command, commands ...*Command) error { return nil } -func detectConnectorName(args []string, commands []*Command) (string, bool) { +func detectConnectorName(args []string, commands []*Command, providers providers.Providers) (string, bool) { autoUpdate := true config.InitViperConfig() @@ -63,7 +63,7 @@ func detectConnectorName(args []string, commands []*Command) (string, bool) { builtins := genBuiltinFlags() for i := range builtins { - addFlagToSet(flags, builtins[i]) + attachFlag(flags, builtins[i]) } for i := range commands { @@ -75,6 +75,19 @@ func detectConnectorName(args []string, commands []*Command) (string, bool) { }) } + for i := range providers { + provider := providers[i] + for j := range provider.Connectors { + conn := provider.Connectors[j] + for k := range conn.Flags { + flag := conn.Flags[k] + if found := flags.Lookup(flag.Long); found == nil { + attachFlag(flags, flag) + } + } + } + } + err := flags.Parse(args) if err != nil { log.Warn().Err(err).Msg("CLI pre-processing encountered an issue") @@ -225,40 +238,51 @@ var skipFlags = map[string]struct{}{ "use-recording": {}, } -func addFlagToSet(set *pflag.FlagSet, flag plugin.Flag) { +func attachFlag(flagset *pflag.FlagSet, flag plugin.Flag) { switch flag.Type { case plugin.FlagType_Bool: if flag.Short != "" { - set.BoolP(flag.Long, flag.Short, false, flag.Desc) + flagset.BoolP(flag.Long, flag.Short, json2T(flag.Default, false), flag.Desc) } else { - set.Bool(flag.Long, false, flag.Desc) + flagset.Bool(flag.Long, json2T(flag.Default, false), flag.Desc) } case plugin.FlagType_Int: if flag.Short != "" { - set.IntP(flag.Long, flag.Short, 0, flag.Desc) + flagset.IntP(flag.Long, flag.Short, json2T(flag.Default, 0), flag.Desc) } else { - set.Int(flag.Long, 0, flag.Desc) + flagset.Int(flag.Long, json2T(flag.Default, 0), flag.Desc) } case plugin.FlagType_String: if flag.Short != "" { - set.StringP(flag.Long, flag.Short, "", flag.Desc) + flagset.StringP(flag.Long, flag.Short, flag.Default, flag.Desc) } else { - set.String(flag.Long, "", flag.Desc) + flagset.String(flag.Long, flag.Default, flag.Desc) } case plugin.FlagType_List: if flag.Short != "" { - set.StringArrayP(flag.Long, flag.Short, []string{}, flag.Desc) + flagset.StringSliceP(flag.Long, flag.Short, json2T(flag.Default, []string{}), flag.Desc) } else { - set.StringArray(flag.Long, []string{}, flag.Desc) + flagset.StringSlice(flag.Long, json2T(flag.Default, []string{}), flag.Desc) } case plugin.FlagType_KeyValue: if flag.Short != "" { - set.StringToStringP(flag.Long, flag.Short, map[string]string{}, flag.Desc) + flagset.StringToStringP(flag.Long, flag.Short, json2T(flag.Default, map[string]string{}), flag.Desc) } else { - set.StringToString(flag.Long, map[string]string{}, flag.Desc) + flagset.StringToString(flag.Long, json2T(flag.Default, map[string]string{}), flag.Desc) } - default: - log.Warn().Msg("unknown flag type for " + flag.Long) + } + + if flag.Option&plugin.FlagOption_Hidden != 0 { + flagset.MarkHidden(flag.Long) + } + if flag.Option&plugin.FlagOption_Deprecated != 0 { + flagset.MarkDeprecated(flag.Long, "has been deprecated") + } +} + +func attachFlags(flagset *pflag.FlagSet, flags []plugin.Flag) { + for i := range flags { + attachFlag(flagset, flags[i]) } } @@ -424,48 +448,7 @@ func setConnector(provider *plugin.Provider, connector *plugin.Connector, run fu providers.Coordinator.Shutdown() } - for i := range allFlags { - flag := allFlags[i] - switch flag.Type { - case plugin.FlagType_Bool: - if flag.Short != "" { - cmd.Flags().BoolP(flag.Long, flag.Short, json2T(flag.Default, false), flag.Desc) - } else { - cmd.Flags().Bool(flag.Long, json2T(flag.Default, false), flag.Desc) - } - case plugin.FlagType_Int: - if flag.Short != "" { - cmd.Flags().IntP(flag.Long, flag.Short, json2T(flag.Default, 0), flag.Desc) - } else { - cmd.Flags().Int(flag.Long, json2T(flag.Default, 0), flag.Desc) - } - case plugin.FlagType_String: - if flag.Short != "" { - cmd.Flags().StringP(flag.Long, flag.Short, flag.Default, flag.Desc) - } else { - cmd.Flags().String(flag.Long, flag.Default, flag.Desc) - } - case plugin.FlagType_List: - if flag.Short != "" { - cmd.Flags().StringSliceP(flag.Long, flag.Short, json2T(flag.Default, []string{}), flag.Desc) - } else { - cmd.Flags().StringSlice(flag.Long, json2T(flag.Default, []string{}), flag.Desc) - } - case plugin.FlagType_KeyValue: - if flag.Short != "" { - cmd.Flags().StringToStringP(flag.Long, flag.Short, json2T(flag.Default, map[string]string{}), flag.Desc) - } else { - cmd.Flags().StringToString(flag.Long, json2T(flag.Default, map[string]string{}), flag.Desc) - } - } - - if flag.Option&plugin.FlagOption_Hidden != 0 { - cmd.Flags().MarkHidden(flag.Long) - } - if flag.Option&plugin.FlagOption_Deprecated != 0 { - cmd.Flags().MarkDeprecated(flag.Long, "has been deprecated") - } - } + attachFlags(cmd.Flags(), allFlags) } func json2T[T any](s string, empty T) T {