diff --git a/signer/Cosigner.go b/signer/Cosigner.go index 15ad514f..ea937cbe 100644 --- a/signer/Cosigner.go +++ b/signer/Cosigner.go @@ -75,8 +75,4 @@ type Cosigner interface { // Sign the requested bytes SetEphemeralSecretPartsAndSign(req CosignerSetEphemeralSecretPartsAndSignRequest) (*CosignerSignResponse, error) - - // Request that the cosigner manage the threshold signing process for this block - // Will throw error if cosigner is not the leader - SignBlock(req CosignerSignBlockRequest) (CosignerSignBlockResponse, error) } diff --git a/signer/local_cosigner.go b/signer/local_cosigner.go index ea3343ec..8e8078c5 100644 --- a/signer/local_cosigner.go +++ b/signer/local_cosigner.go @@ -474,10 +474,6 @@ func (cosigner *LocalCosigner) setEphemeralSecretPart(req CosignerSetEphemeralSe return nil } -func (cosigner *LocalCosigner) SignBlock(req CosignerSignBlockRequest) (CosignerSignBlockResponse, error) { - return CosignerSignBlockResponse{}, errors.New("not implemented") -} - func (cosigner *LocalCosigner) SetEphemeralSecretPartsAndSign( req CosignerSetEphemeralSecretPartsAndSignRequest) (*CosignerSignResponse, error) { for _, secretPart := range req.EncryptedSecrets { diff --git a/signer/raft_events.go b/signer/raft_events.go index 4d69cac6..29d54280 100644 --- a/signer/raft_events.go +++ b/signer/raft_events.go @@ -2,7 +2,9 @@ package signer import ( "encoding/json" + "errors" "fmt" + "strings" ) const ( @@ -31,28 +33,39 @@ func (f *fsm) handleLSSEvent(value string) { _ = f.cosigner.SaveLastSignedState(*lss) } -func (s *RaftStore) GetLeaderCosigner() (Cosigner, error) { +func (s *RaftStore) getLeaderRPCAddress() (string, error) { leader := string(s.GetLeader()) + if leader == "" { + return "", errors.New("no current raft leader") + } + // If the same RPC port is used for all peers, we can just use the leader address on that port + if s.commonRPCPort != "" { + leaderSplit := strings.Split(leader, ":") + if len(leaderSplit) == 2 { + return fmt.Sprintf("tcp://%s:%s", leaderSplit[0], s.commonRPCPort), nil + } + } for _, peer := range s.Peers { if peer.GetRaftAddress() == leader { - return peer, nil + return peer.GetAddress(), nil } tcpAddress, err := GetTCPAddressForRaftAddress(peer.GetRaftAddress()) if err != nil { continue } if fmt.Sprint(tcpAddress) == leader { - return peer, nil + return peer.GetAddress(), nil } } - return nil, fmt.Errorf("unable to find leader cosigner from address %s", leader) + + return "", fmt.Errorf("unable to find leader cosigner from address %s", leader) } -func (s *RaftStore) LeaderSignBlock(req CosignerSignBlockRequest) (*CosignerSignBlockResponse, error) { - leaderCosigner, err := s.GetLeaderCosigner() +func (s *RaftStore) LeaderSignBlock(req CosignerSignBlockRequest) (res *CosignerSignBlockResponse, err error) { + leaderCosigner, err := s.getLeaderRPCAddress() if err != nil { return nil, err } - res, err := leaderCosigner.SignBlock(req) - return &res, err + + return res, CallRPC(leaderCosigner, "SignBlock", req, &res) } diff --git a/signer/raft_store.go b/signer/raft_store.go index 018a220c..96f862ff 100644 --- a/signer/raft_store.go +++ b/signer/raft_store.go @@ -13,6 +13,7 @@ import ( "io" "net" "os" + "strings" "sync" "time" @@ -49,6 +50,22 @@ type RaftStore struct { logger log.Logger cosigner *LocalCosigner thresholdValidator *ThresholdValidator + commonRPCPort string +} + +// OnStart starts the raft server +func getCommonRPCPort(peers []Cosigner) string { + var rpcPort string + for i, peer := range peers { + if i == 0 { + rpcPort = strings.Split(peer.GetAddress(), ":")[2] + continue + } + if strings.Split(peer.GetAddress(), ":")[2] != rpcPort { + return "" + } + } + return rpcPort } // New returns a new Store. @@ -56,14 +73,15 @@ func NewRaftStore( nodeID string, directory string, bindAddress string, timeout time.Duration, logger log.Logger, cosigner *LocalCosigner, raftPeers []Cosigner) *RaftStore { cosignerRaftStore := &RaftStore{ - NodeID: nodeID, - RaftDir: directory, - RaftBind: bindAddress, - RaftTimeout: timeout, - m: make(map[string]string), - logger: logger, - cosigner: cosigner, - Peers: raftPeers, + NodeID: nodeID, + RaftDir: directory, + RaftBind: bindAddress, + RaftTimeout: timeout, + m: make(map[string]string), + logger: logger, + cosigner: cosigner, + Peers: raftPeers, + commonRPCPort: getCommonRPCPort(raftPeers), } cosignerRaftStore.BaseService = *service.NewBaseService(logger, "CosignerRaftStore", cosignerRaftStore) diff --git a/signer/remote_cosigner.go b/signer/remote_cosigner.go index 0d9a6f69..aad8474f 100644 --- a/signer/remote_cosigner.go +++ b/signer/remote_cosigner.go @@ -41,11 +41,6 @@ func (cosigner *RemoteCosigner) GetEphemeralSecretParts( return res, CallRPC(cosigner.address, "GetEphemeralSecretParts", req, &res) } -// Implements the cosigner interface -func (cosigner *RemoteCosigner) SignBlock(req CosignerSignBlockRequest) (res CosignerSignBlockResponse, err error) { - return res, CallRPC(cosigner.address, "SignBlock", req, &res) -} - // Implements the cosigner interface func (cosigner *RemoteCosigner) SetEphemeralSecretPartsAndSign( req CosignerSetEphemeralSecretPartsAndSignRequest) (res *CosignerSignResponse, err error) { diff --git a/signer/sign_state.go b/signer/sign_state.go index 376da108..06595d8e 100644 --- a/signer/sign_state.go +++ b/signer/sign_state.go @@ -20,6 +20,7 @@ const ( stepPropose int8 = 1 stepPrevote int8 = 2 stepPrecommit int8 = 3 + blocksToCache = 3 ) func CanonicalVoteToStep(vote *tmProto.CanonicalVote) int8 { @@ -56,6 +57,7 @@ type SignState struct { EphemeralPublic []byte `json:"ephemeral_public"` Signature []byte `json:"signature,omitempty"` SignBytes tmBytes.HexBytes `json:"signbytes,omitempty"` + cache map[HRSKey]SignStateConsensus filePath string } @@ -76,6 +78,22 @@ func NewSignStateConsensus(height int64, round int64, step int8) SignStateConsen } } +func (signState *SignState) GetFromCache(hrs HRSKey, lock *sync.Mutex) (HRSKey, *SignStateConsensus) { + if lock != nil { + lock.Lock() + defer lock.Unlock() + } + latestBlock := HRSKey{ + Height: signState.Height, + Round: signState.Round, + Step: signState.Step, + } + if ssc, ok := signState.cache[hrs]; ok { + return latestBlock, &ssc + } + return latestBlock, nil +} + func (signState *SignState) Save(ssc SignStateConsensus, lock *sync.Mutex) error { // One lock/unlock for less/equal check and mutation. // Setting nil for lock for getErrorIfLessOrEqual to avoid recursive lock @@ -90,6 +108,13 @@ func (signState *SignState) Save(ssc SignStateConsensus, lock *sync.Mutex) error } // HRS is greater than existing state, allow + signState.cache[HRSKey{Height: ssc.Height, Round: ssc.Round, Step: ssc.Step}] = ssc + for hrs := range signState.cache { + if hrs.Height < ssc.Height-blocksToCache { + delete(signState.cache, hrs) + } + } + signState.Height = ssc.Height signState.Round = ssc.Round signState.Step = ssc.Step @@ -156,6 +181,18 @@ func (signState *SignState) CheckHRS(hrs HRSKey) (bool, error) { return false, nil } +type SameHRSError struct { + msg string +} + +func (e *SameHRSError) Error() string { return e.msg } + +func newSameHRSError(hrs HRSKey) *SameHRSError { + return &SameHRSError{ + msg: fmt.Sprintf("HRS is the same as current: %d:%d:%d", hrs.Height, hrs.Round, hrs.Step), + } +} + func (signState *SignState) GetErrorIfLessOrEqual(height int64, round int64, step int8, lock *sync.Mutex) error { if lock != nil { lock.Lock() @@ -184,8 +221,8 @@ func (signState *SignState) GetErrorIfLessOrEqual(height int64, round int64, ste return errors.New("step regression not allowed") } if step == signState.Step { - // same HRS as current! - return errors.New("not allowing double sign of current latest HRS") + // same HRS as current + return newSameHRSError(HRSKey{Height: height, Round: round, Step: step}) } // Step is greater, so all good return nil @@ -203,6 +240,14 @@ func LoadSignState(filepath string) (SignState, error) { if err != nil { return state, err } + state.cache = make(map[HRSKey]SignStateConsensus) + state.cache[HRSKey{Height: state.Height, Round: state.Round, Step: state.Step}] = SignStateConsensus{ + Height: state.Height, + Round: state.Round, + Step: state.Step, + Signature: state.Signature, + SignBytes: state.SignBytes, + } state.filePath = filepath return state, nil } @@ -220,6 +265,7 @@ func LoadOrCreateSignState(filepath string) (SignState, error) { // Make an empty sign state and save it state := SignState{} state.filePath = filepath + state.cache = make(map[HRSKey]SignStateConsensus) state.save() return state, nil } @@ -236,6 +282,16 @@ func (signState *SignState) OnlyDifferByTimestamp(signBytes []byte) (time.Time, return time.Time{}, false } +func (signState *SignStateConsensus) OnlyDifferByTimestamp(signBytes []byte) (time.Time, bool) { + if signState.Step == stepPropose { + return checkProposalOnlyDifferByTimestamp(signState.SignBytes, signBytes) + } else if signState.Step == stepPrevote || signState.Step == stepPrecommit { + return checkVoteOnlyDifferByTimestamp(signState.SignBytes, signBytes) + } + + return time.Time{}, false +} + func checkVoteOnlyDifferByTimestamp(lastSignBytes, newSignBytes []byte) (time.Time, bool) { var lastVote, newVote tmProto.CanonicalVote if err := protoio.UnmarshalDelimited(lastSignBytes, &lastVote); err != nil { diff --git a/signer/threshold_validator.go b/signer/threshold_validator.go index 7c3df0d1..9b304d23 100644 --- a/signer/threshold_validator.go +++ b/signer/threshold_validator.go @@ -65,6 +65,7 @@ func NewThresholdValidator(opt *ThresholdValidatorOpt) *ThresholdValidator { Round: opt.SignState.Round, Step: opt.SignState.Step, filePath: "none", + cache: make(map[HRSKey]SignStateConsensus), } validator.lastSignStateInitiatedMutex = sync.Mutex{} validator.raftStore = opt.RaftStore @@ -263,36 +264,39 @@ func (pv *ThresholdValidator) SignBlock(chainID string, block *block) ([]byte, t pv.logger.Debug("I am the raft leader. Managing the sign process for this block") - // the block sign state for caching full block signatures - lss := pv.lastSignState - hrs := HRSKey{ Height: height, Round: round, Step: step, } - // check watermark - sameHRS, err := lss.CheckHRS(hrs) - if err != nil { - return nil, stamp, err - } + signBytes := block.SignBytes // Keep track of the last block that we began the signing process for. Only allow one attempt per block if err := pv.SaveLastSignedStateInitiated(NewSignStateConsensus(height, round, step)); err != nil { - return nil, stamp, pv.newBeyondBlockError(hrs) - } - - signBytes := block.SignBytes - - if sameHRS { - if bytes.Equal(signBytes, lss.SignBytes) { - return lss.Signature, block.Timestamp, nil - } else if timestamp, ok := lss.OnlyDifferByTimestamp(signBytes); ok { - return lss.Signature, timestamp, nil + switch err.(type) { + case *SameHRSError: + // Wait for last sign state signature to be the same block + for i := 0; i < 100; i++ { + time.Sleep(10 * time.Millisecond) + latestBlock, existingSignature := pv.lastSignState.GetFromCache(hrs, &pv.lastSignStateMutex) + if existingSignature != nil { + if bytes.Equal(signBytes, existingSignature.SignBytes) { + return existingSignature.Signature, block.Timestamp, nil + } else if timestamp, ok := existingSignature.OnlyDifferByTimestamp(signBytes); ok { + return existingSignature.Signature, timestamp, nil + } + return nil, stamp, errors.New("conflicting data") + } else if latestBlock.Height > height || + (latestBlock.Height == height && latestBlock.Round > round) || + (latestBlock.Height == height && latestBlock.Round == round && latestBlock.Step > step) { + return nil, stamp, pv.newBeyondBlockError(hrs) + } + } + return nil, stamp, errors.New("timed out waiting for block signature from cluster") + default: + return nil, stamp, pv.newBeyondBlockError(hrs) } - - return nil, stamp, errors.New("conflicting data") } numPeers := len(pv.peers)