From 47dada2df0163f007dbfaba986761a9a14ba0159 Mon Sep 17 00:00:00 2001 From: Matthew Slipper Date: Tue, 14 Mar 2023 21:43:08 -0600 Subject: [PATCH] op-chain-ops: Add parallel migration After some research, I discovered that we can iterate over Geth's storage in parallel as long as we don't share the state database. This PR updates the OVM ETH check script to parallelize state iteration across 64 workers. The parallelization works by partitioning storage keyspace based on the number of workers. To optimize further, I also moved the state balance checking into the main iteration loop to avoid unnecessary iteration. Note that state mutation must be done serially. Overall, this should safe about 40 minutes during the migration. This PR was tested by running it against a mainnet data directory. The entire process took approximately one hour. As part of this testing I discovered an invalid check in `check.go`, which was verifying the wrong storage slot in the withdrawals check function. This has been fixed.I confirmed with Mark that the updated `MessageSender` value is the correct one to be checking for. The filtering code was correct from the beginning. --- op-chain-ops/cmd/check-migration/main.go | 210 +++++++++++++++ op-chain-ops/cmd/op-migrate/main.go | 5 + op-chain-ops/ether/cli.go | 4 +- op-chain-ops/ether/migrate.go | 23 +- op-chain-ops/ether/precheck.go | 322 +++++++++++++++++------ op-chain-ops/ether/precheck_test.go | 295 +++++++++++++++++++++ op-chain-ops/genesis/check.go | 8 +- op-chain-ops/genesis/db_migration.go | 29 +- 8 files changed, 786 insertions(+), 110 deletions(-) create mode 100644 op-chain-ops/cmd/check-migration/main.go create mode 100644 op-chain-ops/ether/precheck_test.go diff --git a/op-chain-ops/cmd/check-migration/main.go b/op-chain-ops/cmd/check-migration/main.go new file mode 100644 index 000000000..eb08db43d --- /dev/null +++ b/op-chain-ops/cmd/check-migration/main.go @@ -0,0 +1,210 @@ +package main + +import ( + "context" + "fmt" + "math/big" + "os" + "strings" + + "github.com/ethereum-optimism/optimism/op-chain-ops/crossdomain" + + "github.com/ethereum-optimism/optimism/op-chain-ops/db" + "github.com/mattn/go-isatty" + + "github.com/ethereum-optimism/optimism/op-node/eth" + "github.com/ethereum-optimism/optimism/op-node/rollup/derive" + "github.com/ethereum/go-ethereum/common" + + "github.com/ethereum/go-ethereum/core/types" + + "github.com/ethereum/go-ethereum/log" + + "github.com/ethereum-optimism/optimism/op-bindings/hardhat" + + "github.com/ethereum-optimism/optimism/op-chain-ops/genesis" + "github.com/ethereum/go-ethereum/ethclient" + + "github.com/urfave/cli" +) + +func main() { + log.Root().SetHandler(log.StreamHandler(os.Stderr, log.TerminalFormat(isatty.IsTerminal(os.Stderr.Fd())))) + + app := &cli.App{ + Name: "check-migration", + Usage: "Run sanity checks on a migrated database", + Flags: []cli.Flag{ + &cli.StringFlag{ + Name: "l1-rpc-url", + Value: "http://127.0.0.1:8545", + Usage: "RPC URL for an L1 Node", + Required: true, + }, + &cli.StringFlag{ + Name: "ovm-addresses", + Usage: "Path to ovm-addresses.json", + Required: true, + }, + &cli.StringFlag{ + Name: "ovm-allowances", + Usage: "Path to ovm-allowances.json", + Required: true, + }, + &cli.StringFlag{ + Name: "ovm-messages", + Usage: "Path to ovm-messages.json", + Required: true, + }, + &cli.StringFlag{ + Name: "witness-file", + Usage: "Path to witness file", + Required: true, + }, + &cli.StringFlag{ + Name: "db-path", + Usage: "Path to database", + Required: true, + }, + cli.StringFlag{ + Name: "deploy-config", + Usage: "Path to hardhat deploy config file", + Required: true, + }, + cli.StringFlag{ + Name: "network", + Usage: "Name of hardhat deploy network", + Required: true, + }, + cli.StringFlag{ + Name: "hardhat-deployments", + Usage: "Comma separated list of hardhat deployment directories", + Required: true, + }, + cli.IntFlag{ + Name: "db-cache", + Usage: "LevelDB cache size in mb", + Value: 1024, + }, + cli.IntFlag{ + Name: "db-handles", + Usage: "LevelDB number of handles", + Value: 60, + }, + }, + Action: func(ctx *cli.Context) error { + deployConfig := ctx.String("deploy-config") + config, err := genesis.NewDeployConfig(deployConfig) + if err != nil { + return err + } + + ovmAddresses, err := crossdomain.NewAddresses(ctx.String("ovm-addresses")) + if err != nil { + return err + } + ovmAllowances, err := crossdomain.NewAllowances(ctx.String("ovm-allowances")) + if err != nil { + return err + } + ovmMessages, err := crossdomain.NewSentMessageFromJSON(ctx.String("ovm-messages")) + if err != nil { + return err + } + evmMessages, evmAddresses, err := crossdomain.ReadWitnessData(ctx.String("witness-file")) + if err != nil { + return err + } + + log.Info( + "Loaded witness data", + "ovmAddresses", len(ovmAddresses), + "evmAddresses", len(evmAddresses), + "ovmAllowances", len(ovmAllowances), + "ovmMessages", len(ovmMessages), + "evmMessages", len(evmMessages), + ) + + migrationData := crossdomain.MigrationData{ + OvmAddresses: ovmAddresses, + EvmAddresses: evmAddresses, + OvmAllowances: ovmAllowances, + OvmMessages: ovmMessages, + EvmMessages: evmMessages, + } + + network := ctx.String("network") + deployments := strings.Split(ctx.String("hardhat-deployments"), ",") + hh, err := hardhat.New(network, []string{}, deployments) + if err != nil { + return err + } + + l1RpcURL := ctx.String("l1-rpc-url") + l1Client, err := ethclient.Dial(l1RpcURL) + if err != nil { + return err + } + + var block *types.Block + tag := config.L1StartingBlockTag + if tag.BlockNumber != nil { + block, err = l1Client.BlockByNumber(context.Background(), big.NewInt(tag.BlockNumber.Int64())) + } else if tag.BlockHash != nil { + block, err = l1Client.BlockByHash(context.Background(), *tag.BlockHash) + } else { + return fmt.Errorf("invalid l1StartingBlockTag in deploy config: %v", tag) + } + if err != nil { + return err + } + + dbCache := ctx.Int("db-cache") + dbHandles := ctx.Int("db-handles") + + // Read the required deployment addresses from disk if required + if err := config.GetDeployedAddresses(hh); err != nil { + return err + } + + if err := config.Check(); err != nil { + return err + } + + postLDB, err := db.Open(ctx.String("db-path"), dbCache, dbHandles) + if err != nil { + return err + } + + if err := genesis.PostCheckMigratedDB( + postLDB, + migrationData, + &config.L1CrossDomainMessengerProxy, + config.L1ChainID, + config.FinalSystemOwner, + config.ProxyAdminOwner, + &derive.L1BlockInfo{ + Number: block.NumberU64(), + Time: block.Time(), + BaseFee: block.BaseFee(), + BlockHash: block.Hash(), + BatcherAddr: config.BatchSenderAddress, + L1FeeOverhead: eth.Bytes32(common.BigToHash(new(big.Int).SetUint64(config.GasPriceOracleOverhead))), + L1FeeScalar: eth.Bytes32(common.BigToHash(new(big.Int).SetUint64(config.GasPriceOracleScalar))), + }, + ); err != nil { + return err + } + + if err := postLDB.Close(); err != nil { + return err + } + + return nil + }, + } + + if err := app.Run(os.Args); err != nil { + log.Crit("error in migration", "err", err) + } +} diff --git a/op-chain-ops/cmd/op-migrate/main.go b/op-chain-ops/cmd/op-migrate/main.go index 634b70fdf..2d0db398e 100644 --- a/op-chain-ops/cmd/op-migrate/main.go +++ b/op-chain-ops/cmd/op-migrate/main.go @@ -106,6 +106,11 @@ func main() { Value: "rollup.json", Required: true, }, + cli.BoolFlag{ + Name: "post-check-only", + Usage: "Only perform sanity checks", + Required: false, + }, }, Action: func(ctx *cli.Context) error { deployConfig := ctx.String("deploy-config") diff --git a/op-chain-ops/ether/cli.go b/op-chain-ops/ether/cli.go index 522bbb1d8..b345ad950 100644 --- a/op-chain-ops/ether/cli.go +++ b/op-chain-ops/ether/cli.go @@ -25,8 +25,8 @@ func GetOVMETHTotalSupplySlot() common.Hash { return getOVMETHTotalSupplySlot() } -// getOVMETHBalance gets a user's OVM ETH balance from state by querying the +// GetOVMETHBalance gets a user's OVM ETH balance from state by querying the // appropriate storage slot directly. -func getOVMETHBalance(db *state.StateDB, addr common.Address) *big.Int { +func GetOVMETHBalance(db *state.StateDB, addr common.Address) *big.Int { return db.GetState(OVMETHAddress, CalcOVMETHStorageKey(addr)).Big() } diff --git a/op-chain-ops/ether/migrate.go b/op-chain-ops/ether/migrate.go index 6ab25cbe8..6c9b4999d 100644 --- a/op-chain-ops/ether/migrate.go +++ b/op-chain-ops/ether/migrate.go @@ -29,7 +29,9 @@ var ( } ) -func MigrateLegacyETH(db *state.StateDB, addresses []common.Address, chainID int, noCheck bool) error { +type FilteredOVMETHAddresses []common.Address + +func MigrateLegacyETH(db *state.StateDB, addresses FilteredOVMETHAddresses, chainID int, noCheck bool) error { // Chain params to use for integrity checking. params := crossdomain.ParamsByChainID[chainID] if params == nil { @@ -39,28 +41,15 @@ func MigrateLegacyETH(db *state.StateDB, addresses []common.Address, chainID int // Log the chain params for debugging purposes. log.Info("Chain params", "chain-id", chainID, "supply-delta", params.ExpectedSupplyDelta) - // Deduplicate the list of addresses by converting to a map. - deduped := make(map[common.Address]bool) - for _, addr := range addresses { - deduped[addr] = true - } - // Migrate the legacy ETH to ETH. log.Info("Migrating legacy ETH to ETH", "num-accounts", len(addresses)) totalMigrated := new(big.Int) logAccountProgress := util.ProgressLogger(1000, "imported accounts") - for addr := range deduped { - // No accounts should have a balance in state. If they do, bail. - if db.GetBalance(addr).Sign() > 0 { - if noCheck { - log.Error("account has non-zero balance in state - should never happen", "addr", addr) - } else { - log.Crit("account has non-zero balance in state - should never happen", "addr", addr) - } - } + for _, addr := range addresses { + // Balances are pre-checked not have any balances in state. // Pull out the OVM ETH balance. - ovmBalance := getOVMETHBalance(db, addr) + ovmBalance := GetOVMETHBalance(db, addr) // Actually perform the migration by setting the appropriate values in state. db.SetBalance(addr, ovmBalance) diff --git a/op-chain-ops/ether/precheck.go b/op-chain-ops/ether/precheck.go index be9ff9dae..21c45a9ab 100644 --- a/op-chain-ops/ether/precheck.go +++ b/op-chain-ops/ether/precheck.go @@ -1,9 +1,12 @@ package ether import ( - "errors" "fmt" "math/big" + "sync" + + "github.com/ethereum/go-ethereum/rlp" + "github.com/ethereum/go-ethereum/trie" "github.com/ethereum-optimism/optimism/op-chain-ops/crossdomain" "github.com/ethereum-optimism/optimism/op-chain-ops/util" @@ -11,132 +14,295 @@ import ( "github.com/ethereum-optimism/optimism/op-bindings/predeploys" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/state" - "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/log" ) +const ( + // checkJobs is the number of parallel workers to spawn + // when iterating the storage trie. + checkJobs = 64 +) + +// maxSlot is the maximum possible storage slot. +var maxSlot = common.HexToHash("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") + +// accountData is a wrapper struct that contains the balance and address of an account. +// It gets passed via channel to the collector process. +type accountData struct { + balance *big.Int + address common.Address +} + +type DBFactory func() (*state.StateDB, error) + // PreCheckBalances checks that the given list of addresses and allowances represents all storage // slots in the LegacyERC20ETH contract. We don't have to filter out extra addresses like we do for // withdrawals because we'll simply carry the balance of a given address to the new system, if the // account is extra then it won't have any balance and nothing will happen. -func PreCheckBalances(ldb ethdb.Database, db *state.StateDB, addresses []common.Address, allowances []*crossdomain.Allowance, chainID int, noCheck bool) ([]common.Address, error) { +func PreCheckBalances(dbFactory DBFactory, addresses []common.Address, allowances []*crossdomain.Allowance, chainID int, noCheck bool) (FilteredOVMETHAddresses, error) { // Chain params to use for integrity checking. params := crossdomain.ParamsByChainID[chainID] if params == nil { return nil, fmt.Errorf("no chain params for %d", chainID) } + return doMigration(dbFactory, addresses, allowances, params.ExpectedSupplyDelta, noCheck) +} + +func doMigration(dbFactory DBFactory, addresses []common.Address, allowances []*crossdomain.Allowance, expDiff *big.Int, noCheck bool) (FilteredOVMETHAddresses, error) { // We'll need to maintain a list of all addresses that we've seen along with all of the storage // slots based on the witness data. addrs := make([]common.Address, 0) + slotsAddrs := make(map[common.Hash]common.Address) slotsInp := make(map[common.Hash]int) // For each known address, compute its balance key and add it to the list of addresses. // Mint events are instrumented as regular ETH events in the witness data, so we no longer // need to iterate over mint events during the migration. for _, addr := range addresses { - addrs = append(addrs, addr) - slotsInp[CalcOVMETHStorageKey(addr)] = 1 + sk := CalcOVMETHStorageKey(addr) + slotsAddrs[sk] = addr + slotsInp[sk] = 1 } // For each known allowance, compute its storage key and add it to the list of addresses. for _, allowance := range allowances { - addrs = append(addrs, allowance.From) - slotsInp[CalcAllowanceStorageKey(allowance.From, allowance.To)] = 2 + sk := CalcAllowanceStorageKey(allowance.From, allowance.To) + slotsAddrs[sk] = allowance.From + slotsInp[sk] = 2 } // Add the old SequencerEntrypoint because someone sent it ETH a long time ago and it has a // balance but none of our instrumentation could easily find it. Special case. sequencerEntrypointAddr := common.HexToAddress("0x4200000000000000000000000000000000000005") - addrs = append(addrs, sequencerEntrypointAddr) - slotsInp[CalcOVMETHStorageKey(sequencerEntrypointAddr)] = 1 + entrySK := CalcOVMETHStorageKey(sequencerEntrypointAddr) + slotsAddrs[entrySK] = sequencerEntrypointAddr + slotsInp[entrySK] = 1 - // Build a mapping of every storage slot in the LegacyERC20ETH contract, except the list of - // slots that we know we can ignore (totalSupply, name, symbol). - var count int - slotsAct := make(map[common.Hash]common.Hash) - progress := util.ProgressLogger(1000, "Read OVM_ETH storage slot") - err := db.ForEachStorage(predeploys.LegacyERC20ETHAddr, func(key, value common.Hash) bool { - progress() - - // We can safely ignore specific slots (totalSupply, name, symbol). - if ignoredSlots[key] { - return true + // WaitGroup to wait on each iteration job to finish. + var wg sync.WaitGroup + // Channel to receive storage slot keys and values from each iteration job. + outCh := make(chan accountData) + // Channel to receive errors from each iteration job. + errCh := make(chan error, checkJobs) + // Channel to cancel all iteration jobs as well as the collector. + cancelCh := make(chan struct{}) + + // Keep track of the total migrated supply. + totalFound := new(big.Int) + + // Divide the key space into partitions by dividing the key space by the number + // of jobs. This will leave some slots left over, which we handle below. + partSize := new(big.Int).Div(maxSlot.Big(), big.NewInt(checkJobs)) + + // Define a worker function to iterate over each partition. + worker := func(start, end common.Hash) { + // Decrement the WaitGroup when the function returns. + defer wg.Done() + + db, err := dbFactory() + if err != nil { + log.Crit("cannot get database", "err", err) } - // Slot exists, so add it to the map. - slotsAct[key] = value - count++ - return true - }) - if err != nil { - return nil, fmt.Errorf("cannot iterate over LegacyERC20ETHAddr: %w", err) - } + // Create a new storage trie. Each trie returned by db.StorageTrie + // is a copy, so this is safe for concurrent use. + st, err := db.StorageTrie(predeploys.LegacyERC20ETHAddr) + if err != nil { + // Should never happen, so explode if it does. + log.Crit("cannot get storage trie for LegacyERC20ETHAddr", "err", err) + } + if st == nil { + // Should never happen, so explode if it does. + log.Crit("nil storage trie for LegacyERC20ETHAddr") + } - // Log how many slots were iterated over. - log.Info("Iterated legacy balances", "count", count) + it := trie.NewIterator(st.NodeIterator(start.Bytes())) - // Iterate over the list of known slots and check that we have a slot for each one. We'll also - // keep track of the total balance to be migrated and throw if the total supply exceeds the - // expected supply delta. - totalFound := new(big.Int) - var unknown bool - for slot := range slotsAct { - slotType, ok := slotsInp[slot] - if !ok { - if noCheck { - log.Error("ignoring unknown storage slot in state", "slot", slot.String()) - } else { - unknown = true - log.Error("unknown storage slot in state", "slot", slot.String()) + // Below code is largely based on db.ForEachStorage. We can't use that + // because it doesn't allow us to specify a start and end key. + for it.Next() { + select { + case <-cancelCh: + // If one of the workers encounters an error, cancel all of them. + return + default: + break + } + + // Use the raw (i.e., secure hashed) key to check if we've reached + // the end of the partition. + if new(big.Int).SetBytes(it.Key).Cmp(end.Big()) >= 0 { + return + } + + // Skip if the value is empty. + rawValue := it.Value + if len(rawValue) == 0 { continue } - } - // Add balances to the total found. - switch slotType { - case 1: - // Balance slot. - totalFound.Add(totalFound, slotsAct[slot].Big()) - case 2: - // Allowance slot. - continue - default: - // Should never happen. - if noCheck { - log.Error("unknown slot type", "slot", slot, "type", slotType) - } else { - log.Crit("unknown slot type: %d", slotType) + // Get the preimage. + key := common.BytesToHash(st.GetKey(it.Key)) + + // Parse the raw value. + _, content, _, err := rlp.Split(rawValue) + if err != nil { + // Should never happen, so explode if it does. + log.Crit("mal-formed data in state: %v", err) + } + + // We can safely ignore specific slots (totalSupply, name, symbol). + if ignoredSlots[key] { + continue + } + + slotType, ok := slotsInp[key] + if !ok { + if noCheck { + log.Error("ignoring unknown storage slot in state", "slot", key.String()) + } else { + errCh <- fmt.Errorf("unknown storage slot in state: %s", key.String()) + return + } + } + + // No accounts should have a balance in state. If they do, bail. + addr, ok := slotsAddrs[key] + if !ok { + log.Crit("could not find address in map - should never happen") + } + bal := db.GetBalance(addr) + if bal.Sign() != 0 { + log.Error( + "account has non-zero balance in state - should never happen", + "addr", addr, + "balance", bal.String(), + ) + if !noCheck { + errCh <- fmt.Errorf("account has non-zero balance in state - should never happen: %s", addr.String()) + return + } + } + + // Add balances to the total found. + switch slotType { + case 1: + // Convert the value to a common.Hash, then send to the channel. + value := common.BytesToHash(content) + outCh <- accountData{ + balance: value.Big(), + address: addr, + } + case 2: + // Allowance slot. + continue + default: + // Should never happen. + if noCheck { + log.Error("unknown slot type", "slot", key, "type", slotType) + } else { + log.Crit("unknown slot type %d, should never happen", slotType) + } } } } - if unknown { - return nil, errors.New("unknown storage slots in state (see logs for details)") + + for i := 0; i < checkJobs; i++ { + wg.Add(1) + + // Compute the start and end keys for this partition. + start := common.BigToHash(new(big.Int).Mul(big.NewInt(int64(i)), partSize)) + var end common.Hash + if i < checkJobs-1 { + // If this is not the last partition, use the next partition's start key as the end. + end = common.BigToHash(new(big.Int).Mul(big.NewInt(int64(i+1)), partSize)) + } else { + // If this is the last partition, use the max slot as the end. + end = maxSlot + } + + // Kick off our worker. + go worker(start, end) + } + + // Make a channel to make sure that the collector process completes. + collectorCloseCh := make(chan struct{}) + + // Keep track of the last error seen. + var lastErr error + + // There are multiple ways that the cancel channel can be closed: + // - if we receive an error from the errCh + // - if the collector process completes + // To prevent panics, we wrap the close in a sync.Once. + var cancelOnce sync.Once + + // Kick off another background process to collect + // values from the channel and add them to the map. + var count int + progress := util.ProgressLogger(1000, "Collected OVM_ETH storage slot") + go func() { + defer func() { + collectorCloseCh <- struct{}{} + }() + for { + select { + case account := <-outCh: + progress() + // Accumulate addresses and total supply. + addrs = append(addrs, account.address) + totalFound = new(big.Int).Add(totalFound, account.balance) + case err := <-errCh: + lastErr = err + cancelOnce.Do(func() { + close(cancelCh) + }) + case <-cancelCh: + return + } + } + }() + + // Wait for the workers to finish. + wg.Wait() + // Close the cancel channel to signal the collector process to stop. + cancelOnce.Do(func() { + close(cancelCh) + }) + + // Wait for the collector process to finish. + <-collectorCloseCh + + // If we saw an error, return it. + if lastErr != nil { + return nil, lastErr } + // Log how many slots were iterated over. + log.Info("Iterated legacy balances", "count", count) + // Verify the supply delta. Recorded total supply in the LegacyERC20ETH contract may be higher // than the actual migrated amount because self-destructs will remove ETH supply in a way that // cannot be reflected in the contract. This is fine because self-destructs just mean the L2 is // actually *overcollateralized* by some tiny amount. + db, err := dbFactory() + if err != nil { + log.Crit("cannot get database", "err", err) + } + totalSupply := getOVMETHTotalSupply(db) delta := new(big.Int).Sub(totalSupply, totalFound) - if delta.Cmp(params.ExpectedSupplyDelta) != 0 { - if noCheck { - log.Error( - "supply mismatch", - "migrated", totalFound.String(), - "supply", totalSupply.String(), - "delta", delta.String(), - "exp_delta", params.ExpectedSupplyDelta.String(), - ) - } else { - log.Crit( - "supply mismatch", - "migrated", totalFound.String(), - "supply", totalSupply.String(), - "delta", delta.String(), - "exp_delta", params.ExpectedSupplyDelta.String(), - ) + if delta.Cmp(expDiff) != 0 { + log.Error( + "supply mismatch", + "migrated", totalFound.String(), + "supply", totalSupply.String(), + "delta", delta.String(), + "exp_delta", expDiff.String(), + ) + if !noCheck { + return nil, fmt.Errorf("supply mismatch: %s", delta.String()) } } @@ -146,7 +312,7 @@ func PreCheckBalances(ldb ethdb.Database, db *state.StateDB, addresses []common. "migrated", totalFound.String(), "supply", totalSupply.String(), "delta", delta.String(), - "exp_delta", params.ExpectedSupplyDelta.String(), + "exp_delta", expDiff.String(), ) // We know we have at least a superset of all addresses here since we know that we have every diff --git a/op-chain-ops/ether/precheck_test.go b/op-chain-ops/ether/precheck_test.go new file mode 100644 index 000000000..3d234664b --- /dev/null +++ b/op-chain-ops/ether/precheck_test.go @@ -0,0 +1,295 @@ +package ether + +import ( + "bytes" + "math/big" + "math/rand" + "os" + "sort" + "testing" + + "github.com/ethereum/go-ethereum/log" + + "github.com/ethereum-optimism/optimism/op-chain-ops/crossdomain" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/core/state" + "github.com/ethereum/go-ethereum/trie" + "github.com/stretchr/testify/require" +) + +func TestPreCheckBalances(t *testing.T) { + log.Root().SetHandler(log.StreamHandler(os.Stderr, log.TerminalFormat(true))) + + tests := []struct { + name string + totalSupply *big.Int + expDiff *big.Int + stateBalances map[common.Address]*big.Int + stateAllowances map[common.Address]common.Address + inputAddresses []common.Address + inputAllowances []*crossdomain.Allowance + check func(t *testing.T, addrs FilteredOVMETHAddresses, err error) + }{ + { + name: "everything matches", + totalSupply: big.NewInt(3), + expDiff: big.NewInt(0), + stateBalances: map[common.Address]*big.Int{ + common.HexToAddress("0x123"): big.NewInt(1), + common.HexToAddress("0x456"): big.NewInt(2), + }, + stateAllowances: map[common.Address]common.Address{ + common.HexToAddress("0x123"): common.HexToAddress("0x456"), + }, + inputAddresses: []common.Address{ + common.HexToAddress("0x123"), + common.HexToAddress("0x456"), + }, + inputAllowances: []*crossdomain.Allowance{ + { + From: common.HexToAddress("0x123"), + To: common.HexToAddress("0x456"), + }, + }, + check: func(t *testing.T, addrs FilteredOVMETHAddresses, err error) { + require.NoError(t, err) + require.EqualValues(t, FilteredOVMETHAddresses{ + common.HexToAddress("0x123"), + common.HexToAddress("0x456"), + }, addrs) + }, + }, + { + name: "extra input addresses", + totalSupply: big.NewInt(1), + expDiff: big.NewInt(0), + stateBalances: map[common.Address]*big.Int{ + common.HexToAddress("0x123"): big.NewInt(1), + }, + inputAddresses: []common.Address{ + common.HexToAddress("0x123"), + common.HexToAddress("0x456"), + }, + check: func(t *testing.T, addrs FilteredOVMETHAddresses, err error) { + require.NoError(t, err) + require.EqualValues(t, FilteredOVMETHAddresses{ + common.HexToAddress("0x123"), + }, addrs) + }, + }, + { + name: "extra input allowances", + totalSupply: big.NewInt(1), + expDiff: big.NewInt(0), + stateBalances: map[common.Address]*big.Int{ + common.HexToAddress("0x123"): big.NewInt(1), + }, + stateAllowances: map[common.Address]common.Address{ + common.HexToAddress("0x123"): common.HexToAddress("0x456"), + }, + inputAddresses: []common.Address{ + common.HexToAddress("0x123"), + common.HexToAddress("0x456"), + }, + inputAllowances: []*crossdomain.Allowance{ + { + From: common.HexToAddress("0x123"), + To: common.HexToAddress("0x456"), + }, + { + From: common.HexToAddress("0x123"), + To: common.HexToAddress("0x789"), + }, + }, + check: func(t *testing.T, addrs FilteredOVMETHAddresses, err error) { + require.NoError(t, err) + require.EqualValues(t, FilteredOVMETHAddresses{ + common.HexToAddress("0x123"), + }, addrs) + }, + }, + { + name: "missing input addresses", + totalSupply: big.NewInt(2), + expDiff: big.NewInt(0), + stateBalances: map[common.Address]*big.Int{ + common.HexToAddress("0x123"): big.NewInt(1), + common.HexToAddress("0x456"): big.NewInt(1), + }, + inputAddresses: []common.Address{ + common.HexToAddress("0x123"), + }, + check: func(t *testing.T, addrs FilteredOVMETHAddresses, err error) { + require.Error(t, err) + require.ErrorContains(t, err, "unknown storage slot") + }, + }, + { + name: "missing input allowances", + totalSupply: big.NewInt(2), + expDiff: big.NewInt(0), + stateBalances: map[common.Address]*big.Int{ + common.HexToAddress("0x123"): big.NewInt(1), + }, + stateAllowances: map[common.Address]common.Address{ + common.HexToAddress("0x123"): common.HexToAddress("0x456"), + common.HexToAddress("0x123"): common.HexToAddress("0x789"), + }, + inputAddresses: []common.Address{ + common.HexToAddress("0x123"), + }, + inputAllowances: []*crossdomain.Allowance{ + { + From: common.HexToAddress("0x123"), + To: common.HexToAddress("0x456"), + }, + }, + check: func(t *testing.T, addrs FilteredOVMETHAddresses, err error) { + require.Error(t, err) + require.ErrorContains(t, err, "unknown storage slot") + }, + }, + { + name: "bad supply diff", + totalSupply: big.NewInt(4), + expDiff: big.NewInt(0), + stateBalances: map[common.Address]*big.Int{ + common.HexToAddress("0x123"): big.NewInt(1), + common.HexToAddress("0x456"): big.NewInt(2), + }, + inputAddresses: []common.Address{ + common.HexToAddress("0x123"), + common.HexToAddress("0x456"), + }, + check: func(t *testing.T, addrs FilteredOVMETHAddresses, err error) { + require.Error(t, err) + require.ErrorContains(t, err, "supply mismatch") + }, + }, + { + name: "good supply diff", + totalSupply: big.NewInt(4), + expDiff: big.NewInt(1), + stateBalances: map[common.Address]*big.Int{ + common.HexToAddress("0x123"): big.NewInt(1), + common.HexToAddress("0x456"): big.NewInt(2), + }, + inputAddresses: []common.Address{ + common.HexToAddress("0x123"), + common.HexToAddress("0x456"), + }, + check: func(t *testing.T, addrs FilteredOVMETHAddresses, err error) { + require.NoError(t, err) + require.EqualValues(t, FilteredOVMETHAddresses{ + common.HexToAddress("0x123"), + common.HexToAddress("0x456"), + }, addrs) + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + db := makeLegacyETH(t, tt.totalSupply, tt.stateBalances, tt.stateAllowances) + factory := func() (*state.StateDB, error) { + return db, nil + } + addrs, err := doMigration(factory, tt.inputAddresses, tt.inputAllowances, tt.expDiff, false) + + // Sort the addresses since they come in in a random order. + sort.Slice(addrs, func(i, j int) bool { + return bytes.Compare(addrs[i][:], addrs[j][:]) < 0 + }) + + tt.check(t, addrs, err) + }) + } +} + +func makeLegacyETH(t *testing.T, totalSupply *big.Int, balances map[common.Address]*big.Int, allowances map[common.Address]common.Address) *state.StateDB { + db, err := state.New(common.Hash{}, state.NewDatabaseWithConfig(rawdb.NewMemoryDatabase(), &trie.Config{ + Preimages: true, + Cache: 1024, + }), nil) + require.NoError(t, err) + + db.CreateAccount(OVMETHAddress) + db.SetState(OVMETHAddress, getOVMETHTotalSupplySlot(), common.BigToHash(totalSupply)) + + for slot := range ignoredSlots { + if slot == getOVMETHTotalSupplySlot() { + continue + } + db.SetState(OVMETHAddress, slot, common.Hash{31: 0xff}) + } + for addr, balance := range balances { + db.SetState(OVMETHAddress, CalcOVMETHStorageKey(addr), common.BigToHash(balance)) + } + for from, to := range allowances { + db.SetState(OVMETHAddress, CalcAllowanceStorageKey(from, to), common.BigToHash(big.NewInt(1))) + } + + root, err := db.Commit(false) + require.NoError(t, err) + + err = db.Database().TrieDB().Commit(root, true) + require.NoError(t, err) + + return db +} + +// TestPreCheckBalancesRandom tests that the pre-check balances function works +// with random addresses. This test makes sure that the partition logic doesn't +// miss anything. +func TestPreCheckBalancesRandom(t *testing.T) { + addresses := make([]common.Address, 0) + stateBalances := make(map[common.Address]*big.Int) + + allowances := make([]*crossdomain.Allowance, 0) + stateAllowances := make(map[common.Address]common.Address) + + totalSupply := big.NewInt(0) + + for i := 0; i < 100; i++ { + for i := 0; i < rand.Intn(1000); i++ { + addr := randAddr(t) + addresses = append(addresses, addr) + stateBalances[addr] = big.NewInt(int64(rand.Intn(1_000_000))) + totalSupply = new(big.Int).Add(totalSupply, stateBalances[addr]) + } + + sort.Slice(addresses, func(i, j int) bool { + return bytes.Compare(addresses[i][:], addresses[j][:]) < 0 + }) + + for i := 0; i < rand.Intn(1000); i++ { + addr := randAddr(t) + to := randAddr(t) + allowances = append(allowances, &crossdomain.Allowance{ + From: addr, + To: to, + }) + stateAllowances[addr] = to + } + + db := makeLegacyETH(t, totalSupply, stateBalances, stateAllowances) + factory := func() (*state.StateDB, error) { + return db, nil + } + + outAddrs, err := doMigration(factory, addresses, allowances, big.NewInt(0), false) + require.NoError(t, err) + + sort.Slice(outAddrs, func(i, j int) bool { + return bytes.Compare(outAddrs[i][:], outAddrs[j][:]) < 0 + }) + require.EqualValues(t, addresses, outAddrs) + } +} + +func randAddr(t *testing.T) common.Address { + var addr common.Address + _, err := rand.Read(addr[:]) + require.NoError(t, err) + return addr +} diff --git a/op-chain-ops/genesis/check.go b/op-chain-ops/genesis/check.go index 3af17b605..abe569970 100644 --- a/op-chain-ops/genesis/check.go +++ b/op-chain-ops/genesis/check.go @@ -7,6 +7,8 @@ import ( "fmt" "math/big" + "github.com/ethereum-optimism/optimism/op-chain-ops/ether" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/state" @@ -251,7 +253,7 @@ func PostCheckPredeploys(prevDB, currDB *state.StateDB) error { // Balances and nonces should match legacy oldNonce := prevDB.GetNonce(addr) - oldBalance := prevDB.GetBalance(addr) + oldBalance := ether.GetOVMETHBalance(prevDB, addr) newNonce := currDB.GetNonce(addr) newBalance := currDB.GetBalance(addr) if oldNonce != newNonce { @@ -543,7 +545,7 @@ func CheckWithdrawalsAfter(db vm.StateDB, data crossdomain.MigrationData, l1Cros // If the sender is _not_ the L2XDM, the value should not be migrated. wd := wdsByOldSlot[key] - if wd.XDomainSender == predeploys.L2CrossDomainMessengerAddr { + if wd.MessageSender == predeploys.L2CrossDomainMessengerAddr { // Make sure the value is abiTrue if this withdrawal should be migrated. if migratedValue != abiTrue { innerErr = fmt.Errorf("expected migrated value to be true, but got %s", migratedValue) @@ -552,7 +554,7 @@ func CheckWithdrawalsAfter(db vm.StateDB, data crossdomain.MigrationData, l1Cros } else { // Otherwise, ensure that withdrawals from senders other than the L2XDM are _not_ migrated. if migratedValue != abiFalse { - innerErr = fmt.Errorf("a migration from a sender other than the L2XDM was migrated") + innerErr = fmt.Errorf("a migration from a sender other than the L2XDM was migrated. sender: %s, migrated value: %s", wd.MessageSender, migratedValue) return false } } diff --git a/op-chain-ops/genesis/db_migration.go b/op-chain-ops/genesis/db_migration.go index 2f6b802e5..1eafd1598 100644 --- a/op-chain-ops/genesis/db_migration.go +++ b/op-chain-ops/genesis/db_migration.go @@ -82,16 +82,25 @@ func MigrateDB(ldb ethdb.Database, config *DeployConfig, l1Block *types.Block, m ) } - // Set up the backing store. - underlyingDB := state.NewDatabaseWithConfig(ldb, &trie.Config{ - Preimages: true, - Cache: 1024, - }) - - // Open up the state database. - db, err := state.New(header.Root, underlyingDB, nil) + dbFactory := func() (*state.StateDB, error) { + // Set up the backing store. + underlyingDB := state.NewDatabaseWithConfig(ldb, &trie.Config{ + Preimages: true, + Cache: 1024, + }) + + // Open up the state database. + db, err := state.New(header.Root, underlyingDB, nil) + if err != nil { + return nil, fmt.Errorf("cannot open StateDB: %w", err) + } + + return db, nil + } + + db, err := dbFactory() if err != nil { - return nil, fmt.Errorf("cannot open StateDB: %w", err) + return nil, fmt.Errorf("cannot create StateDB: %w", err) } // Before we do anything else, we need to ensure that all of the input configuration is correct @@ -139,7 +148,7 @@ func MigrateDB(ldb ethdb.Database, config *DeployConfig, l1Block *types.Block, m // Unlike with withdrawals, we do not need to filter out extra addresses because their balances // would necessarily be zero and therefore not affect the migration. log.Info("Checking addresses...", "no-check", noCheck) - addrs, err := ether.PreCheckBalances(ldb, db, migrationData.Addresses(), migrationData.OvmAllowances, int(config.L1ChainID), noCheck) + addrs, err := ether.PreCheckBalances(dbFactory, migrationData.Addresses(), migrationData.OvmAllowances, int(config.L1ChainID), noCheck) if err != nil { return nil, fmt.Errorf("addresses mismatch: %w", err) }