From 5d94da15e6e341cd5e5eee5ed250f938195a9f59 Mon Sep 17 00:00:00 2001 From: Goran Rojovic Date: Sat, 10 Feb 2024 14:23:39 +0100 Subject: [PATCH] Journaling fix --- state/executor.go | 108 ++++++++++++++++++++++++------ state/executor_test.go | 3 +- state/runtime/evm/instructions.go | 29 +++----- state/runtime/journal.go | 28 +++++--- state/runtime/runtime.go | 26 +++---- state/transition_test.go | 7 +- 6 files changed, 128 insertions(+), 73 deletions(-) diff --git a/state/executor.go b/state/executor.go index e2e69cc7ca..9d396499cc 100644 --- a/state/executor.go +++ b/state/executor.go @@ -5,6 +5,7 @@ import ( "fmt" "math" "math/big" + "sort" "github.com/hashicorp/go-hclog" @@ -88,6 +89,7 @@ func (e *Executor) WriteGenesis( gasPool: uint64(env.GasLimit), config: config, precompiles: precompiled.NewPrecompiled(), + journal: &runtime.Journal{}, } for addr, account := range alloc { @@ -221,6 +223,7 @@ func (e *Executor) BeginTxn( evm: evm.NewEVM(), precompiles: precompiled.NewPrecompiled(), PostHook: e.PostHook, + journal: &runtime.Journal{}, } // enable contract deployment allow list (if any) @@ -283,6 +286,12 @@ type Transition struct { txnBlockList *addresslist.AddressList bridgeAllowList *addresslist.AddressList bridgeBlockList *addresslist.AddressList + + // journaling + journal *runtime.Journal + journalRevisions []runtime.JournalRevision + + accessList *runtime.AccessList } func NewTransition(config chain.ForksInTime, snap Snapshot, radix *Txn) *Transition { @@ -292,6 +301,7 @@ func NewTransition(config chain.ForksInTime, snap Snapshot, radix *Txn) *Transit snap: snap, evm: evm.NewEVM(), precompiles: precompiled.NewPrecompiled(), + journal: &runtime.Journal{}, } } @@ -429,11 +439,11 @@ func (t *Transition) Txn() *Txn { // Apply applies a new transaction func (t *Transition) Apply(msg *types.Transaction) (*runtime.ExecutionResult, error) { - s := t.state.Snapshot() + s := t.Snapshot() result, err := t.apply(msg) if err != nil { - if revertErr := t.state.RevertToSnapshot(s); revertErr != nil { + if revertErr := t.RevertToSnapshot(s); revertErr != nil { return nil, revertErr } } @@ -630,14 +640,16 @@ func (t *Transition) apply(msg *types.Transaction) (*runtime.ExecutionResult, er initialAccessList.PrepareAccessList(msg.From(), msg.To(), t.precompiles.Addrs, msg.AccessList()) } + t.accessList = initialAccessList + var result *runtime.ExecutionResult if msg.IsContractCreation() { - result = t.Create2(msg.From(), msg.Input(), value, gasLeft, initialAccessList) + result = t.Create2(msg.From(), msg.Input(), value, gasLeft) } else { if err := t.state.IncrNonce(msg.From()); err != nil { return nil, err } - result = t.Call2(msg.From(), *(msg.To()), msg.Input(), value, gasLeft, initialAccessList) + result = t.Call2(msg.From(), *(msg.To()), msg.Input(), value, gasLeft) } refundQuotient := LegacyRefundQuotient @@ -672,6 +684,7 @@ func (t *Transition) apply(msg *types.Transaction) (*runtime.ExecutionResult, er coinbaseFee := new(big.Int).Mul(new(big.Int).SetUint64(result.GasUsed), effectiveTip) t.state.AddBalance(t.ctx.Coinbase, coinbaseFee) + // TODO - burning of base fee should not be done in the EVM // Burn some amount if the london hardfork is applied. // Basically, burn amount is just transferred to the current burn contract. // if t.config.London && msg.Type() != types.StateTx { @@ -690,10 +703,9 @@ func (t *Transition) Create2( code []byte, value *big.Int, gas uint64, - initialAccessList *runtime.AccessList, ) *runtime.ExecutionResult { address := crypto.CreateAddress(caller, t.state.GetNonce(caller)) - contract := runtime.NewContractCreation(1, caller, caller, address, value, gas, code, initialAccessList) + contract := runtime.NewContractCreation(1, caller, caller, address, value, gas, code) return t.applyCreate(contract, t) } @@ -704,9 +716,8 @@ func (t *Transition) Call2( input []byte, value *big.Int, gas uint64, - initialAccessList *runtime.AccessList, ) *runtime.ExecutionResult { - c := runtime.NewContractCall(1, caller, caller, to, value, gas, t.state.GetCode(to), input, initialAccessList) + c := runtime.NewContractCall(1, caller, caller, to, value, gas, t.state.GetCode(to), input) return t.applyCall(c, runtime.Call, t) } @@ -795,7 +806,7 @@ func (t *Transition) applyCall( } } - snapshot := t.state.Snapshot() + snapshot := t.Snapshot() t.state.TouchAccount(c.Address) if callType == runtime.Call { @@ -814,9 +825,7 @@ func (t *Transition) applyCall( result = t.run(c, host) if result.Failed() { - c.RevertJournal() - - if err := t.state.RevertToSnapshot(snapshot); err != nil { + if err := t.RevertToSnapshot(snapshot); err != nil { return &runtime.ExecutionResult{ GasLeft: c.Gas, Err: err, @@ -861,8 +870,7 @@ func (t *Transition) applyCreate(c *runtime.Contract, host runtime.Host) *runtim // we add this to the access-list before taking a snapshot. Even if the creation fails, // the access-list change should not be rolled back according to EIP2929 specs if t.config.Berlin { - c.AddToJournal(&runtime.AccessListAddAccountChange{Address: c.Address}) - c.AccessList.AddAddress(c.Address) + t.AddAddressToAccessList(c.Address) } // Check if there is a collision and the address already exists @@ -874,7 +882,7 @@ func (t *Transition) applyCreate(c *runtime.Contract, host runtime.Host) *runtim } // Take snapshot of the current state - snapshot := t.state.Snapshot() + snapshot := t.Snapshot() if t.config.EIP158 { // Force the creation of the account @@ -937,9 +945,7 @@ func (t *Transition) applyCreate(c *runtime.Contract, host runtime.Host) *runtim result = t.run(c, host) if result.Failed() { - c.RevertJournal() - - if err := t.state.RevertToSnapshot(snapshot); err != nil { + if err := t.RevertToSnapshot(snapshot); err != nil { return &runtime.ExecutionResult{ Err: err, } @@ -950,7 +956,7 @@ func (t *Transition) applyCreate(c *runtime.Contract, host runtime.Host) *runtim if t.config.EIP158 && len(result.ReturnValue) > SpuriousDragonMaxCodeSize { // Contract size exceeds 'SpuriousDragon' size limit - if err := t.state.RevertToSnapshot(snapshot); err != nil { + if err := t.RevertToSnapshot(snapshot); err != nil { return &runtime.ExecutionResult{ Err: err, } @@ -964,7 +970,7 @@ func (t *Transition) applyCreate(c *runtime.Contract, host runtime.Host) *runtim // Reject code starting with 0xEF if EIP-3541 is enabled. if result.Err == nil && len(result.ReturnValue) >= 1 && result.ReturnValue[0] == 0xEF && t.config.London { - if err := t.state.RevertToSnapshot(snapshot); err != nil { + if err := t.RevertToSnapshot(snapshot); err != nil { return &runtime.ExecutionResult{ Err: err, } @@ -984,7 +990,7 @@ func (t *Transition) applyCreate(c *runtime.Contract, host runtime.Host) *runtim // Out of gas creating the contract if t.config.Homestead { - if err := t.state.RevertToSnapshot(snapshot); err != nil { + if err := t.RevertToSnapshot(snapshot); err != nil { return &runtime.ExecutionResult{ Err: err, } @@ -1298,3 +1304,63 @@ func (t *Transition) captureCallEnd(c *runtime.Contract, result *runtime.Executi result.Err, ) } + +func (t *Transition) AddToJournal(j runtime.JournalEntry) { + t.journal.Append(j) +} + +func (t *Transition) Snapshot() int { + snapshot := t.state.Snapshot() + t.journalRevisions = append(t.journalRevisions, runtime.JournalRevision{ID: snapshot, Index: t.journal.Len()}) + + return snapshot +} + +func (t *Transition) RevertToSnapshot(snapshot int) error { + if err := t.state.RevertToSnapshot(snapshot); err != nil { + return err + } + + // Find the snapshot in the stack of valid snapshots. + idx := sort.Search(len(t.journalRevisions), func(i int) bool { + return t.journalRevisions[i].ID >= snapshot + }) + + if idx == len(t.journalRevisions) || t.journalRevisions[idx].ID != snapshot { + panic(fmt.Errorf("journal revision id %v cannot be reverted", snapshot)) + } + + journalIndex := t.journalRevisions[idx].Index + + // Replay the journal to undo changes and remove invalidated snapshots + t.journal.Revert(t, journalIndex) + t.journalRevisions = t.journalRevisions[:idx] + + return nil +} + +func (t *Transition) AddSlotToAccessList(addr types.Address, slot types.Hash) { + t.journal.Append(&runtime.AccessListAddSlotChange{Address: addr, Slot: slot}) + t.accessList.AddSlot(addr, slot) +} + +func (t *Transition) AddAddressToAccessList(addr types.Address) { + t.journal.Append(&runtime.AccessListAddAccountChange{Address: addr}) + t.accessList.AddAddress(addr) +} + +func (t *Transition) ContainsAccessListAddress(addr types.Address) bool { + return t.accessList.ContainsAddress(addr) +} + +func (t *Transition) ContainsAccessListSlot(addr types.Address, slot types.Hash) (bool, bool) { + return t.accessList.Contains(addr, slot) +} + +func (t *Transition) DeleteAccessListAddress(addr types.Address) { + t.accessList.DeleteAddress(addr) +} + +func (t *Transition) DeleteAccessListSlot(addr types.Address, slot types.Hash) { + t.accessList.DeleteSlot(addr, slot) +} diff --git a/state/executor_test.go b/state/executor_test.go index 0ea86367dd..870afa636e 100644 --- a/state/executor_test.go +++ b/state/executor_test.go @@ -262,8 +262,9 @@ func Test_Transition_EIP2929(t *testing.T) { transition := NewTransition(enabledForks, state, txn) initialAccessList := runtime.NewAccessList() initialAccessList.PrepareAccessList(transition.ctx.Origin, &addr, transition.precompiles.Addrs, nil) + transition.accessList = initialAccessList - result := transition.Call2(transition.ctx.Origin, addr, nil, big.NewInt(0), uint64(1000000), initialAccessList) + result := transition.Call2(transition.ctx.Origin, addr, nil, big.NewInt(0), uint64(1000000)) assert.Equal(t, tt.gasConsumed, result.GasUsed, "Gas consumption for %s is inaccurate according to EIP 2929", tt.name) }) } diff --git a/state/runtime/evm/instructions.go b/state/runtime/evm/instructions.go index 0a66d60405..89f4c1a74d 100644 --- a/state/runtime/evm/instructions.go +++ b/state/runtime/evm/instructions.go @@ -30,12 +30,11 @@ var ( func (c *state) calculateGasForEIP2929(addr types.Address) uint64 { var gas uint64 - if c.msg.AccessList.ContainsAddress(addr) { + if c.host.ContainsAccessListAddress(addr) { gas = WarmStorageReadCostEIP2929 } else { gas = ColdAccountAccessCostEIP2929 - c.msg.AddToJournal(&runtime.AccessListAddAccountChange{Address: addr}) - c.msg.AccessList.AddAddress(addr) + c.host.AddAddressToAccessList(addr) } return gas @@ -486,10 +485,10 @@ func opSload(c *state) { var gas uint64 if c.config.Berlin { - if _, slotPresent := c.msg.AccessList.Contains(c.msg.Address, bigToHash(loc)); !slotPresent { + if _, slotPresent := c.host.ContainsAccessListSlot(c.msg.Address, bigToHash(loc)); !slotPresent { gas = ColdStorageReadCostEIP2929 - c.addAccessListSlot(c.msg.Address, bigToHash(loc)) + c.host.AddSlotToAccessList(c.msg.Address, bigToHash(loc)) } else { gas = WarmStorageReadCostEIP2929 } @@ -532,10 +531,10 @@ func opSStore(c *state) { cost := uint64(0) if c.config.Berlin { - if _, slotPresent := c.msg.AccessList.Contains(c.msg.Address, key); !slotPresent { + if _, slotPresent := c.host.ContainsAccessListSlot(c.msg.Address, key); !slotPresent { cost = ColdStorageReadCostEIP2929 - c.addAccessListSlot(c.msg.Address, key) + c.host.AddSlotToAccessList(c.msg.Address, key) } } @@ -1005,10 +1004,10 @@ func opSelfDestruct(c *state) { } // EIP 2929 gas - if c.config.Berlin && !c.msg.AccessList.ContainsAddress(address) { + if c.config.Berlin && !c.host.ContainsAccessListAddress(address) { gas += ColdAccountAccessCostEIP2929 - c.addAccessListAddress(address) + c.host.AddAddressToAccessList(address) } if !c.consumeGas(gas) { @@ -1379,7 +1378,6 @@ func (c *state) buildCallContract(op OpCode) (*runtime.Contract, uint64, uint64, gas, c.host.GetCode(addr), args, - c.msg.AccessList, ) if op == STATICCALL || parent.msg.Static { @@ -1479,22 +1477,11 @@ func (c *state) buildCreateContract(op OpCode) (*runtime.Contract, error) { value, gas, input, - c.msg.AccessList, ) return contract, nil } -func (c *state) addAccessListSlot(address types.Address, slot types.Hash) { - c.msg.AddToJournal(&runtime.AccessListAddSlotChange{Address: address, Slot: slot}) - c.msg.AccessList.AddSlot(address, slot) -} - -func (c *state) addAccessListAddress(address types.Address) { - c.msg.AddToJournal(&runtime.AccessListAddAccountChange{Address: address}) - c.msg.AccessList.AddAddress(address) -} - func opHalt(op OpCode) instruction { return func(c *state) { if op == REVERT && !c.config.Byzantium { diff --git a/state/runtime/journal.go b/state/runtime/journal.go index 2aac5e4f4e..86048eaed3 100644 --- a/state/runtime/journal.go +++ b/state/runtime/journal.go @@ -4,8 +4,13 @@ import ( "github.com/0xPolygon/polygon-edge/types" ) +type JournalRevision struct { + ID int + Index int +} + type JournalEntry interface { - Revert(c *Contract) + Revert(host Host) } type Journal struct { @@ -16,12 +21,17 @@ func (j *Journal) Append(entry JournalEntry) { j.entries = append(j.entries, entry) } -func (j *Journal) Revert(c *Contract) { - for i := len(j.entries) - 1; i >= 0; i-- { - j.entries[i].Revert(c) +func (j *Journal) Revert(host Host, snapshot int) { + for i := len(j.entries) - 1; i >= snapshot; i-- { + // Undo the changes made by the operation + j.entries[i].Revert(host) } - j.entries = j.entries[:0] + j.entries = j.entries[:snapshot] +} + +func (j *Journal) Len() int { + return len(j.entries) } type ( @@ -36,12 +46,12 @@ type ( var _ JournalEntry = (*AccessListAddAccountChange)(nil) -func (ch AccessListAddAccountChange) Revert(c *Contract) { - c.AccessList.DeleteAddress(ch.Address) +func (ch AccessListAddAccountChange) Revert(host Host) { + host.DeleteAccessListAddress(ch.Address) } var _ JournalEntry = (*AccessListAddSlotChange)(nil) -func (ch AccessListAddSlotChange) Revert(c *Contract) { - c.AccessList.DeleteSlot(ch.Address, ch.Slot) +func (ch AccessListAddSlotChange) Revert(host Host) { + host.DeleteAccessListSlot(ch.Address, ch.Slot) } diff --git a/state/runtime/runtime.go b/state/runtime/runtime.go index fa4ecddbe2..c180c7c777 100644 --- a/state/runtime/runtime.go +++ b/state/runtime/runtime.go @@ -80,6 +80,12 @@ type Host interface { Transfer(from types.Address, to types.Address, amount *big.Int) error GetTracer() VMTracer GetRefund() uint64 + AddSlotToAccessList(addr types.Address, slot types.Hash) + AddAddressToAccessList(addr types.Address) + ContainsAccessListAddress(addr types.Address) bool + ContainsAccessListSlot(addr types.Address, slot types.Hash) (bool, bool) + DeleteAccessListAddress(addr types.Address) + DeleteAccessListSlot(addr types.Address, slot types.Hash) } type VMTracer interface { @@ -199,9 +205,6 @@ type Contract struct { Input []byte Gas uint64 Static bool - AccessList *AccessList - - Journal *Journal } func NewContract( @@ -212,7 +215,6 @@ func NewContract( value *big.Int, gas uint64, code []byte, - accessList *AccessList, ) *Contract { f := &Contract{ Caller: from, @@ -223,8 +225,6 @@ func NewContract( Value: value, Code: code, Depth: depth, - AccessList: accessList, - Journal: &Journal{}, } return f @@ -238,9 +238,8 @@ func NewContractCreation( value *big.Int, gas uint64, code []byte, - accessList *AccessList, ) *Contract { - c := NewContract(depth, origin, from, to, value, gas, code, accessList) + c := NewContract(depth, origin, from, to, value, gas, code) return c } @@ -254,18 +253,9 @@ func NewContractCall( gas uint64, code []byte, input []byte, - accessList *AccessList, ) *Contract { - c := NewContract(depth, origin, from, to, value, gas, code, accessList) + c := NewContract(depth, origin, from, to, value, gas, code) c.Input = input return c } - -func (c *Contract) RevertJournal() { - c.Journal.Revert(c) -} - -func (c *Contract) AddToJournal(e JournalEntry) { - c.Journal.Append(e) -} diff --git a/state/transition_test.go b/state/transition_test.go index 5bfba6edff..c683538071 100644 --- a/state/transition_test.go +++ b/state/transition_test.go @@ -16,9 +16,10 @@ func newTestTransition(preState map[types.Address]*PreState) *Transition { } return &Transition{ - logger: hclog.NewNullLogger(), - state: newTestTxn(preState), - ctx: runtime.TxContext{BaseFee: big.NewInt(0)}, + logger: hclog.NewNullLogger(), + state: newTestTxn(preState), + ctx: runtime.TxContext{BaseFee: big.NewInt(0)}, + journal: &runtime.Journal{}, } }