diff --git a/go/common/flag/flag.go b/go/common/flag/flag.go index db362151c2..a783e546c9 100644 --- a/go/common/flag/flag.go +++ b/go/common/flag/flag.go @@ -98,6 +98,16 @@ func (f TenFlag) Bool() bool { return f.Value.(bool) } +func (f TenFlag) IsSet() bool { + found := false + flag.Visit(func(fl *flag.Flag) { + if fl.Name == f.Name { + found = true + } + }) + return found +} + func CreateCLIFlags(flags map[string]*TenFlag) error { for _, tflag := range flags { switch tflag.FlagType { diff --git a/go/config/enclave_config.go b/go/config/enclave_config.go index 8eaefcb719..4feb2244f5 100644 --- a/go/config/enclave_config.go +++ b/go/config/enclave_config.go @@ -105,7 +105,7 @@ func DefaultEnclaveConfig() *EnclaveConfig { } } -func FromFlags(flagMap map[string]*flag.TenFlag) (*EnclaveConfig, error) { +func FromFlags(cliFlags map[string]*flag.TenFlag) (*EnclaveConfig, error) { flagsTestMode := false // check if it's in test mode or not @@ -122,9 +122,16 @@ func FromFlags(flagMap map[string]*flag.TenFlag) (*EnclaveConfig, error) { return nil, fmt.Errorf("unable to retrieve env flags - %w", err) } + // fail if any restricted flag is set via the cli + for _, envflag := range envFlags { + if cliflag, ok := cliFlags[envflag.Name]; ok && cliflag.IsSet() { + return nil, fmt.Errorf("restricted flag was set: %s", cliflag.Name) + } + } + // create the final flag usage parsedFlags := map[string]*flag.TenFlag{} - for flagName, cliflag := range flagMap { + for flagName, cliflag := range cliFlags { parsedFlags[flagName] = cliflag } // env flags override CLI flags @@ -134,7 +141,7 @@ func FromFlags(flagMap map[string]*flag.TenFlag) (*EnclaveConfig, error) { return newConfig(parsedFlags) } - return newConfig(flagMap) + return newConfig(cliFlags) } func retrieveEnvFlags() (map[string]*flag.TenFlag, error) { diff --git a/go/config/enclave_config_test.go b/go/config/enclave_config_test.go index 3d9e8609db..4798efe536 100644 --- a/go/config/enclave_config_test.go +++ b/go/config/enclave_config_test.go @@ -100,3 +100,38 @@ func TestRestrictedMode(t *testing.T) { require.Equal(t, true, enclaveConfig.ProfilerEnabled) require.Equal(t, true, enclaveConfig.DebugNamespaceEnabled) } + +func TestRestrictedModeNoCLIDuplication(t *testing.T) { + // Backup the original CommandLine. + originalFlagSet := flag.CommandLine + // Create a new FlagSet for testing purposes. + flag.CommandLine = flag.NewFlagSet("", flag.ContinueOnError) + + // Defer a function to reset CommandLine after the test. + defer func() { + flag.CommandLine = originalFlagSet + }() + + t.Setenv("EDG_TESTMODE", "false") + t.Setenv("EDG_"+strings.ToUpper(L1ChainIDFlag), "4444") + t.Setenv("EDG_"+strings.ToUpper(ObscuroChainIDFlag), "1243") + t.Setenv("EDG_"+strings.ToUpper(ObscuroGenesisFlag), "{}") + t.Setenv("EDG_"+strings.ToUpper(UseInMemoryDBFlag), "true") + t.Setenv("EDG_"+strings.ToUpper(ProfilerEnabledFlag), "true") + t.Setenv("EDG_"+strings.ToUpper(DebugNamespaceEnabledFlag), "true") + + flags := EnclaveFlags() + err := tenflag.CreateCLIFlags(flags) + require.NoError(t, err) + + err = flag.CommandLine.Set(NodeTypeFlag, "sequencer") + require.NoError(t, err) + + err = flag.CommandLine.Set(L1ChainIDFlag, "5555") + require.NoError(t, err) + + flag.Parse() + + _, err = FromFlags(flags) + require.Errorf(t, err, "restricted flag was set: l1ChainID") +}