Skip to content

Commit

Permalink
🐛 bring back loading provider flags from env vars (#4863)
Browse files Browse the repository at this point in the history
* Revert "Revert "🐛 Load provider flags from environment variables (#4847)" (#4857)"

This reverts commit e2cc11e.

* 🐛 fix provider flags with `ConfigEntry="-"`

When a provider defines a flag with `ConfigEntry = "-"`, then we do not
Bind the Flag to the `viper` config. For those flags, we will continue
to fetch the value directly from the flag, that is, from `cobra`.

Signed-off-by: Salim Afiune Maya <[email protected]>

---------

Signed-off-by: Salim Afiune Maya <[email protected]>
  • Loading branch information
afiune authored Nov 25, 2024
1 parent 22c683f commit 9cf3198
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 11 deletions.
44 changes: 36 additions & 8 deletions cli/providers/providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ package providers

import (
"encoding/json"
"go.mondoo.com/cnquery/v11/utils/piped"
"go.mondoo.com/ranger-rpc/status"
"os"
"strings"

Expand All @@ -21,6 +19,8 @@ import (
"go.mondoo.com/cnquery/v11/providers-sdk/v1/plugin"
"go.mondoo.com/cnquery/v11/providers-sdk/v1/recording"
"go.mondoo.com/cnquery/v11/types"
"go.mondoo.com/cnquery/v11/utils/piped"
"go.mondoo.com/ranger-rpc/status"
)

type Command struct {
Expand Down Expand Up @@ -318,14 +318,31 @@ func attachFlags(flagset *pflag.FlagSet, flags []plugin.Flag) {
}
}

func getFlagValue(flag plugin.Flag, cmd *cobra.Command) *llx.Primitive {
func getFlagValueFromConfig(flag plugin.Flag) *llx.Primitive {
switch flag.Type {
case plugin.FlagType_Bool:
v, err := cmd.Flags().GetBool(flag.Long)
if err == nil {
return llx.BoolPrimitive(viper.GetBool(flag.Long))
case plugin.FlagType_Int:
return llx.IntPrimitive(viper.GetInt64(flag.Long))
case plugin.FlagType_String:
return llx.StringPrimitive(viper.GetString(flag.Long))
case plugin.FlagType_List:
return llx.ArrayPrimitiveT(viper.GetStringSlice(flag.Long), llx.StringPrimitive, types.String)
case plugin.FlagType_KeyValue:
return llx.MapPrimitiveT(viper.GetStringMapString(flag.Long), llx.StringPrimitive, types.String)
default:
log.Warn().Msg("unknown flag type for " + flag.Long)
return nil
}
}

func getFlagValueFromCobra(flag plugin.Flag, cmd *cobra.Command) *llx.Primitive {
var err error
switch flag.Type {
case plugin.FlagType_Bool:
if v, err := cmd.Flags().GetBool(flag.Long); err == nil {
return llx.BoolPrimitive(v)
}
log.Warn().Err(err).Msg("failed to get flag " + flag.Long)
case plugin.FlagType_Int:
if v, err := cmd.Flags().GetInt(flag.Long); err == nil {
return llx.IntPrimitive(int64(v))
Expand All @@ -346,6 +363,8 @@ func getFlagValue(flag plugin.Flag, cmd *cobra.Command) *llx.Primitive {
log.Warn().Msg("unknown flag type for " + flag.Long)
return nil
}

log.Warn().Err(err).Msg("failed to get flag " + flag.Long)
return nil
}

Expand All @@ -366,6 +385,7 @@ func setConnector(provider *plugin.Provider, connector *plugin.Connector, run fu
for i := range allFlags {
flag := allFlags[i]
if flag.ConfigEntry == "-" {
log.Debug().Msg("skipping config binding for " + flag.Long)
continue
}

Expand Down Expand Up @@ -421,8 +441,16 @@ func setConnector(provider *plugin.Provider, connector *plugin.Connector, run fu
continue
}

if v := getFlagValue(flag, cmd); v != nil {
flagVals[flag.Long] = v
// if the provider flag was configured to avoid using the config,
// we should instead fetch the flag value from `cobra` directly.
if flag.ConfigEntry == "-" {
if v := getFlagValueFromCobra(flag, cmd); v != nil {
flagVals[flag.Long] = v
}
} else {
if v := getFlagValueFromConfig(flag); v != nil {
flagVals[flag.Long] = v
}
}
}

Expand Down
74 changes: 71 additions & 3 deletions test/providers/os_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,16 @@
package providers

import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.mondoo.com/cnquery/v11/test"
"log"
"os"
"os/exec"
"path/filepath"
"sync"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.mondoo.com/cnquery/v11/test"
)

var once sync.Once
Expand Down Expand Up @@ -185,3 +186,70 @@ func TestOsProviderSharedTests(t *testing.T) {
}
}
}

func TestProvidersEnvVarsLoading(t *testing.T) {
t.Run("command WITHOUT path should not find any package", func(t *testing.T) {
r := test.NewCliTestRunner("./cnquery", "run", "fs", "-c", mqlPackagesQuery, "-j")
err := r.Run()
require.NoError(t, err)
assert.Equal(t, 0, r.ExitCode())
assert.NotNil(t, r.Stdout())
assert.NotNil(t, r.Stderr())

var c mqlPackages
err = r.Json(&c)
assert.NoError(t, err)

// No packages
assert.Empty(t, c)
})
t.Run("command WITH path should find packages", func(t *testing.T) {
os.Setenv("MONDOO_PATH", "./testdata/fs")
defer os.Unsetenv("MONDOO_PATH")
// Note we are not passing the flag "--path ./testdata/fs"
r := test.NewCliTestRunner("./cnquery", "run", "fs", "-c", mqlPackagesQuery, "-j")
err := r.Run()
require.NoError(t, err)
assert.Equal(t, 0, r.ExitCode())
assert.NotNil(t, r.Stdout())
assert.NotNil(t, r.Stderr())

var c mqlPackages
err = r.Json(&c)
assert.NoError(t, err)

// Should have packages
if assert.NotEmpty(t, c) {
x := c[0]
assert.NotNil(t, x.Packages)
assert.True(t, len(x.Packages) > 0)
}
})

t.Run("command with flags set to not bind to config (ConfigEntry=\"-\")", func(t *testing.T) {
t.Run("should work via direct flag", func(t *testing.T) {
r := test.NewCliTestRunner("./cnquery", "run", "ssh", "localhost", "-c", "ls", "-p", "test", "-v")
err := r.Run()
require.NoError(t, err)
assert.Equal(t, 0, r.ExitCode())
assert.NotNil(t, r.Stdout())
if assert.NotNil(t, r.Stderr()) {
assert.Contains(t, string(r.Stderr()), "skipping config binding for password")
assert.Contains(t, string(r.Stderr()), "enabled ssh password authentication")
}
})
t.Run("should NOT work via config/env-vars", func(t *testing.T) {
os.Setenv("MONDOO_PASSWORD", "test")
defer os.Unsetenv("MONDOO_PASSWORD")
r := test.NewCliTestRunner("./cnquery", "run", "ssh", "localhost", "-c", "ls", "-v")
err := r.Run()
require.NoError(t, err)
assert.Equal(t, 0, r.ExitCode())
assert.NotNil(t, r.Stdout())
if assert.NotNil(t, r.Stderr()) {
assert.Contains(t, string(r.Stderr()), "skipping config binding for password")
assert.NotContains(t, string(r.Stderr()), "enabled ssh password authentication")
}
})
})
}

0 comments on commit 9cf3198

Please sign in to comment.