diff --git a/dockerfiles/enclave.Dockerfile b/dockerfiles/enclave.Dockerfile index 1a9b81aa86..17e4ac3609 100644 --- a/dockerfiles/enclave.Dockerfile +++ b/dockerfiles/enclave.Dockerfile @@ -45,7 +45,7 @@ FROM build-enclave as build-enclave-testmode-true RUN ego sign enclave-test.json # Tag the restricted mode as the current build -FROM build-enclave-restrictedmode-${RESTRICTEDMODE} as build-enclave +FROM build-enclave-testmode-${TESTMODE} as build-enclave # Trigger a new build stage and use the smaller ego version: FROM ghcr.io/edgelesssys/ego-deploy:v1.3.0 diff --git a/go/common/flag/flag.go b/go/common/flag/flag.go index 6511130b3a..db362151c2 100644 --- a/go/common/flag/flag.go +++ b/go/common/flag/flag.go @@ -3,166 +3,121 @@ package flag import ( "flag" "fmt" - "os" - "strconv" - "strings" ) -// WrappedFlag is a construct that allows to have go flags while obeying to a new set of restrictiveness rules -type WrappedFlag struct { - flagType string - ptr any - defaultValue any - description string +type TenFlag struct { + Name string + Value any + FlagType string + Description string + DefaultValue any } -// GetString returns the flag current value cast to string -func (f WrappedFlag) GetString() string { - return *f.ptr.(*string) -} - -// GetInt64 returns the flag current value cast to int64 -func (f WrappedFlag) GetInt64() int64 { - return *f.ptr.(*int64) -} - -// GetBool returns the flag current value cast to bool -func (f WrappedFlag) GetBool() bool { - return *f.ptr.(*bool) -} - -// singletonFlagger is the singleton instance of the loaded flags -var singletonFlagger = map[string]*WrappedFlag{} - -// String directly uses the go flag package -func String(flagName, defaultValue, description string) *string { - return flag.String(flagName, defaultValue, description) -} - -// RestrictedString wraps the go flag package depending on the restriction mode -func RestrictedString(flagName, defaultValue, description string) *WrappedFlag { - prtVal := new(string) - singletonFlagger[flagName] = &WrappedFlag{ - flagType: "string", - ptr: prtVal, - defaultValue: defaultValue, - description: description, +func NewStringFlag(name, defaultValue, description string) *TenFlag { + return &TenFlag{ + Name: name, + Value: "", + FlagType: "string", + Description: description, + DefaultValue: defaultValue, } - return singletonFlagger[flagName] -} - -// Int64 directly uses the go flag package -func Int64(flagName string, defaultValue int64, description string) *int64 { - return flag.Int64(flagName, defaultValue, description) } -// RestrictedInt64 wraps the go flag package depending on the restriction mode -func RestrictedInt64(flagName string, defaultValue int64, description string) *WrappedFlag { - prtVal := new(int64) - singletonFlagger[flagName] = &WrappedFlag{ - flagType: "int64", - ptr: prtVal, - defaultValue: defaultValue, - description: description, +func NewIntFlag(name string, defaultValue int, description string) *TenFlag { + return &TenFlag{ + Name: name, + Value: 0, + FlagType: "int", + Description: description, + DefaultValue: defaultValue, } - return singletonFlagger[flagName] } -// Bool directly uses the go flag package -func Bool(flagName string, defaultValue bool, description string) *bool { - return flag.Bool(flagName, defaultValue, description) +func NewBoolFlag(name string, defaultValue bool, description string) *TenFlag { + return &TenFlag{ + Name: name, + Value: false, + FlagType: "bool", + Description: description, + DefaultValue: defaultValue, + } } -// RestrictedBool wraps the go flag package depending on the restriction mode -func RestrictedBool(flagName string, defaultValue bool, description string) *WrappedFlag { - prtVal := new(bool) - singletonFlagger[flagName] = &WrappedFlag{ - flagType: "bool", - ptr: prtVal, - defaultValue: defaultValue, - description: description, +func NewInt64Flag(name string, defaultValue int64, description string) *TenFlag { + return &TenFlag{ + Name: name, + Value: false, + FlagType: "int64", + Description: description, + DefaultValue: defaultValue, } - return singletonFlagger[flagName] } -// Int directly uses the go flag package -func Int(flagName string, defaultValue int, description string) *int { - return flag.Int(flagName, defaultValue, description) +func NewUint64Flag(name string, defaultValue uint64, description string) *TenFlag { + return &TenFlag{ + Name: name, + Value: false, + FlagType: "uint64", + Description: description, + DefaultValue: defaultValue, + } } -// Uint64 directly uses the go flag package -func Uint64(flagName string, defaultValue uint64, description string) *uint64 { - return flag.Uint64(flagName, defaultValue, description) +func (f TenFlag) String() string { + if ptrVal, ok := f.Value.(*string); ok { + return *ptrVal + } + return f.Value.(string) } -// Parse ensures the restricted mode is applied only to restricted flags -// Restricted Mode - Flags can only be inputted via ENV Vars via the enclave.json -// Non-Restricted Mode - Flags can only be inputted via normal CLI command line -func Parse() error { - mandatoryEnvFlags := false - val := os.Getenv("EDG_RESTRICTED") - if val == "true" { - fmt.Println("Using mandatory signed configurations.") - mandatoryEnvFlags = true +func (f TenFlag) Int() int { + if ptrVal, ok := f.Value.(*int); ok { + return *ptrVal } + return f.Value.(int) +} - for flagName, wflag := range singletonFlagger { - // parse restricted flags if in restricted mode - if mandatoryEnvFlags { - err := parseMandatoryFlags(flagName, wflag) - if err != nil { - return fmt.Errorf("unable to parse mandatory flag: %s - %w", flagName, err) - } - } else { - err := parseNonMandatoryFlag(flagName, wflag) - if err != nil { - return fmt.Errorf("unable to parse flag: %s - %w", flagName, err) - } - } +func (f TenFlag) Int64() int64 { + if ptrVal, ok := f.Value.(*int64); ok { + return *ptrVal } - - // parse all remaining flags - flag.Parse() - return nil + return f.Value.(int64) } -func parseNonMandatoryFlag(flagName string, wflag *WrappedFlag) error { - switch wflag.flagType { - case "string": - wflag.ptr = flag.String(flagName, wflag.defaultValue.(string), wflag.description) - case "int64": - wflag.ptr = flag.Int64(flagName, wflag.defaultValue.(int64), wflag.description) - case "bool": - wflag.ptr = flag.Bool(flagName, wflag.defaultValue.(bool), wflag.description) - default: - return fmt.Errorf("unexpected type: %s", wflag.flagType) +func (f TenFlag) Uint64() uint64 { + if ptrVal, ok := f.Value.(*uint64); ok { + return *ptrVal } - return nil + return f.Value.(uint64) } -func parseMandatoryFlags(flagName string, wflag *WrappedFlag) error { - val := os.Getenv("EDG_" + strings.ToUpper(flagName)) - if val == "" { - return fmt.Errorf("mandatory restricted flag not available - %s", flagName) +func (f TenFlag) Bool() bool { + if ptrVal, ok := f.Value.(*bool); ok { + return *ptrVal } + return f.Value.(bool) +} - switch wflag.flagType { - case "string": - wflag.ptr = &val - case "int64": - i, err := strconv.ParseInt(val, 10, 64) - if err != nil { - return fmt.Errorf("unable to parse flag %s - %w", flagName, err) - } - wflag.ptr = &i - case "bool": - b, err := strconv.ParseBool(val) - if err != nil { - return fmt.Errorf("unable to parse flag %s - %w", flagName, err) +func CreateCLIFlags(flags map[string]*TenFlag) error { + for _, tflag := range flags { + switch tflag.FlagType { + case "string": + tflag.Value = flag.String(tflag.Name, tflag.DefaultValue.(string), tflag.Description) + case "bool": + tflag.Value = flag.Bool(tflag.Name, tflag.DefaultValue.(bool), tflag.Description) + case "int": + tflag.Value = flag.Int(tflag.Name, tflag.DefaultValue.(int), tflag.Description) + case "int64": + tflag.Value = flag.Int64(tflag.Name, tflag.DefaultValue.(int64), tflag.Description) + case "uint64": + tflag.Value = flag.Uint64(tflag.Name, tflag.DefaultValue.(uint64), tflag.Description) + default: + return fmt.Errorf("unexpected flag type %s", tflag.FlagType) } - wflag.ptr = &b - default: - return fmt.Errorf("unexpected mandatory type: %s", wflag.flagType) } return nil } + +func Parse() { + flag.Parse() +} diff --git a/go/common/flag/flag_test.go b/go/common/flag/flag_test.go deleted file mode 100644 index bae24294cb..0000000000 --- a/go/common/flag/flag_test.go +++ /dev/null @@ -1,68 +0,0 @@ -package flag - -import ( - "fmt" - "os" - "strings" - "testing" - - "github.com/stretchr/testify/require" -) - -func TestStringFlagCreation(t *testing.T) { - expected := "Not the Default" - flagName := "testString" - - // Create the flag - flagOutput := String(flagName, "default", "Test string flag") - - // Parse the flag - os.Args = []string{"cmd", fmt.Sprintf("-%s=%s", flagName, expected)} - - // Parse the flags - require.NoError(t, Parse()) - - // Verify the flag value - require.Equal(t, expected, *flagOutput) -} - -func TestParseInRestrictedMode(t *testing.T) { - // Set up restricted mode - t.Setenv("EDG_RESTRICTED", "true") - defer os.Unsetenv("EDG_RESTRICTED") - - // Create a restricted flag - flagName := "testFlag" - flagOutput := RestrictedString(flagName, "default", "Test restricted flag") - - // Mimic setting the environment variable for the restricted flag - expectedValue := "restrictedValue" - t.Setenv("EDG_"+strings.ToUpper(flagName), expectedValue) - - defer os.Unsetenv("EDG_" + strings.ToUpper(flagName)) - - // Parse the flags - require.NoError(t, Parse()) - - // Verify the flag value - require.Equal(t, expectedValue, flagOutput.GetString()) -} - -func TestParseInUnrestrictedMode(t *testing.T) { - // Ensure unrestricted mode - os.Unsetenv("EDG_RESTRICTED") - - // Create a regular flag - flagName := "testUnrestrictedFlag" - expected := int64(12345) - // Parse the flag - os.Args = []string{"cmd", fmt.Sprintf("-%s=%d", flagName, expected)} - - flagOutput := RestrictedInt64(flagName, int64(1), "Test flag") - - // Parse the flags - require.NoError(t, Parse()) - - // Verify the flag value - require.Equal(t, expected, flagOutput.GetInt64()) -} diff --git a/go/config/enclave_cli_flags.go b/go/config/enclave_cli_flags.go new file mode 100644 index 0000000000..175d2e8e57 --- /dev/null +++ b/go/config/enclave_cli_flags.go @@ -0,0 +1,98 @@ +package config + +import "github.com/ten-protocol/go-ten/go/common/flag" + +// EnclaveFlags are the flags that the enclave can receive +func EnclaveFlags() map[string]*flag.TenFlag { + return map[string]*flag.TenFlag{ + HostIDFlag: flag.NewStringFlag(HostIDFlag, "", FlagDescriptionMap[HostIDFlag]), + HostAddressFlag: flag.NewStringFlag(HostAddressFlag, "", FlagDescriptionMap[HostAddressFlag]), + AddressFlag: flag.NewStringFlag(AddressFlag, "", FlagDescriptionMap[HostAddressFlag]), + NodeTypeFlag: flag.NewStringFlag(NodeTypeFlag, "", FlagDescriptionMap[NodeTypeFlag]), + WillAttestFlag: flag.NewBoolFlag(WillAttestFlag, false, FlagDescriptionMap[WillAttestFlag]), + ValidateL1BlocksFlag: flag.NewBoolFlag(ValidateL1BlocksFlag, false, FlagDescriptionMap[ValidateL1BlocksFlag]), + ManagementContractAddressFlag: flag.NewStringFlag(ManagementContractAddressFlag, "", FlagDescriptionMap[ManagementContractAddressFlag]), + LogLevelFlag: flag.NewIntFlag(LogLevelFlag, 0, FlagDescriptionMap[LogLevelFlag]), + LogPathFlag: flag.NewStringFlag(LogPathFlag, "", FlagDescriptionMap[LogPathFlag]), + EdgelessDBHostFlag: flag.NewStringFlag(EdgelessDBHostFlag, "", FlagDescriptionMap[EdgelessDBHostFlag]), + SQLiteDBPathFlag: flag.NewStringFlag(SQLiteDBPathFlag, "", FlagDescriptionMap[SQLiteDBPathFlag]), + MinGasPriceFlag: flag.NewInt64Flag(MinGasPriceFlag, 0, FlagDescriptionMap[MinGasPriceFlag]), + MessageBusAddressFlag: flag.NewStringFlag(MessageBusAddressFlag, "", FlagDescriptionMap[MessageBusAddressFlag]), + SequencerIDFlag: flag.NewStringFlag(SequencerIDFlag, "", FlagDescriptionMap[SequencerIDFlag]), + MaxBatchSizeFlag: flag.NewUint64Flag(MaxBatchSizeFlag, 0, FlagDescriptionMap[MaxBatchSizeFlag]), + MaxRollupSizeFlag: flag.NewUint64Flag(MaxRollupSizeFlag, 0, FlagDescriptionMap[MaxRollupSizeFlag]), + L2BaseFeeFlag: flag.NewUint64Flag(L2BaseFeeFlag, 0, ""), + L2CoinbaseFlag: flag.NewStringFlag(L2CoinbaseFlag, "", ""), + L2GasLimitFlag: flag.NewUint64Flag(L2GasLimitFlag, 0, ""), + ObscuroGenesisFlag: flag.NewStringFlag(ObscuroGenesisFlag, "", FlagDescriptionMap[ObscuroGenesisFlag]), + L1ChainIDFlag: flag.NewInt64Flag(L1ChainIDFlag, 0, FlagDescriptionMap[L1ChainIDFlag]), + ObscuroChainIDFlag: flag.NewInt64Flag(ObscuroChainIDFlag, 0, FlagDescriptionMap[ObscuroChainIDFlag]), + UseInMemoryDBFlag: flag.NewBoolFlag(UseInMemoryDBFlag, false, FlagDescriptionMap[UseInMemoryDBFlag]), + ProfilerEnabledFlag: flag.NewBoolFlag(ProfilerEnabledFlag, false, FlagDescriptionMap[ProfilerEnabledFlag]), + DebugNamespaceEnabledFlag: flag.NewBoolFlag(DebugNamespaceEnabledFlag, false, FlagDescriptionMap[DebugNamespaceEnabledFlag]), + } +} + +// enclaveRestrictedFlags are the flags that the enclave can receive ONLY over the Ego signed enclave.json +var enclaveRestrictedFlags = map[string]*flag.TenFlag{ + L1ChainIDFlag: flag.NewInt64Flag(L1ChainIDFlag, 0, FlagDescriptionMap[L1ChainIDFlag]), + ObscuroChainIDFlag: flag.NewInt64Flag(ObscuroChainIDFlag, 0, FlagDescriptionMap[ObscuroChainIDFlag]), + ObscuroGenesisFlag: flag.NewStringFlag(ObscuroGenesisFlag, "", FlagDescriptionMap[ObscuroGenesisFlag]), + UseInMemoryDBFlag: flag.NewBoolFlag(UseInMemoryDBFlag, false, FlagDescriptionMap[UseInMemoryDBFlag]), + ProfilerEnabledFlag: flag.NewBoolFlag(ProfilerEnabledFlag, false, FlagDescriptionMap[ProfilerEnabledFlag]), + DebugNamespaceEnabledFlag: flag.NewBoolFlag(DebugNamespaceEnabledFlag, false, FlagDescriptionMap[DebugNamespaceEnabledFlag]), +} + +// Flag names. +const ( + HostIDFlag = "hostID" + HostAddressFlag = "hostAddress" + AddressFlag = "address" + NodeTypeFlag = "nodeType" + L1ChainIDFlag = "l1ChainID" + ObscuroChainIDFlag = "obscuroChainID" + WillAttestFlag = "willAttest" + ValidateL1BlocksFlag = "validateL1Blocks" + ManagementContractAddressFlag = "managementContractAddress" + LogLevelFlag = "logLevel" + LogPathFlag = "logPath" + UseInMemoryDBFlag = "useInMemoryDB" + EdgelessDBHostFlag = "edgelessDBHost" + SQLiteDBPathFlag = "sqliteDBPath" + ProfilerEnabledFlag = "profilerEnabled" + MinGasPriceFlag = "minGasPrice" + MessageBusAddressFlag = "messageBusAddress" + SequencerIDFlag = "sequencerID" + ObscuroGenesisFlag = "obscuroGenesis" + DebugNamespaceEnabledFlag = "debugNamespaceEnabled" + MaxBatchSizeFlag = "maxBatchSize" + MaxRollupSizeFlag = "maxRollupSize" + L2BaseFeeFlag = "l2BaseFee" + L2CoinbaseFlag = "l2Coinbase" + L2GasLimitFlag = "l2GasLimit" +) + +var FlagDescriptionMap = map[string]string{ + HostIDFlag: "The 20 bytes of the address of the Obscuro host this enclave serves", + HostAddressFlag: "The peer-to-peer IP address of the Obscuro host this enclave serves", + AddressFlag: "The address on which to serve the Obscuro enclave service", + NodeTypeFlag: "The node's type (e.g. sequencer, validator)", + L1ChainIDFlag: "An integer representing the unique chain id of the Ethereum chain used as an L1 (default 1337)", + ObscuroChainIDFlag: "An integer representing the unique chain id of the Obscuro chain (default 443)", + WillAttestFlag: "Whether the enclave will produce a verified attestation report", + ValidateL1BlocksFlag: "Whether to validate incoming blocks using the hardcoded L1 genesis.json config", + ManagementContractAddressFlag: "The management contract address on the L1", + LogLevelFlag: "The verbosity level of logs. (Defaults to Info)", + LogPathFlag: "The path to use for the enclave service's log file", + UseInMemoryDBFlag: "Whether the enclave will use an in-memory DB rather than persist data", + EdgelessDBHostFlag: "Host address for the edgeless DB instance (can be empty if useInMemoryDB is true or if not using attestation", + SQLiteDBPathFlag: "Filepath for the sqlite DB persistence file (can be empty if a throwaway file in /tmp/ is acceptable or if using InMemory DB or if using attestation/EdgelessDB)", + ProfilerEnabledFlag: "Runs a profiler instance (Defaults to false)", + MinGasPriceFlag: "The minimum gas price for mining a transaction", + MessageBusAddressFlag: "The address of the L1 message bus contract owned by the management contract.", + SequencerIDFlag: "The 20 bytes of the address of the sequencer for this network", + ObscuroGenesisFlag: "The json string with the obscuro genesis", + DebugNamespaceEnabledFlag: "Whether the debug namespace is enabled", + MaxBatchSizeFlag: "The maximum size a batch is allowed to reach uncompressed", + MaxRollupSizeFlag: "The maximum size a rollup is allowed to reach", +} diff --git a/go/config/enclave_config.go b/go/config/enclave_config.go index 7471c3b391..6453375013 100644 --- a/go/config/enclave_config.go +++ b/go/config/enclave_config.go @@ -1,14 +1,18 @@ package config import ( + "fmt" "math/big" + "os" + "strconv" + "strings" - "github.com/ten-protocol/go-ten/go/common" - - gethcommon "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/params" + "github.com/ten-protocol/go-ten/go/common" + "github.com/ten-protocol/go-ten/go/common/flag" "github.com/ten-protocol/go-ten/go/common/log" + gethcommon "github.com/ethereum/go-ethereum/common" gethlog "github.com/ethereum/go-ethereum/log" ) @@ -100,3 +104,108 @@ func DefaultEnclaveConfig() *EnclaveConfig { GasLimit: new(big.Int).SetUint64(params.MaxGasLimit / 6), } } + +func FromFlags(flagMap map[string]*flag.TenFlag) (*EnclaveConfig, error) { + flagsTestMode := false + + // check if it's in test mode or not + val := os.Getenv("EDG_TESTMODE") + if val == "true" { + flagsTestMode = true + } else { + fmt.Println("Using mandatory signed configurations.") + } + + if !flagsTestMode { + envFlags, err := retrieveEnvFlags() + if err != nil { + return nil, fmt.Errorf("unable to retrieve env flags - %w", err) + } + + // create the final flag usage + parsedFlags := map[string]*flag.TenFlag{} + for flagName, cliflag := range flagMap { + parsedFlags[flagName] = cliflag + } + // env flags override CLI flags + for flagName, envflag := range envFlags { + parsedFlags[flagName] = envflag + } + + return newConfig(parsedFlags) + } + return newConfig(flagMap) +} + +func retrieveEnvFlags() (map[string]*flag.TenFlag, error) { + parsedFlags := map[string]*flag.TenFlag{} + + for _, eflag := range enclaveRestrictedFlags { + val := os.Getenv("EDG_" + strings.ToUpper(eflag.Name)) + switch eflag.FlagType { + case "string": + parsedFlag := flag.NewStringFlag(eflag.Name, "", "") + parsedFlag.Value = val + + parsedFlags[eflag.Name] = parsedFlag + case "int64": + i, err := strconv.ParseInt(val, 10, 64) + if err != nil { + return nil, fmt.Errorf("unable to parse flag %s - %w", eflag.Name, err) + } + + parsedFlag := flag.NewIntFlag(eflag.Name, 0, "") + parsedFlag.Value = i + parsedFlags[eflag.Name] = parsedFlag + case "bool": + b, err := strconv.ParseBool(val) + if err != nil { + return nil, fmt.Errorf("unable to parse flag %s - %w", eflag.Name, err) + } + + parsedFlag := flag.NewBoolFlag(eflag.Name, false, "") + parsedFlag.Value = b + parsedFlags[eflag.Name] = parsedFlag + default: + return nil, fmt.Errorf("unexpected type: %s", eflag.FlagType) + } + } + return parsedFlags, nil +} + +func newConfig(flags map[string]*flag.TenFlag) (*EnclaveConfig, error) { + cfg := DefaultEnclaveConfig() + + nodeType, err := common.ToNodeType(flags[NodeTypeFlag].String()) + if err != nil { + return nil, fmt.Errorf("unrecognised node type '%s'", flags[NodeTypeFlag].String()) + } + + cfg.HostID = gethcommon.HexToAddress(flags[HostIDFlag].String()) + cfg.HostAddress = flags[HostAddressFlag].String() + cfg.Address = flags[AddressFlag].String() + cfg.NodeType = nodeType + cfg.L1ChainID = flags[L1ChainIDFlag].Int64() + cfg.ObscuroChainID = flags[ObscuroChainIDFlag].Int64() + cfg.WillAttest = flags[WillAttestFlag].Bool() + cfg.ValidateL1Blocks = flags[ValidateL1BlocksFlag].Bool() + cfg.ManagementContractAddress = gethcommon.HexToAddress(flags[ManagementContractAddressFlag].String()) + cfg.LogLevel = flags[LogLevelFlag].Int() + cfg.LogPath = flags[LogPathFlag].String() + cfg.UseInMemoryDB = flags[UseInMemoryDBFlag].Bool() + cfg.EdgelessDBHost = flags[EdgelessDBHostFlag].String() + cfg.SqliteDBPath = flags[SQLiteDBPathFlag].String() + cfg.ProfilerEnabled = flags[ProfilerEnabledFlag].Bool() + cfg.MinGasPrice = big.NewInt(flags[MinGasPriceFlag].Int64()) + cfg.MessageBusAddress = gethcommon.HexToAddress(flags[MessageBusAddressFlag].String()) + cfg.SequencerID = gethcommon.HexToAddress(flags[SequencerIDFlag].String()) + cfg.ObscuroGenesis = flags[ObscuroGenesisFlag].String() + cfg.DebugNamespaceEnabled = flags[DebugNamespaceEnabledFlag].Bool() + cfg.MaxBatchSize = flags[MaxBatchSizeFlag].Uint64() + cfg.MaxRollupSize = flags[MaxRollupSizeFlag].Uint64() + cfg.BaseFee = big.NewInt(0).SetUint64(flags[L2BaseFeeFlag].Uint64()) + cfg.GasPaymentAddress = gethcommon.HexToAddress(flags[L2CoinbaseFlag].String()) + cfg.GasLimit = big.NewInt(0).SetUint64(flags[L2GasLimitFlag].Uint64()) + + return cfg, nil +} diff --git a/go/config/enclave_config_test.go b/go/config/enclave_config_test.go new file mode 100644 index 0000000000..01d5c8da9f --- /dev/null +++ b/go/config/enclave_config_test.go @@ -0,0 +1,55 @@ +package config + +import ( + "flag" + "testing" + + "github.com/stretchr/testify/require" + tenflag "github.com/ten-protocol/go-ten/go/common/flag" +) + +func TestFromFlags(t *testing.T) { + // Backup the original CommandLine. + originalFlagSet := flag.CommandLine + // Create a new FlagSet for testing purposes. + flag.CommandLine = flag.NewFlagSet("", flag.ContinueOnError) + + // Defer a function to reset CommandLine after the test. + defer func() { + flag.CommandLine = originalFlagSet + }() + + flags := EnclaveFlags() + err := tenflag.CreateCLIFlags(flags) + require.NoError(t, err) + + // Set the flags as needed for the test. + err = flag.CommandLine.Set(HostIDFlag, "string-value") + require.NoError(t, err) + + err = flag.CommandLine.Set(NodeTypeFlag, "sequencer") + require.NoError(t, err) + + err = flag.CommandLine.Set(WillAttestFlag, "true") + require.NoError(t, err) + + err = flag.CommandLine.Set(LogLevelFlag, "123") + require.NoError(t, err) + + err = flag.CommandLine.Set(MinGasPriceFlag, "3333") + require.NoError(t, err) + + err = flag.CommandLine.Set(L2GasLimitFlag, "222222") + require.NoError(t, err) + + flag.Parse() + + require.Equal(t, "string-value", flags[HostIDFlag].String()) + require.Equal(t, true, flags[WillAttestFlag].Bool()) + require.Equal(t, 123, flags[LogLevelFlag].Int()) + require.Equal(t, int64(3333), flags[MinGasPriceFlag].Int64()) + require.Equal(t, uint64(222222), flags[L2GasLimitFlag].Uint64()) + + _, err = newConfig(flags) + require.NoError(t, err) +} diff --git a/go/enclave/container/cli.go b/go/enclave/container/cli.go deleted file mode 100644 index c440b95de2..0000000000 --- a/go/enclave/container/cli.go +++ /dev/null @@ -1,115 +0,0 @@ -package container - -import ( - "fmt" - "math/big" - - "github.com/ten-protocol/go-ten/go/common" - "github.com/ten-protocol/go-ten/go/common/flag" - "github.com/ten-protocol/go-ten/go/config" - - gethcommon "github.com/ethereum/go-ethereum/common" -) - -// EnclaveConfigToml is the structure that an enclave's .toml config is parsed into. -type EnclaveConfigToml struct { - HostID string - HostAddress string - Address string - NodeType string - L1ChainID int64 - ObscuroChainID int64 - WillAttest bool - ValidateL1Blocks bool - ManagementContractAddress string - LogLevel int - LogPath string - UseInMemoryDB bool - GenesisJSON string - EdgelessDBHost string - SqliteDBPath string - ProfilerEnabled bool - MinGasPrice int64 - MessageBusAddress string - SequencerID string - ObscuroGenesis string - DebugNamespaceEnabled bool - MaxBatchSize uint64 - MaxRollupSize uint64 - GasPaymentAddress string - BaseFee uint64 - GasLimit uint64 -} - -// ParseConfig returns a config.EnclaveConfig based on either the file identified by the `config` flag, or the flags -// with specific defaults (if the `config` flag isn't specified). -func ParseConfig() (*config.EnclaveConfig, error) { - cfg := config.DefaultEnclaveConfig() - flagUsageMap := getFlagUsageMap() - - hostID := flag.String(hostIDName, cfg.HostID.Hex(), flagUsageMap[hostIDName]) - hostAddress := flag.String(hostAddressName, cfg.HostAddress, flagUsageMap[hostAddressName]) - address := flag.String(addressName, cfg.Address, flagUsageMap[addressName]) - nodeTypeStr := flag.String(nodeTypeName, cfg.NodeType.String(), flagUsageMap[nodeTypeName]) - willAttest := flag.Bool(willAttestName, cfg.WillAttest, flagUsageMap[willAttestName]) - validateL1Blocks := flag.Bool(validateL1BlocksName, cfg.ValidateL1Blocks, flagUsageMap[validateL1BlocksName]) - managementContractAddress := flag.String(ManagementContractAddressName, cfg.ManagementContractAddress.Hex(), flagUsageMap[ManagementContractAddressName]) - loglevel := flag.Int(logLevelName, cfg.LogLevel, flagUsageMap[logLevelName]) - logPath := flag.String(logPathName, cfg.LogPath, flagUsageMap[logPathName]) - edgelessDBHost := flag.String(edgelessDBHostName, cfg.EdgelessDBHost, flagUsageMap[edgelessDBHostName]) - sqliteDBPath := flag.String(sqliteDBPathName, cfg.SqliteDBPath, flagUsageMap[sqliteDBPathName]) - minGasPrice := flag.Int64(minGasPriceName, cfg.MinGasPrice.Int64(), flagUsageMap[minGasPriceName]) - messageBusAddress := flag.String(messageBusAddressName, cfg.MessageBusAddress.Hex(), flagUsageMap[messageBusAddressName]) - sequencerID := flag.String(sequencerIDName, cfg.SequencerID.Hex(), flagUsageMap[sequencerIDName]) - maxBatchSize := flag.Uint64(maxBatchSizeName, cfg.MaxBatchSize, flagUsageMap[maxBatchSizeName]) - maxRollupSize := flag.Uint64(maxRollupSizeName, cfg.MaxRollupSize, flagUsageMap[maxRollupSizeName]) - baseFee := flag.Uint64("l2BaseFee", cfg.BaseFee.Uint64(), "") - coinbaseAddress := flag.String("l2Coinbase", cfg.GasPaymentAddress.Hex(), "") - gasLimit := flag.Uint64("l2GasLimit", cfg.GasLimit.Uint64(), "") - - // set of restricted flags that can only be set in the signed enclave.json - obscuroGenesis := flag.RestrictedString(obscuroGenesisName, cfg.ObscuroGenesis, flagUsageMap[obscuroGenesisName]) - l1ChainID := flag.RestrictedInt64(l1ChainIDName, cfg.L1ChainID, flagUsageMap[l1ChainIDName]) - obscuroChainID := flag.RestrictedInt64(obscuroChainIDName, cfg.ObscuroChainID, flagUsageMap[obscuroChainIDName]) - useInMemoryDB := flag.RestrictedBool(useInMemoryDBName, cfg.UseInMemoryDB, flagUsageMap[useInMemoryDBName]) - profilerEnabled := flag.RestrictedBool(profilerEnabledName, cfg.ProfilerEnabled, flagUsageMap[profilerEnabledName]) - debugNamespaceEnabled := flag.RestrictedBool(debugNamespaceEnabledName, cfg.DebugNamespaceEnabled, flagUsageMap[debugNamespaceEnabledName]) - - err := flag.Parse() - if err != nil { - return nil, err - } - - nodeType, err := common.ToNodeType(*nodeTypeStr) - if err != nil { - return nil, fmt.Errorf("unrecognised node type '%s'", *nodeTypeStr) - } - - cfg.HostID = gethcommon.HexToAddress(*hostID) - cfg.HostAddress = *hostAddress - cfg.Address = *address - cfg.NodeType = nodeType - cfg.L1ChainID = l1ChainID.GetInt64() - cfg.ObscuroChainID = obscuroChainID.GetInt64() - cfg.WillAttest = *willAttest - cfg.ValidateL1Blocks = *validateL1Blocks - cfg.ManagementContractAddress = gethcommon.HexToAddress(*managementContractAddress) - cfg.LogLevel = *loglevel - cfg.LogPath = *logPath - cfg.UseInMemoryDB = useInMemoryDB.GetBool() - cfg.EdgelessDBHost = *edgelessDBHost - cfg.SqliteDBPath = *sqliteDBPath - cfg.ProfilerEnabled = profilerEnabled.GetBool() - cfg.MinGasPrice = big.NewInt(*minGasPrice) - cfg.MessageBusAddress = gethcommon.HexToAddress(*messageBusAddress) - cfg.SequencerID = gethcommon.HexToAddress(*sequencerID) - cfg.ObscuroGenesis = obscuroGenesis.GetString() - cfg.DebugNamespaceEnabled = debugNamespaceEnabled.GetBool() - cfg.MaxBatchSize = *maxBatchSize - cfg.MaxRollupSize = *maxRollupSize - cfg.BaseFee = big.NewInt(0).SetUint64(*baseFee) - cfg.GasPaymentAddress = gethcommon.HexToAddress(*coinbaseAddress) - cfg.GasLimit = big.NewInt(0).SetUint64(*gasLimit) - - return cfg, nil -} diff --git a/go/enclave/container/cli_flags.go b/go/enclave/container/cli_flags.go deleted file mode 100644 index 7bf4eb6bbf..0000000000 --- a/go/enclave/container/cli_flags.go +++ /dev/null @@ -1,58 +0,0 @@ -package container - -// Flag names. -const ( - configName = "config" - hostIDName = "hostID" - hostAddressName = "hostAddress" - addressName = "address" - nodeTypeName = "nodeType" - l1ChainIDName = "l1ChainID" - obscuroChainIDName = "obscuroChainID" - willAttestName = "willAttest" - validateL1BlocksName = "validateL1Blocks" - ManagementContractAddressName = "managementContractAddress" - logLevelName = "logLevel" - logPathName = "logPath" - useInMemoryDBName = "useInMemoryDB" - edgelessDBHostName = "edgelessDBHost" - sqliteDBPathName = "sqliteDBPath" - profilerEnabledName = "profilerEnabled" - minGasPriceName = "minGasPrice" - messageBusAddressName = "messageBusAddress" - sequencerIDName = "sequencerID" - obscuroGenesisName = "obscuroGenesis" - debugNamespaceEnabledName = "debugNamespaceEnabled" - maxBatchSizeName = "maxBatchSize" - maxRollupSizeName = "maxRollupSize" -) - -// Returns a map of the flag usages. -// While we could just use constants instead of a map, this approach allows us to test that all the expected flags are defined. -func getFlagUsageMap() map[string]string { - return map[string]string{ - configName: "The path to the node's config file. Overrides all other flags", - hostIDName: "The 20 bytes of the address of the Obscuro host this enclave serves", - hostAddressName: "The peer-to-peer IP address of the Obscuro host this enclave serves", - addressName: "The address on which to serve the Obscuro enclave service", - nodeTypeName: "The node's type (e.g. sequencer, validator)", - l1ChainIDName: "An integer representing the unique chain id of the Ethereum chain used as an L1 (default 1337)", - obscuroChainIDName: "An integer representing the unique chain id of the Obscuro chain (default 443)", - willAttestName: "Whether the enclave will produce a verified attestation report", - validateL1BlocksName: "Whether to validate incoming blocks using the hardcoded L1 genesis.json config", - ManagementContractAddressName: "The management contract address on the L1", - logLevelName: "The verbosity level of logs. (Defaults to Info)", - logPathName: "The path to use for the enclave service's log file", - useInMemoryDBName: "Whether the enclave will use an in-memory DB rather than persist data", - edgelessDBHostName: "Host address for the edgeless DB instance (can be empty if useInMemoryDB is true or if not using attestation", - sqliteDBPathName: "Filepath for the sqlite DB persistence file (can be empty if a throwaway file in /tmp/ is acceptable or if using InMemory DB or if using attestation/EdgelessDB)", - profilerEnabledName: "Runs a profiler instance (Defaults to false)", - minGasPriceName: "The minimum gas price for mining a transaction", - messageBusAddressName: "The address of the L1 message bus contract owned by the management contract.", - sequencerIDName: "The 20 bytes of the address of the sequencer for this network", - obscuroGenesisName: "The json string with the obscuro genesis", - debugNamespaceEnabledName: "Whether the debug namespace is enabled", - maxBatchSizeName: "The maximum size a batch is allowed to reach uncompressed", - maxRollupSizeName: "The maximum size a rollup is allowed to reach", - } -} diff --git a/go/enclave/main/enclave-test.json b/go/enclave/main/enclave-test.json index 3e2c3bbcf7..260daa7724 100644 --- a/go/enclave/main/enclave-test.json +++ b/go/enclave/main/enclave-test.json @@ -16,8 +16,8 @@ ], "env": [ { - "name": "RESTRICTED", - "value": "false" + "name": "EDG_TESTMODE", + "value": "true" } ] } \ No newline at end of file diff --git a/go/enclave/main/enclave.json b/go/enclave/main/enclave.json index 3db18c70d4..7919f528d9 100644 --- a/go/enclave/main/enclave.json +++ b/go/enclave/main/enclave.json @@ -16,8 +16,8 @@ ], "env": [ { - "name": "RESTRICTED", - "value": "true" + "name": "TESTMODE", + "value": "false" }, { "name": "L1CHAINID", diff --git a/go/enclave/main/main.go b/go/enclave/main/main.go index d03967a258..7d062a1b06 100644 --- a/go/enclave/main/main.go +++ b/go/enclave/main/main.go @@ -4,16 +4,27 @@ import ( "fmt" "github.com/ten-protocol/go-ten/go/common/container" + tenflag "github.com/ten-protocol/go-ten/go/common/flag" + "github.com/ten-protocol/go-ten/go/config" enclavecontainer "github.com/ten-protocol/go-ten/go/enclave/container" ) // Runs an Obscuro enclave as a standalone process. func main() { - config, err := enclavecontainer.ParseConfig() + // fetch and parse flags + flags := config.EnclaveFlags() + err := tenflag.CreateCLIFlags(flags) if err != nil { panic(fmt.Errorf("could not parse config. Cause: %w", err)) } - enclaveContainer := enclavecontainer.NewEnclaveContainerFromConfig(config) + tenflag.Parse() + + enclaveConfig, err := config.FromFlags(flags) + if err != nil { + panic(fmt.Errorf("unable to create config from flags - %w", err)) + } + + enclaveContainer := enclavecontainer.NewEnclaveContainerFromConfig(enclaveConfig) container.Serve(enclaveContainer) }