diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml new file mode 100644 index 00000000..ad020f2e --- /dev/null +++ b/.github/workflows/go.yml @@ -0,0 +1,27 @@ +# This workflow will build a golang project +# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-go + +name: Go + +on: [push , pull_request] + +jobs: + + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + + - name: Set up Go + uses: actions/setup-go@v4 + with: + go-version: '1.20' + +# - name: Run init +# run: go run build/ci.go lint + + - name: Build + run: make geth + + - name: Test + run: go test -timeout=40m -tags=ckzg -p 1 ./... diff --git a/accounts/keystore/keystore.go b/accounts/keystore/keystore.go index 88dcfbeb..64b34504 100644 --- a/accounts/keystore/keystore.go +++ b/accounts/keystore/keystore.go @@ -240,7 +240,7 @@ func (ks *KeyStore) Delete(a accounts.Account, passphrase string) error { // Decrypting the key isn't really necessary, but we do // it anyway to check the password and zero out the key // immediately afterwards. - a, key, err := ks.getDecryptedKey(a, passphrase) + a, key, err := ks.GetDecryptedKey(a, passphrase) if key != nil { zeroKey(key.PrivateKey) } @@ -292,7 +292,7 @@ func (ks *KeyStore) SignTx(a accounts.Account, tx *types.Transaction, chainID *b // can be decrypted with the given passphrase. The produced signature is in the // [R || S || V] format where V is 0 or 1. func (ks *KeyStore) SignHashWithPassphrase(a accounts.Account, passphrase string, hash []byte) (signature []byte, err error) { - _, key, err := ks.getDecryptedKey(a, passphrase) + _, key, err := ks.GetDecryptedKey(a, passphrase) if err != nil { return nil, err } @@ -303,7 +303,7 @@ func (ks *KeyStore) SignHashWithPassphrase(a accounts.Account, passphrase string // SignTxWithPassphrase signs the transaction if the private key matching the // given address can be decrypted with the given passphrase. func (ks *KeyStore) SignTxWithPassphrase(a accounts.Account, passphrase string, tx *types.Transaction, chainID *big.Int) (*types.Transaction, error) { - _, key, err := ks.getDecryptedKey(a, passphrase) + _, key, err := ks.GetDecryptedKey(a, passphrase) if err != nil { return nil, err } @@ -338,7 +338,7 @@ func (ks *KeyStore) Lock(addr common.Address) error { // shortens the active unlock timeout. If the address was previously unlocked // indefinitely the timeout is not altered. func (ks *KeyStore) TimedUnlock(a accounts.Account, passphrase string, timeout time.Duration) error { - a, key, err := ks.getDecryptedKey(a, passphrase) + a, key, err := ks.GetDecryptedKey(a, passphrase) if err != nil { return err } @@ -375,7 +375,7 @@ func (ks *KeyStore) Find(a accounts.Account) (accounts.Account, error) { return a, err } -func (ks *KeyStore) getDecryptedKey(a accounts.Account, auth string) (accounts.Account, *Key, error) { +func (ks *KeyStore) GetDecryptedKey(a accounts.Account, auth string) (accounts.Account, *Key, error) { a, err := ks.Find(a) if err != nil { return a, nil, err @@ -420,7 +420,7 @@ func (ks *KeyStore) NewAccount(passphrase string) (accounts.Account, error) { // Export exports as a JSON key, encrypted with newPassphrase. func (ks *KeyStore) Export(a accounts.Account, passphrase, newPassphrase string) (keyJSON []byte, err error) { - _, key, err := ks.getDecryptedKey(a, passphrase) + _, key, err := ks.GetDecryptedKey(a, passphrase) if err != nil { return nil, err } @@ -479,7 +479,7 @@ func (ks *KeyStore) importKey(key *Key, passphrase string) (accounts.Account, er // Update changes the passphrase of an existing account. func (ks *KeyStore) Update(a accounts.Account, passphrase, newPassphrase string) error { - a, key, err := ks.getDecryptedKey(a, passphrase) + a, key, err := ks.GetDecryptedKey(a, passphrase) if err != nil { return err } diff --git a/consensus/consensus.go b/consensus/consensus.go index 6bfb94c0..bac9e2de 100644 --- a/consensus/consensus.go +++ b/consensus/consensus.go @@ -169,12 +169,6 @@ type HotStuff interface { // FillHeader fulfill the header with extra which contains epoch change info FillHeader(state *state.StateDB, header *types.Header) error - - // IsSystemCall return method id and true if the tx is an system transaction - IsSystemTransaction(tx *types.Transaction, header *types.Header) (string, bool) - - // HasSystemTxHook return true if systemTxHook is not nil - HasSystemTxHook() bool } // Handler should be implemented is the consensus needs to handle and send peer's message diff --git a/consensus/hotstuff/backend.go b/consensus/hotstuff/backend.go index 5c03c236..d4ea89a3 100644 --- a/consensus/hotstuff/backend.go +++ b/consensus/hotstuff/backend.go @@ -33,7 +33,7 @@ type Backend interface { Address() common.Address // Validators returns current epoch participants - Validators(height uint64, inConsensus bool) ValidatorSet + Validators(height uint64, inConsensus bool) (ValidatorSet, error) // EventMux returns the event mux in backend EventMux() *event.TypeMux @@ -70,7 +70,7 @@ type Backend interface { // CheckPoint retrieve the flag of epoch change and new epoch start height CheckPoint(height uint64) (uint64, bool) - ReStart() + Reset() Close() error } diff --git a/consensus/hotstuff/backend/engine.go b/consensus/hotstuff/backend/engine.go index 570be4e2..c9e6cd3b 100644 --- a/consensus/hotstuff/backend/engine.go +++ b/consensus/hotstuff/backend/engine.go @@ -288,24 +288,27 @@ func (s *backend) Close() error { return nil } -func (s *backend) ReStart() { +func (s *backend) Reset() { + if !s.coreStarted { + log.Errorf("Try to reset stopped core engine") + return + } + log.Debug("Reset consensus engine...") + next, err := s.newEpochValidators() if err != nil { - panic(fmt.Errorf("Restart consensus engine failed, err: %v ", err)) + panic(fmt.Errorf("Reset consensus engine failed, err: %v ", err)) } - if next.Equal(s.vals.Copy()) { - log.Trace("Restart Consensus engine, validators not changed.", "origin", s.vals.AddressList(), "current", next.AddressList()) + log.Trace("Reset Consensus engine, validators not changed.", "origin", s.vals.AddressList(), "current", next.AddressList()) return } - if s.coreStarted { - s.Stop() - // waiting for last engine instance free resource, e.g: unsubscribe... - time.Sleep(2 * time.Second) - log.Debug("Restart consensus engine...") - s.Start(s.chain, s.hasBadBlock) - } + // reset validator set + s.vals = next.Copy() + + // p2p module connect nodes directly + s.nodesFeed.Send(consensus.StaticNodesEvent{Validators: s.vals.AddressList()}) } // verifyHeader checks whether a header conforms to the consensus rules.The diff --git a/consensus/hotstuff/backend/governance.go b/consensus/hotstuff/backend/governance.go index 287db378..14c32dec 100644 --- a/consensus/hotstuff/backend/governance.go +++ b/consensus/hotstuff/backend/governance.go @@ -95,20 +95,20 @@ func (s *backend) CheckPoint(height uint64) (uint64, bool) { } // Validators get validators from backend by `consensus core`, param of `mining` is false denote need last epoch validators. -func (s *backend) Validators(height uint64, mining bool) hotstuff.ValidatorSet { +func (s *backend) Validators(height uint64, mining bool) (hotstuff.ValidatorSet, error) { if mining { - return s.vals.Copy() + return s.vals.Copy(), nil } header := s.chain.GetHeaderByNumber(height) if header == nil { - return nil + return nil, fmt.Errorf("GetHeaderByNumber, header is nil") } _, vals, err := s.getValidatorsByHeader(header, nil, s.chain) if err != nil { - return nil + return nil, err } - return vals + return vals, nil } // getValidatorsByHeader check if current header height is an new epoch start and retrieve the validators. diff --git a/consensus/hotstuff/core/core.go b/consensus/hotstuff/core/core.go index 2477b5d7..147c875a 100644 --- a/consensus/hotstuff/core/core.go +++ b/consensus/hotstuff/core/core.go @@ -58,6 +58,8 @@ type core struct { validateFn func(common.Hash, []byte) (common.Address, error) checkPointFn func(uint64) (uint64, bool) isRunning bool + + wg sync.WaitGroup } // New creates an HotStuff consensus core @@ -125,16 +127,20 @@ func (c *core) startNewRound(round *big.Int) { Height: new(big.Int).Add(lastProposal.Number(), common.Big1), Round: new(big.Int), } + var changeEpoch bool if changeView { newView.Height = new(big.Int).Set(c.current.Height()) newView.Round = new(big.Int).Set(round) - } else if c.checkPoint(newView) { - logger.Trace("Stop engine after check point.") - return + } else { + changeEpoch = c.checkPoint(newView) } // calculate validator set - c.valSet = c.backend.Validators(newView.HeightU64(), true) + var err error + if c.valSet, err = c.backend.Validators(newView.HeightU64(), true); err != nil { + logger.Error("get validator set failed", "err", err) + return + } c.valSet.CalcProposer(lastProposer, newView.Round.Uint64()) // update smr and try to unlock at the round0 @@ -142,13 +148,9 @@ func (c *core) startNewRound(round *big.Int) { logger.Error("Update round state failed", "state", c.currentState(), "newView", newView, "err", err) return } - if !changeView { - if err := c.current.Unlock(); err != nil { - logger.Error("Unlock node failed", "newView", newView, "err", err) - return - } + if changeEpoch { + c.current.Unlock() } - logger.Debug("New round", "state", c.currentState(), "newView", newView, "new_proposer", c.valSet.GetProposer(), "valSet", c.valSet.List(), "size", c.valSet.Size(), "IsProposer", c.IsProposer()) // stop last timer and regenerate new timer @@ -170,9 +172,7 @@ func (c *core) checkPoint(view *View) bool { c.point = epochStart c.lastVals = c.valSet.Copy() c.logger.Trace("CheckPoint done", "view", view, "point", c.point) - c.backend.ReStart() - } - if !c.isRunning { + c.backend.Reset() return true } return false diff --git a/consensus/hotstuff/core/decide.go b/consensus/hotstuff/core/decide.go index 372cca49..51da6ad7 100644 --- a/consensus/hotstuff/core/decide.go +++ b/consensus/hotstuff/core/decide.go @@ -29,8 +29,10 @@ import ( // handleCommitVote implement description as follow: // ``` // leader wait for (n n f) votes: V ← {v | matchingMsg(v, commit, curView)} +// // commitQC ← QC(V ) // broadcast Msg(decide, ⊥, commitQC ) +// // ``` func (c *core) handleCommitVote(data *Message) error { var ( @@ -200,6 +202,9 @@ func (c *core) handleDecide(data *Message) error { } } + //prepare for new round + c.current.Unlock() + c.startNewRound(common.Big0) return nil } diff --git a/consensus/hotstuff/core/handler.go b/consensus/hotstuff/core/handler.go index 347abeb6..cf2455b3 100644 --- a/consensus/hotstuff/core/handler.go +++ b/consensus/hotstuff/core/handler.go @@ -32,10 +32,11 @@ func (c *core) Start(chain consensus.ChainReader) { c.current = nil c.subscribeEvents() - go c.handleEvents() // Start a new round from last sequence + 1 c.startNewRound(common.Big0) + c.wg.Add(1) + go c.handleEvents() } // Stop implements core.Engine.Stop @@ -43,6 +44,7 @@ func (c *core) Stop() { c.stopTimer() c.unsubscribeEvents() c.isRunning = false + c.wg.Wait() } // Address implement core.Engine.Address @@ -100,6 +102,7 @@ func (c *core) unsubscribeEvents() { } func (c *core) handleEvents() { + defer c.wg.Done() logger := c.logger.New("handleEvents") for { diff --git a/consensus/hotstuff/core/prepare.go b/consensus/hotstuff/core/prepare.go index c618036f..0534dab4 100644 --- a/consensus/hotstuff/core/prepare.go +++ b/consensus/hotstuff/core/prepare.go @@ -53,6 +53,10 @@ func (c *core) sendPrepare() { request := c.current.PendingRequest() if request == nil || request.block == nil || request.block.NumberU64() != c.HeightU64() { logger.Trace("Pending request invalid", "msg", code) + if request != nil && request.block != nil { + logger.Trace("Pending request invalid", "msg", code, "request.block.Number", request.block.NumberU64(), + "c.Height", c.HeightU64(), "request.block.hash", request.block.SealHash()) + } return } else { block = c.current.PendingRequest().block @@ -91,10 +95,12 @@ func (c *core) sendPrepare() { // handlePrepare implement description as follow: // ``` -// repo wait for message m : matchingMsg(m, prepare, curView) from leader(curView) -// if m.node extends from m.justify.node ∧ -// safeNode(m.node, m.justify) then -// send voteMsg(prepare, m.node, ⊥) to leader(curView) +// +// repo wait for message m : matchingMsg(m, prepare, curView) from leader(curView) +// if m.node extends from m.justify.node ∧ +// safeNode(m.node, m.justify) then +// send voteMsg(prepare, m.node, ⊥) to leader(curView) +// // ``` func (c *core) handlePrepare(data *Message) error { var ( @@ -127,17 +133,19 @@ func (c *core) handlePrepare(data *Message) error { // ensure remote block is legal. block := node.Block - if err := c.checkBlock(block); err != nil { - logger.Trace("Failed to check block", "msg", code, "src", src, "err", err) - return err - } - if duration, err := c.backend.Verify(block, false); err != nil { - logger.Trace("Failed to verify unsealed proposal", "msg", code, "src", src, "err", err, "duration", duration) - return errVerifyUnsealedProposal - } - if err := c.executeBlock(block); err != nil { - logger.Trace("Failed to execute block", "msg", code, "src", src, "err", err) - return err + if c.current.executed == nil || c.current.executed.Block != nil || c.current.executed.Block.SealHash() != block.SealHash() { + if err := c.checkBlock(block); err != nil { + logger.Trace("Failed to check block", "msg", code, "src", src, "err", err) + return err + } + if duration, err := c.backend.Verify(block, false); err != nil { + logger.Trace("Failed to verify unsealed proposal", "msg", code, "src", src, "err", err, "duration", duration) + return errVerifyUnsealedProposal + } + if err := c.executeBlock(block); err != nil { + logger.Trace("Failed to execute block", "msg", code, "src", src, "err", err) + return err + } } // safety and liveness rules judgement. diff --git a/consensus/hotstuff/core/request.go b/consensus/hotstuff/core/request.go index 43fecad3..2760ec74 100644 --- a/consensus/hotstuff/core/request.go +++ b/consensus/hotstuff/core/request.go @@ -46,6 +46,7 @@ func (c *core) handleRequest(request *Request) error { if c.current.PendingRequest() == nil || c.current.PendingRequest().block.NumberU64() < c.current.HeightU64() { c.current.SetPendingRequest(request) + logger.Trace("Set PendingRequest", "number", request.block.NumberU64(), "hash", request.block.SealHash()) c.sendPrepare() } else { logger.Trace("PendingRequest exist") diff --git a/consensus/hotstuff/core/round_state.go b/consensus/hotstuff/core/round_state.go index 9e4a9b22..a7c0b782 100644 --- a/consensus/hotstuff/core/round_state.go +++ b/consensus/hotstuff/core/round_state.go @@ -149,9 +149,7 @@ func (s *roundState) LastChainedBlock() *types.Block { // accept pending request from miner only for once. func (s *roundState) SetPendingRequest(req *Request) { - if s.pendingRequest == nil { - s.pendingRequest = req - } + s.pendingRequest = req } func (s *roundState) PendingRequest() *Request { @@ -207,13 +205,12 @@ func (s *roundState) LockQC() *QuorumCert { } // Unlock it's happened at the start of new round, new state is `StateAcceptRequest`, and `lockQC` keep to judge safety rule -func (s *roundState) Unlock() error { +func (s *roundState) Unlock() { s.pendingRequest = nil s.proposalLocked = false s.lockedBlock = nil s.node.temp = nil s.executed = nil - return nil } func (s *roundState) LockedBlock() *types.Block { diff --git a/consensus/hotstuff/core/test_utils.go b/consensus/hotstuff/core/test_utils.go index 573c6d97..130a537b 100644 --- a/consensus/hotstuff/core/test_utils.go +++ b/consensus/hotstuff/core/test_utils.go @@ -136,8 +136,8 @@ func (ts *testSystemBackend) Address() common.Address { } // Peers returns all connected peers -func (ts *testSystemBackend) Validators(height uint64, mining bool) hotstuff.ValidatorSet { - return ts.peers +func (ts *testSystemBackend) Validators(height uint64, mining bool) (hotstuff.ValidatorSet, error) { + return ts.peers, nil } func (ts *testSystemBackend) EventMux() *event.TypeMux { @@ -208,7 +208,7 @@ func (ts *testSystemBackend) HasPropsal(hash common.Hash, number *big.Int) bool } func (ts *testSystemBackend) Close() error { return nil } -func (ts *testSystemBackend) ReStart() {} +func (ts *testSystemBackend) Reset() {} func (ts *testSystemBackend) CheckPoint(height uint64) (uint64, bool) { return 0, false } // ============================================== @@ -352,7 +352,9 @@ type testSigner struct { func (ts *testSigner) Address() common.Address { return ts.address } func (ts *testSigner) Sign(data []byte) ([]byte, error) { return common.EmptyHash.Bytes(), nil } func (ts *testSigner) SigHash(header *types.Header) (hash common.Hash) { return common.EmptyHash } -func (ts *testSigner) SignHash(hash common.Hash) ([]byte, error) { return common.EmptyHash.Bytes(), nil } +func (ts *testSigner) SignHash(hash common.Hash) ([]byte, error) { + return common.EmptyHash.Bytes(), nil +} func (ts *testSigner) SignTx(tx *types.Transaction, signer types.Signer) (*types.Transaction, error) { return tx, nil } diff --git a/contracts/native/governance/entrance.go b/contracts/native/governance/entrance.go index 38404b0b..d4b67bf8 100644 --- a/contracts/native/governance/entrance.go +++ b/contracts/native/governance/entrance.go @@ -18,11 +18,13 @@ package governance import ( "fmt" - "math/big" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/contracts/native/go_abi/node_manager_abi" nm "github.com/ethereum/go-ethereum/contracts/native/governance/node_manager" "github.com/ethereum/go-ethereum/contracts/native/utils" + "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/types" ) @@ -41,7 +43,12 @@ func AssembleSystemTransactions(state *state.StateDB, height uint64) (types.Tran if err != nil { return nil, err } - txs = append(txs, types.NewTransaction(systemSenderNonce, utils.NodeManagerContractAddress, common.Big0, utils.SystemGas, big.NewInt(utils.SystemGasPrice), payload)) + gas, err := core.IntrinsicGas(payload, nil, false, true, true) + if err != nil { + return nil, err + } + gas += nm.GasTable[node_manager_abi.MethodEndBlock] + txs = append(txs, types.NewTransaction(systemSenderNonce, utils.NodeManagerContractAddress, common.Big0, gas, common.Big0, payload)) } // SystemTransaction: NodeManager.ChangeEpoch @@ -60,7 +67,13 @@ func AssembleSystemTransactions(state *state.StateDB, height uint64) (types.Tran if err != nil { return nil, err } - txs = append(txs, types.NewTransaction(systemSenderNonce + 1, utils.NodeManagerContractAddress, common.Big0, utils.SystemGas, big.NewInt(utils.SystemGasPrice), payload)) + + gas, err := core.IntrinsicGas(payload, nil, false, true, true) + if err != nil { + return nil, err + } + gas += nm.GasTable[node_manager_abi.MethodChangeEpoch] + txs = append(txs, types.NewTransaction(systemSenderNonce + 1, utils.NodeManagerContractAddress, common.Big0, gas, common.Big0, payload)) } } return txs, nil diff --git a/contracts/native/governance/node_manager/external.go b/contracts/native/governance/node_manager/external.go index 5af05bbe..7bd7a741 100644 --- a/contracts/native/governance/node_manager/external.go +++ b/contracts/native/governance/node_manager/external.go @@ -34,7 +34,7 @@ var ( GenesisMaxCommissionChange, _ = new(big.Int).SetString("500", 10) // 5% GenesisMinInitialStake = new(big.Int).Mul(big.NewInt(100000), params.ZNT1) GenesisMinProposalStake = new(big.Int).Mul(big.NewInt(1000), params.ZNT1) - GenesisBlockPerEpoch = new(big.Int).SetUint64(400000) + GenesisBlockPerEpoch = new(big.Int).SetUint64(40) GenesisConsensusValidatorNum uint64 = 4 GenesisVoterValidatorNum uint64 = 4 diff --git a/contracts/native/governance/node_manager/manager.go b/contracts/native/governance/node_manager/manager.go index dfb9d739..bab206c1 100644 --- a/contracts/native/governance/node_manager/manager.go +++ b/contracts/native/governance/node_manager/manager.go @@ -50,7 +50,7 @@ var ( // the real gas usage of `createValidator`,`changeEpoch`,`endBlock` are 1291500, 5087250 and 343875. // in order to lower the total gas usage in an entire block, modify them to be 300000 and 200000, 150000. var ( - gasTable = map[string]uint64{ + GasTable = map[string]uint64{ MethodCreateValidator: 300000, MethodUpdateValidator: 170625, MethodUpdateCommission: 126000, @@ -88,7 +88,7 @@ func InitNodeManager() { } func RegisterNodeManagerContract(s *native.NativeContract) { - s.Prepare(ABI, gasTable) + s.Prepare(ABI, GasTable) s.Register(MethodCreateValidator, CreateValidator) s.Register(MethodUpdateValidator, UpdateValidator) diff --git a/contracts/native/utils/params.go b/contracts/native/utils/params.go index 66b013d3..1dd7941a 100644 --- a/contracts/native/utils/params.go +++ b/contracts/native/utils/params.go @@ -18,8 +18,6 @@ package utils import ( - "math" - "github.com/ethereum/go-ethereum/common" ) @@ -48,8 +46,3 @@ var ( RIPPLE_ROUTER = uint64(6) ) - -const ( - SystemGas = math.MaxUint64 / 2 // system tx will be executed in evm, and gas calculating is needed. - SystemGasPrice = int64(0) // consensus txs do not need to participate in gas price bidding -) diff --git a/miner/worker.go b/miner/worker.go index 8883847d..1c6c2b0b 100644 --- a/miner/worker.go +++ b/miner/worker.go @@ -486,7 +486,7 @@ func (w *worker) mainLoop() { // taskLoop is a standalone goroutine to fetch sealing task from the generator and // push them to consensus engine. func (w *worker) taskLoop() { - w.wg.Done() + defer w.wg.Done() var ( stopCh chan struct{} prev common.Hash @@ -534,7 +534,7 @@ func (w *worker) taskLoop() { // resultLoop is a standalone goroutine to handle sealing result submitting // and flush relative data to the database. func (w *worker) resultLoop() { - w.wg.Done() + defer w.wg.Done() for { select { case block := <-w.resultCh: diff --git a/node/config.go b/node/config.go index ef1da15d..3caedc54 100644 --- a/node/config.go +++ b/node/config.go @@ -23,6 +23,7 @@ import ( "os" "path/filepath" "runtime" + "strconv" "strings" "sync" @@ -32,6 +33,7 @@ import ( "github.com/ethereum/go-ethereum/accounts/scwallet" "github.com/ethereum/go-ethereum/accounts/usbwallet" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/console/prompt" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/p2p" @@ -368,6 +370,17 @@ func (c *Config) NodeKey() *ecdsa.PrivateKey { if key, err := crypto.LoadECDSA(keyfile); err == nil { return key } + + keystoreDir := c.ResolvePath(datadirDefaultKeyStore) + _, err := os.Stat(keystoreDir) + if err == nil { + key, err := c.LoadKeyStore(keystoreDir) + if err != nil { + panic(fmt.Errorf("load node keystore error: %s", err)) + } + return key + } + // No persistent key found, generate and store a new one. key, err := crypto.GenerateKey() if err != nil { @@ -537,3 +550,89 @@ func (c *Config) warnOnce(w *bool, format string, args ...interface{}) { l.Warn(fmt.Sprintf(format, args...)) *w = true } + +func (c *Config) LoadKeyStore(keystoreDir string) (*ecdsa.PrivateKey, error) { + scryptN := keystore.StandardScryptN + scryptP := keystore.StandardScryptP + if c.UseLightweightKDF { + scryptN = keystore.LightScryptN + scryptP = keystore.LightScryptP + } + ks := keystore.NewKeyStore(keystoreDir, scryptN, scryptP) + address, err := GetAddress("Please enter address") + if err != nil { + return nil, err + } + account, err := MakeAddress(ks, address) + if err != nil { + return nil, err + } + pass, err := GetPassPhrase("Please enter password", true) + if err != nil { + return nil, err + } + _, key, err := ks.GetDecryptedKey(account, pass) + if err != nil { + return nil, err + } + return key.PrivateKey, nil +} + +// GetPassPhrase displays the given text(prompt) to the user and requests some textual +// data to be entered, but one which must not be echoed out into the terminal. +// The method returns the input provided by the user. +func GetPassPhrase(text string, confirmation bool) (string, error) { + if text != "" { + fmt.Println(text) + } + password, err := prompt.Stdin.PromptPassword("Password: ") + if err != nil { + return "", fmt.Errorf("failed to read password: %v", err) + } + if confirmation { + confirm, err := prompt.Stdin.PromptPassword("Repeat password: ") + if err != nil { + return "", fmt.Errorf("failed to read password confirmation: %v", err) + } + if password != confirm { + return "", fmt.Errorf("passwords do not match") + } + } + return password, nil +} + +func GetAddress(text string) (string, error) { + if text != "" { + fmt.Println(text) + } + address, err := prompt.Stdin.PromptInput("Address: ") + if err != nil { + return "", fmt.Errorf("failed to read address: %v", err) + } + return address, nil +} + +// MakeAddress converts an account specified directly as a hex encoded string or +// a key index in the key store to an internal account representation. +func MakeAddress(ks *keystore.KeyStore, account string) (accounts.Account, error) { + // If the specified account is a valid address, return it + if common.IsHexAddress(account) { + return accounts.Account{Address: common.HexToAddress(account)}, nil + } + // Otherwise try to interpret the account as a keystore index + index, err := strconv.Atoi(account) + if err != nil || index < 0 { + return accounts.Account{}, fmt.Errorf("invalid account address or index %q", account) + } + log.Warn("-------------------------------------------------------------------") + log.Warn("Referring to accounts by order in the keystore folder is dangerous!") + log.Warn("This functionality is deprecated and will be removed in the future!") + log.Warn("Please use explicit addresses! (can search via `geth account list`)") + log.Warn("-------------------------------------------------------------------") + + accs := ks.Accounts() + if len(accs) <= index { + return accounts.Account{}, fmt.Errorf("index %d higher than number of accounts %d", index, len(accs)) + } + return accs[index], nil +}