diff --git a/turbo/stages/headerdownload/header_algo_test.go b/turbo/stages/headerdownload/header_algo_test.go index 23a17fedf15..3e6d76d47ac 100644 --- a/turbo/stages/headerdownload/header_algo_test.go +++ b/turbo/stages/headerdownload/header_algo_test.go @@ -1,10 +1,12 @@ package headerdownload_test import ( + "bytes" "context" "math/big" "testing" + "github.com/ledgerwatch/erigon-lib/common" "github.com/ledgerwatch/erigon-lib/kv" "github.com/ledgerwatch/erigon/core" @@ -16,7 +18,7 @@ import ( "github.com/ledgerwatch/erigon/turbo/stages/mock" ) -func TestInserter1(t *testing.T) { +func TestSideChainInsert(t *testing.T) { funds := big.NewInt(1000000000) key, _ := crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291") address := crypto.PubkeyToAddress(key.PublicKey) @@ -40,24 +42,83 @@ func TestInserter1(t *testing.T) { defer tx.Rollback() br := m.BlockReader hi := headerdownload.NewHeaderInserter("headers", big.NewInt(0), 0, br) - h1 := types.Header{ - Number: big.NewInt(1), - Difficulty: big.NewInt(10), - ParentHash: genesis.Hash(), + + // Chain with higher initial difficulty + chain1 := createTestChain(3, genesis.Hash(), 2, []byte("")) + + // Smaller side chain (non-canonical) + chain2 := createTestChain(5, genesis.Hash(), 1, []byte("side1")) + + // Bigger side chain (canonical) + chain3 := createTestChain(7, genesis.Hash(), 1, []byte("side2")) + + // Again smaller side chain but with high difficulty (canonical) + chain4 := createTestChain(5, genesis.Hash(), 2, []byte("side3")) + + // More smaller side chain with same difficulty (canonical) + chain5 := createTestChain(2, genesis.Hash(), 5, []byte("side5")) + + // Bigger side chain with same difficulty (non-canonical) + chain6 := createTestChain(10, genesis.Hash(), 1, []byte("side6")) + + // Same side chain (in terms of number and difficulty) but different hash + chain7 := createTestChain(2, genesis.Hash(), 5, []byte("side7")) + + finalExpectedHash := chain5[len(chain5)-1].Hash() + if bytes.Compare(chain5[len(chain5)-1].Hash().Bytes(), chain7[len(chain7)-1].Hash().Bytes()) < 0 { + finalExpectedHash = chain7[len(chain7)-1].Hash() } - h1Hash := h1.Hash() - h2 := types.Header{ - Number: big.NewInt(2), - Difficulty: big.NewInt(1010), - ParentHash: h1Hash, + + testCases := []struct { + name string + chain []types.Header + expectedHash common.Hash + expectedDiff int64 + }{ + {"normal initial insert", chain1, chain1[len(chain1)-1].Hash(), 6}, + {"td(current) > td(incoming)", chain2, chain1[len(chain1)-1].Hash(), 6}, + {"td(incoming) > td(current), number(incoming) > number(current)", chain3, chain3[len(chain3)-1].Hash(), 7}, + {"td(incoming) > td(current), number(current) > number(incoming)", chain4, chain4[len(chain4)-1].Hash(), 10}, + {"td(incoming) = td(current), number(current) > number(current)", chain5, chain5[len(chain5)-1].Hash(), 10}, + {"td(incoming) = td(current), number(incoming) > number(current)", chain6, chain5[len(chain5)-1].Hash(), 10}, + {"td(incoming) = td(current), number(incoming) = number(current), hash different", chain7, finalExpectedHash, 10}, } - h2Hash := h2.Hash() - data1, _ := rlp.EncodeToBytes(&h1) - if _, err = hi.FeedHeaderPoW(tx, br, &h1, data1, h1Hash, 1); err != nil { - t.Errorf("feed empty header 1: %v", err) + + for _, tc := range testCases { + tc := tc + for i, h := range tc.chain { + h := h + data, _ := rlp.EncodeToBytes(&h) + if _, err = hi.FeedHeaderPoW(tx, br, &h, data, h.Hash(), uint64(i+1)); err != nil { + t.Errorf("feed empty header for %s, err: %v", tc.name, err) + } + } + + if hi.GetHighestHash() != tc.expectedHash { + t.Errorf("incorrect highest hash for %s, expected %s, got %s", tc.name, tc.expectedHash, hi.GetHighestHash()) + } + if hi.GetLocalTd().Int64() != tc.expectedDiff { + t.Errorf("incorrect difficulty for %s, expected %d, got %d", tc.name, tc.expectedDiff, hi.GetLocalTd().Int64()) + } } - data2, _ := rlp.EncodeToBytes(&h2) - if _, err = hi.FeedHeaderPoW(tx, br, &h2, data2, h2Hash, 2); err != nil { - t.Errorf("feed empty header 2: %v", err) +} + +func createTestChain(length int64, parent common.Hash, diff int64, extra []byte) []types.Header { + var ( + i int64 + headers []types.Header + ) + + for i = 0; i < length; i++ { + h := types.Header{ + Number: big.NewInt(i + 1), + Difficulty: big.NewInt(diff), + ParentHash: parent, + Extra: extra, + } + headers = append(headers, h) + parent = h.Hash() } + + return headers } diff --git a/turbo/stages/headerdownload/header_algos.go b/turbo/stages/headerdownload/header_algos.go index c98b0a969e4..f9b64b074ea 100644 --- a/turbo/stages/headerdownload/header_algos.go +++ b/turbo/stages/headerdownload/header_algos.go @@ -896,24 +896,40 @@ func (hi *HeaderInserter) FeedHeaderPoW(db kv.StatelessRwTx, headerReader servic } // Calculate total difficulty of this header using parent's total difficulty td = new(big.Int).Add(parentTd, header.Difficulty) + // Now we can decide wether this header will create a change in the canonical head - if td.Cmp(hi.localTd) > 0 { - hi.newCanonical = true - forkingPoint, err := hi.ForkingPoint(db, header, parent) - if err != nil { - return nil, err + if td.Cmp(hi.localTd) >= 0 { + reorg := true + + // TODO: Add bor check here if required + // Borrowed from https://github.com/maticnetwork/bor/blob/master/core/forkchoice.go#L81 + if td.Cmp(hi.localTd) == 0 { + if blockHeight > hi.highest { + reorg = false + } else if blockHeight == hi.highest { + // Compare hashes of block in case of tie breaker. Lexicographically larger hash wins. + reorg = bytes.Compare(hi.highestHash.Bytes(), hash.Bytes()) < 0 + } } - hi.highest = blockHeight - hi.highestHash = hash - hi.highestTimestamp = header.Time - hi.canonicalCache.Add(blockHeight, hash) - // See if the forking point affects the unwindPoint (the block number to which other stages will need to unwind before the new canonical chain is applied) - if forkingPoint < hi.unwindPoint { - hi.unwindPoint = forkingPoint - hi.unwind = true + + if reorg { + hi.newCanonical = true + forkingPoint, err := hi.ForkingPoint(db, header, parent) + if err != nil { + return nil, err + } + hi.highest = blockHeight + hi.highestHash = hash + hi.highestTimestamp = header.Time + hi.canonicalCache.Add(blockHeight, hash) + // See if the forking point affects the unwindPoint (the block number to which other stages will need to unwind before the new canonical chain is applied) + if forkingPoint < hi.unwindPoint { + hi.unwindPoint = forkingPoint + hi.unwind = true + } + // This makes sure we end up choosing the chain with the max total difficulty + hi.localTd.Set(td) } - // This makes sure we end up choosing the chain with the max total difficulty - hi.localTd.Set(td) } if err = rawdb.WriteTd(db, hash, blockHeight, td); err != nil { return nil, fmt.Errorf("[%s] failed to WriteTd: %w", hi.logPrefix, err) @@ -950,6 +966,10 @@ func (hi *HeaderInserter) FeedHeaderPoS(db kv.RwTx, header *types.Header, hash l return nil } +func (hi *HeaderInserter) GetLocalTd() *big.Int { + return hi.localTd +} + func (hi *HeaderInserter) GetHighest() uint64 { return hi.highest }