diff --git a/core/forkchoice.go b/core/forkchoice.go index 2fb3994dac..248dc36b07 100644 --- a/core/forkchoice.go +++ b/core/forkchoice.go @@ -17,6 +17,7 @@ package core import ( + "bytes" "errors" "math/big" @@ -114,7 +115,8 @@ func (f *ForkChoice) ReorgNeeded(current *types.Header, extern *types.Header) (b currentPreserve, externPreserve = f.preserve(current), f.preserve(extern) } - reorg = !currentPreserve && (externPreserve || f.rand.Float64() < 0.5) + // Compare hashes of block in case of tie breaker. Lexicographically larger hash wins. + reorg = !currentPreserve && (externPreserve || bytes.Compare(current.Hash().Bytes(), extern.Hash().Bytes()) < 0) } return reorg, nil diff --git a/core/forkchoice_test.go b/core/forkchoice_test.go index 857081e578..4b4a1343a1 100644 --- a/core/forkchoice_test.go +++ b/core/forkchoice_test.go @@ -10,6 +10,8 @@ import ( "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/trie" + + "github.com/stretchr/testify/require" ) // chainValidatorFake is a mock for the chain validator service @@ -30,6 +32,60 @@ func newChainReaderFake(getTd func(hash common.Hash, number uint64) *big.Int) *c return &chainReaderFake{getTd: getTd} } +// nolint: tparallel +func TestForkChoice(t *testing.T) { + t.Parallel() + + // Create mocks for forker + getTd := func(hash common.Hash, number uint64) *big.Int { + if number <= 2 { + return big.NewInt(int64(number)) + } + + return big.NewInt(0) + } + mockChainReader := newChainReaderFake(getTd) + mockForker := NewForkChoice(mockChainReader, nil, nil) + + createHeader := func(number int64, extra []byte) *types.Header { + return &types.Header{ + Number: big.NewInt(number), + Extra: extra, + } + } + + // Create headers for different cases + headerA := createHeader(1, []byte("A")) + headerB := createHeader(2, []byte("B")) + headerC := createHeader(3, []byte("C")) + headerD := createHeader(4, []byte("D")) // 0x96b0f70c01f4d2b1ee2df5b0202c099776f24c9375ffc89d94b880007633961b (hash) + headerE := createHeader(4, []byte("E")) // 0xdc0acf54354ff86194baeaab983098a49a40218cffcc77a583726fc06c429685 (hash) + + testCases := []struct { + name string + current *types.Header + incoming *types.Header + want bool + }{ + {"tdd(incoming) > tdd(current)", headerA, headerB, true}, + {"tdd(current) > tdd(incoming)", headerB, headerA, false}, + {"tdd(current) = tdd(incoming), number(incoming) > number(current)", headerC, headerD, false}, + {"tdd(current) = tdd(incoming), number(current) > number(incoming)", headerD, headerC, true}, + {"tdd(current) = tdd(incoming), number(current) = number(incoming), hash(current) > hash(incoming)", headerE, headerD, false}, + {"tdd(current) = tdd(incoming), number(current) = number(incoming), hash(incoming) > hash(current)", headerD, headerE, true}, + } + + // nolint: paralleltest + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + res, err := mockForker.ReorgNeeded(tc.current, tc.incoming) + require.Equal(t, tc.want, res, tc.name) + require.NoError(t, err, tc.name) + }) + } +} + func TestPastChainInsert(t *testing.T) { t.Parallel()