Skip to content

Commit

Permalink
Restricted flags cannot be used outside testmode
Browse files Browse the repository at this point in the history
  • Loading branch information
otherview committed Nov 30, 2023
1 parent d1d3411 commit 2da41e0
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 3 deletions.
10 changes: 10 additions & 0 deletions go/common/flag/flag.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,16 @@ func (f TenFlag) Bool() bool {
return f.Value.(bool)
}

func (f TenFlag) IsSet() bool {
found := false
flag.Visit(func(fl *flag.Flag) {
if fl.Name == f.Name {
found = true
}
})
return found
}

func CreateCLIFlags(flags map[string]*TenFlag) error {
for _, tflag := range flags {
switch tflag.FlagType {
Expand Down
13 changes: 10 additions & 3 deletions go/config/enclave_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ func DefaultEnclaveConfig() *EnclaveConfig {
}
}

func FromFlags(flagMap map[string]*flag.TenFlag) (*EnclaveConfig, error) {
func FromFlags(cliFlags map[string]*flag.TenFlag) (*EnclaveConfig, error) {
flagsTestMode := false

// check if it's in test mode or not
Expand All @@ -122,9 +122,16 @@ func FromFlags(flagMap map[string]*flag.TenFlag) (*EnclaveConfig, error) {
return nil, fmt.Errorf("unable to retrieve env flags - %w", err)
}

// fail if any restricted flag is set via the cli
for _, envflag := range envFlags {
if cliflag, ok := cliFlags[envflag.Name]; ok && cliflag.IsSet() {
return nil, fmt.Errorf("restricted flag was set: %s", cliflag.Name)
}
}

// create the final flag usage
parsedFlags := map[string]*flag.TenFlag{}
for flagName, cliflag := range flagMap {
for flagName, cliflag := range cliFlags {
parsedFlags[flagName] = cliflag
}
// env flags override CLI flags
Expand All @@ -134,7 +141,7 @@ func FromFlags(flagMap map[string]*flag.TenFlag) (*EnclaveConfig, error) {

return newConfig(parsedFlags)
}
return newConfig(flagMap)
return newConfig(cliFlags)
}

func retrieveEnvFlags() (map[string]*flag.TenFlag, error) {
Expand Down
35 changes: 35 additions & 0 deletions go/config/enclave_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,38 @@ func TestRestrictedMode(t *testing.T) {
require.Equal(t, true, enclaveConfig.ProfilerEnabled)
require.Equal(t, true, enclaveConfig.DebugNamespaceEnabled)
}

func TestRestrictedModeNoCLIDuplication(t *testing.T) {
// Backup the original CommandLine.
originalFlagSet := flag.CommandLine
// Create a new FlagSet for testing purposes.
flag.CommandLine = flag.NewFlagSet("", flag.ContinueOnError)

// Defer a function to reset CommandLine after the test.
defer func() {
flag.CommandLine = originalFlagSet
}()

t.Setenv("EDG_TESTMODE", "false")
t.Setenv("EDG_"+strings.ToUpper(L1ChainIDFlag), "4444")
t.Setenv("EDG_"+strings.ToUpper(ObscuroChainIDFlag), "1243")
t.Setenv("EDG_"+strings.ToUpper(ObscuroGenesisFlag), "{}")
t.Setenv("EDG_"+strings.ToUpper(UseInMemoryDBFlag), "true")
t.Setenv("EDG_"+strings.ToUpper(ProfilerEnabledFlag), "true")
t.Setenv("EDG_"+strings.ToUpper(DebugNamespaceEnabledFlag), "true")

flags := EnclaveFlags()
err := tenflag.CreateCLIFlags(flags)
require.NoError(t, err)

err = flag.CommandLine.Set(NodeTypeFlag, "sequencer")
require.NoError(t, err)

err = flag.CommandLine.Set(L1ChainIDFlag, "5555")
require.NoError(t, err)

flag.Parse()

_, err = FromFlags(flags)
require.Errorf(t, err, "restricted flag was set: l1ChainID")
}

0 comments on commit 2da41e0

Please sign in to comment.