Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 pre-process known provider flags #1868

Merged
merged 1 commit into from
Sep 24, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 41 additions & 58 deletions cli/providers/providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func AttachCLIs(rootCmd *cobra.Command, commands ...*Command) error {
return err
}

connectorName, autoUpdate := detectConnectorName(os.Args, commands)
connectorName, autoUpdate := detectConnectorName(os.Args, commands, existing)
if connectorName != "" {
if _, err := providers.EnsureProvider(existing, connectorName, "", autoUpdate); err != nil {
return err
Expand All @@ -49,7 +49,7 @@ func AttachCLIs(rootCmd *cobra.Command, commands ...*Command) error {
return nil
}

func detectConnectorName(args []string, commands []*Command) (string, bool) {
func detectConnectorName(args []string, commands []*Command, providers providers.Providers) (string, bool) {
autoUpdate := true

config.InitViperConfig()
Expand All @@ -63,7 +63,7 @@ func detectConnectorName(args []string, commands []*Command) (string, bool) {

builtins := genBuiltinFlags()
for i := range builtins {
addFlagToSet(flags, builtins[i])
attachFlag(flags, builtins[i])
}

for i := range commands {
Expand All @@ -75,6 +75,19 @@ func detectConnectorName(args []string, commands []*Command) (string, bool) {
})
}

for i := range providers {
provider := providers[i]
for j := range provider.Connectors {
conn := provider.Connectors[j]
for k := range conn.Flags {
flag := conn.Flags[k]
if found := flags.Lookup(flag.Long); found == nil {
attachFlag(flags, flag)
}
}
}
}

err := flags.Parse(args)
if err != nil {
log.Warn().Err(err).Msg("CLI pre-processing encountered an issue")
Expand Down Expand Up @@ -225,40 +238,51 @@ var skipFlags = map[string]struct{}{
"use-recording": {},
}

func addFlagToSet(set *pflag.FlagSet, flag plugin.Flag) {
func attachFlag(flagset *pflag.FlagSet, flag plugin.Flag) {
switch flag.Type {
case plugin.FlagType_Bool:
if flag.Short != "" {
set.BoolP(flag.Long, flag.Short, false, flag.Desc)
flagset.BoolP(flag.Long, flag.Short, json2T(flag.Default, false), flag.Desc)
} else {
set.Bool(flag.Long, false, flag.Desc)
flagset.Bool(flag.Long, json2T(flag.Default, false), flag.Desc)
}
case plugin.FlagType_Int:
if flag.Short != "" {
set.IntP(flag.Long, flag.Short, 0, flag.Desc)
flagset.IntP(flag.Long, flag.Short, json2T(flag.Default, 0), flag.Desc)
} else {
set.Int(flag.Long, 0, flag.Desc)
flagset.Int(flag.Long, json2T(flag.Default, 0), flag.Desc)
}
case plugin.FlagType_String:
if flag.Short != "" {
set.StringP(flag.Long, flag.Short, "", flag.Desc)
flagset.StringP(flag.Long, flag.Short, flag.Default, flag.Desc)
} else {
set.String(flag.Long, "", flag.Desc)
flagset.String(flag.Long, flag.Default, flag.Desc)
}
case plugin.FlagType_List:
if flag.Short != "" {
set.StringArrayP(flag.Long, flag.Short, []string{}, flag.Desc)
flagset.StringSliceP(flag.Long, flag.Short, json2T(flag.Default, []string{}), flag.Desc)
} else {
set.StringArray(flag.Long, []string{}, flag.Desc)
flagset.StringSlice(flag.Long, json2T(flag.Default, []string{}), flag.Desc)
}
case plugin.FlagType_KeyValue:
if flag.Short != "" {
set.StringToStringP(flag.Long, flag.Short, map[string]string{}, flag.Desc)
flagset.StringToStringP(flag.Long, flag.Short, json2T(flag.Default, map[string]string{}), flag.Desc)
} else {
set.StringToString(flag.Long, map[string]string{}, flag.Desc)
flagset.StringToString(flag.Long, json2T(flag.Default, map[string]string{}), flag.Desc)
}
default:
log.Warn().Msg("unknown flag type for " + flag.Long)
}

if flag.Option&plugin.FlagOption_Hidden != 0 {
flagset.MarkHidden(flag.Long)
}
if flag.Option&plugin.FlagOption_Deprecated != 0 {
flagset.MarkDeprecated(flag.Long, "has been deprecated")
}
}

func attachFlags(flagset *pflag.FlagSet, flags []plugin.Flag) {
for i := range flags {
attachFlag(flagset, flags[i])
}
}

Expand Down Expand Up @@ -424,48 +448,7 @@ func setConnector(provider *plugin.Provider, connector *plugin.Connector, run fu
providers.Coordinator.Shutdown()
}

for i := range allFlags {
flag := allFlags[i]
switch flag.Type {
case plugin.FlagType_Bool:
if flag.Short != "" {
cmd.Flags().BoolP(flag.Long, flag.Short, json2T(flag.Default, false), flag.Desc)
} else {
cmd.Flags().Bool(flag.Long, json2T(flag.Default, false), flag.Desc)
}
case plugin.FlagType_Int:
if flag.Short != "" {
cmd.Flags().IntP(flag.Long, flag.Short, json2T(flag.Default, 0), flag.Desc)
} else {
cmd.Flags().Int(flag.Long, json2T(flag.Default, 0), flag.Desc)
}
case plugin.FlagType_String:
if flag.Short != "" {
cmd.Flags().StringP(flag.Long, flag.Short, flag.Default, flag.Desc)
} else {
cmd.Flags().String(flag.Long, flag.Default, flag.Desc)
}
case plugin.FlagType_List:
if flag.Short != "" {
cmd.Flags().StringSliceP(flag.Long, flag.Short, json2T(flag.Default, []string{}), flag.Desc)
} else {
cmd.Flags().StringSlice(flag.Long, json2T(flag.Default, []string{}), flag.Desc)
}
case plugin.FlagType_KeyValue:
if flag.Short != "" {
cmd.Flags().StringToStringP(flag.Long, flag.Short, json2T(flag.Default, map[string]string{}), flag.Desc)
} else {
cmd.Flags().StringToString(flag.Long, json2T(flag.Default, map[string]string{}), flag.Desc)
}
}

if flag.Option&plugin.FlagOption_Hidden != 0 {
cmd.Flags().MarkHidden(flag.Long)
}
if flag.Option&plugin.FlagOption_Deprecated != 0 {
cmd.Flags().MarkDeprecated(flag.Long, "has been deprecated")
}
}
attachFlags(cmd.Flags(), allFlags)
}

func json2T[T any](s string, empty T) T {
Expand Down