Skip to content

Commit

Permalink
Journaling fix
Browse files Browse the repository at this point in the history
  • Loading branch information
goran-ethernal committed Feb 10, 2024
1 parent 2bbeaa7 commit 5d94da1
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 73 deletions.
108 changes: 87 additions & 21 deletions state/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"math"
"math/big"
"sort"

"github.com/hashicorp/go-hclog"

Expand Down Expand Up @@ -88,6 +89,7 @@ func (e *Executor) WriteGenesis(
gasPool: uint64(env.GasLimit),
config: config,
precompiles: precompiled.NewPrecompiled(),
journal: &runtime.Journal{},
}

for addr, account := range alloc {
Expand Down Expand Up @@ -221,6 +223,7 @@ func (e *Executor) BeginTxn(
evm: evm.NewEVM(),
precompiles: precompiled.NewPrecompiled(),
PostHook: e.PostHook,
journal: &runtime.Journal{},
}

// enable contract deployment allow list (if any)
Expand Down Expand Up @@ -283,6 +286,12 @@ type Transition struct {
txnBlockList *addresslist.AddressList
bridgeAllowList *addresslist.AddressList
bridgeBlockList *addresslist.AddressList

// journaling
journal *runtime.Journal
journalRevisions []runtime.JournalRevision

accessList *runtime.AccessList
}

func NewTransition(config chain.ForksInTime, snap Snapshot, radix *Txn) *Transition {
Expand All @@ -292,6 +301,7 @@ func NewTransition(config chain.ForksInTime, snap Snapshot, radix *Txn) *Transit
snap: snap,
evm: evm.NewEVM(),
precompiles: precompiled.NewPrecompiled(),
journal: &runtime.Journal{},
}
}

Expand Down Expand Up @@ -429,11 +439,11 @@ func (t *Transition) Txn() *Txn {

// Apply applies a new transaction
func (t *Transition) Apply(msg *types.Transaction) (*runtime.ExecutionResult, error) {
s := t.state.Snapshot()
s := t.Snapshot()

result, err := t.apply(msg)
if err != nil {
if revertErr := t.state.RevertToSnapshot(s); revertErr != nil {
if revertErr := t.RevertToSnapshot(s); revertErr != nil {
return nil, revertErr
}
}
Expand Down Expand Up @@ -630,14 +640,16 @@ func (t *Transition) apply(msg *types.Transaction) (*runtime.ExecutionResult, er
initialAccessList.PrepareAccessList(msg.From(), msg.To(), t.precompiles.Addrs, msg.AccessList())
}

t.accessList = initialAccessList

var result *runtime.ExecutionResult
if msg.IsContractCreation() {
result = t.Create2(msg.From(), msg.Input(), value, gasLeft, initialAccessList)
result = t.Create2(msg.From(), msg.Input(), value, gasLeft)
} else {
if err := t.state.IncrNonce(msg.From()); err != nil {
return nil, err
}
result = t.Call2(msg.From(), *(msg.To()), msg.Input(), value, gasLeft, initialAccessList)
result = t.Call2(msg.From(), *(msg.To()), msg.Input(), value, gasLeft)
}

refundQuotient := LegacyRefundQuotient
Expand Down Expand Up @@ -672,6 +684,7 @@ func (t *Transition) apply(msg *types.Transaction) (*runtime.ExecutionResult, er
coinbaseFee := new(big.Int).Mul(new(big.Int).SetUint64(result.GasUsed), effectiveTip)
t.state.AddBalance(t.ctx.Coinbase, coinbaseFee)

// TODO - burning of base fee should not be done in the EVM
// Burn some amount if the london hardfork is applied.
// Basically, burn amount is just transferred to the current burn contract.
// if t.config.London && msg.Type() != types.StateTx {
Expand All @@ -690,10 +703,9 @@ func (t *Transition) Create2(
code []byte,
value *big.Int,
gas uint64,
initialAccessList *runtime.AccessList,
) *runtime.ExecutionResult {
address := crypto.CreateAddress(caller, t.state.GetNonce(caller))
contract := runtime.NewContractCreation(1, caller, caller, address, value, gas, code, initialAccessList)
contract := runtime.NewContractCreation(1, caller, caller, address, value, gas, code)

return t.applyCreate(contract, t)
}
Expand All @@ -704,9 +716,8 @@ func (t *Transition) Call2(
input []byte,
value *big.Int,
gas uint64,
initialAccessList *runtime.AccessList,
) *runtime.ExecutionResult {
c := runtime.NewContractCall(1, caller, caller, to, value, gas, t.state.GetCode(to), input, initialAccessList)
c := runtime.NewContractCall(1, caller, caller, to, value, gas, t.state.GetCode(to), input)

return t.applyCall(c, runtime.Call, t)
}
Expand Down Expand Up @@ -795,7 +806,7 @@ func (t *Transition) applyCall(
}
}

snapshot := t.state.Snapshot()
snapshot := t.Snapshot()
t.state.TouchAccount(c.Address)

if callType == runtime.Call {
Expand All @@ -814,9 +825,7 @@ func (t *Transition) applyCall(

result = t.run(c, host)
if result.Failed() {
c.RevertJournal()

if err := t.state.RevertToSnapshot(snapshot); err != nil {
if err := t.RevertToSnapshot(snapshot); err != nil {
return &runtime.ExecutionResult{
GasLeft: c.Gas,
Err: err,
Expand Down Expand Up @@ -861,8 +870,7 @@ func (t *Transition) applyCreate(c *runtime.Contract, host runtime.Host) *runtim
// we add this to the access-list before taking a snapshot. Even if the creation fails,
// the access-list change should not be rolled back according to EIP2929 specs
if t.config.Berlin {
c.AddToJournal(&runtime.AccessListAddAccountChange{Address: c.Address})
c.AccessList.AddAddress(c.Address)
t.AddAddressToAccessList(c.Address)
}

// Check if there is a collision and the address already exists
Expand All @@ -874,7 +882,7 @@ func (t *Transition) applyCreate(c *runtime.Contract, host runtime.Host) *runtim
}

// Take snapshot of the current state
snapshot := t.state.Snapshot()
snapshot := t.Snapshot()

if t.config.EIP158 {
// Force the creation of the account
Expand Down Expand Up @@ -937,9 +945,7 @@ func (t *Transition) applyCreate(c *runtime.Contract, host runtime.Host) *runtim

result = t.run(c, host)
if result.Failed() {
c.RevertJournal()

if err := t.state.RevertToSnapshot(snapshot); err != nil {
if err := t.RevertToSnapshot(snapshot); err != nil {
return &runtime.ExecutionResult{
Err: err,
}
Expand All @@ -950,7 +956,7 @@ func (t *Transition) applyCreate(c *runtime.Contract, host runtime.Host) *runtim

if t.config.EIP158 && len(result.ReturnValue) > SpuriousDragonMaxCodeSize {
// Contract size exceeds 'SpuriousDragon' size limit
if err := t.state.RevertToSnapshot(snapshot); err != nil {
if err := t.RevertToSnapshot(snapshot); err != nil {
return &runtime.ExecutionResult{
Err: err,
}
Expand All @@ -964,7 +970,7 @@ func (t *Transition) applyCreate(c *runtime.Contract, host runtime.Host) *runtim

// Reject code starting with 0xEF if EIP-3541 is enabled.
if result.Err == nil && len(result.ReturnValue) >= 1 && result.ReturnValue[0] == 0xEF && t.config.London {
if err := t.state.RevertToSnapshot(snapshot); err != nil {
if err := t.RevertToSnapshot(snapshot); err != nil {
return &runtime.ExecutionResult{
Err: err,
}
Expand All @@ -984,7 +990,7 @@ func (t *Transition) applyCreate(c *runtime.Contract, host runtime.Host) *runtim

// Out of gas creating the contract
if t.config.Homestead {
if err := t.state.RevertToSnapshot(snapshot); err != nil {
if err := t.RevertToSnapshot(snapshot); err != nil {
return &runtime.ExecutionResult{
Err: err,
}
Expand Down Expand Up @@ -1298,3 +1304,63 @@ func (t *Transition) captureCallEnd(c *runtime.Contract, result *runtime.Executi
result.Err,
)
}

func (t *Transition) AddToJournal(j runtime.JournalEntry) {
t.journal.Append(j)
}

func (t *Transition) Snapshot() int {
snapshot := t.state.Snapshot()
t.journalRevisions = append(t.journalRevisions, runtime.JournalRevision{ID: snapshot, Index: t.journal.Len()})

return snapshot
}

func (t *Transition) RevertToSnapshot(snapshot int) error {
if err := t.state.RevertToSnapshot(snapshot); err != nil {
return err
}

// Find the snapshot in the stack of valid snapshots.
idx := sort.Search(len(t.journalRevisions), func(i int) bool {
return t.journalRevisions[i].ID >= snapshot
})

if idx == len(t.journalRevisions) || t.journalRevisions[idx].ID != snapshot {
panic(fmt.Errorf("journal revision id %v cannot be reverted", snapshot))
}

journalIndex := t.journalRevisions[idx].Index

// Replay the journal to undo changes and remove invalidated snapshots
t.journal.Revert(t, journalIndex)
t.journalRevisions = t.journalRevisions[:idx]

return nil
}

func (t *Transition) AddSlotToAccessList(addr types.Address, slot types.Hash) {
t.journal.Append(&runtime.AccessListAddSlotChange{Address: addr, Slot: slot})
t.accessList.AddSlot(addr, slot)
}

func (t *Transition) AddAddressToAccessList(addr types.Address) {
t.journal.Append(&runtime.AccessListAddAccountChange{Address: addr})
t.accessList.AddAddress(addr)
}

func (t *Transition) ContainsAccessListAddress(addr types.Address) bool {
return t.accessList.ContainsAddress(addr)
}

func (t *Transition) ContainsAccessListSlot(addr types.Address, slot types.Hash) (bool, bool) {
return t.accessList.Contains(addr, slot)
}

func (t *Transition) DeleteAccessListAddress(addr types.Address) {
t.accessList.DeleteAddress(addr)
}

func (t *Transition) DeleteAccessListSlot(addr types.Address, slot types.Hash) {
t.accessList.DeleteSlot(addr, slot)
}
3 changes: 2 additions & 1 deletion state/executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,9 @@ func Test_Transition_EIP2929(t *testing.T) {
transition := NewTransition(enabledForks, state, txn)
initialAccessList := runtime.NewAccessList()
initialAccessList.PrepareAccessList(transition.ctx.Origin, &addr, transition.precompiles.Addrs, nil)
transition.accessList = initialAccessList

result := transition.Call2(transition.ctx.Origin, addr, nil, big.NewInt(0), uint64(1000000), initialAccessList)
result := transition.Call2(transition.ctx.Origin, addr, nil, big.NewInt(0), uint64(1000000))
assert.Equal(t, tt.gasConsumed, result.GasUsed, "Gas consumption for %s is inaccurate according to EIP 2929", tt.name)
})
}
Expand Down
29 changes: 8 additions & 21 deletions state/runtime/evm/instructions.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,11 @@ var (

func (c *state) calculateGasForEIP2929(addr types.Address) uint64 {
var gas uint64
if c.msg.AccessList.ContainsAddress(addr) {
if c.host.ContainsAccessListAddress(addr) {
gas = WarmStorageReadCostEIP2929
} else {
gas = ColdAccountAccessCostEIP2929
c.msg.AddToJournal(&runtime.AccessListAddAccountChange{Address: addr})
c.msg.AccessList.AddAddress(addr)
c.host.AddAddressToAccessList(addr)
}

return gas
Expand Down Expand Up @@ -486,10 +485,10 @@ func opSload(c *state) {
var gas uint64

if c.config.Berlin {
if _, slotPresent := c.msg.AccessList.Contains(c.msg.Address, bigToHash(loc)); !slotPresent {
if _, slotPresent := c.host.ContainsAccessListSlot(c.msg.Address, bigToHash(loc)); !slotPresent {
gas = ColdStorageReadCostEIP2929

c.addAccessListSlot(c.msg.Address, bigToHash(loc))
c.host.AddSlotToAccessList(c.msg.Address, bigToHash(loc))
} else {
gas = WarmStorageReadCostEIP2929
}
Expand Down Expand Up @@ -532,10 +531,10 @@ func opSStore(c *state) {
cost := uint64(0)

if c.config.Berlin {
if _, slotPresent := c.msg.AccessList.Contains(c.msg.Address, key); !slotPresent {
if _, slotPresent := c.host.ContainsAccessListSlot(c.msg.Address, key); !slotPresent {
cost = ColdStorageReadCostEIP2929

c.addAccessListSlot(c.msg.Address, key)
c.host.AddSlotToAccessList(c.msg.Address, key)
}
}

Expand Down Expand Up @@ -1005,10 +1004,10 @@ func opSelfDestruct(c *state) {
}

// EIP 2929 gas
if c.config.Berlin && !c.msg.AccessList.ContainsAddress(address) {
if c.config.Berlin && !c.host.ContainsAccessListAddress(address) {
gas += ColdAccountAccessCostEIP2929

c.addAccessListAddress(address)
c.host.AddAddressToAccessList(address)
}

if !c.consumeGas(gas) {
Expand Down Expand Up @@ -1379,7 +1378,6 @@ func (c *state) buildCallContract(op OpCode) (*runtime.Contract, uint64, uint64,
gas,
c.host.GetCode(addr),
args,
c.msg.AccessList,
)

if op == STATICCALL || parent.msg.Static {
Expand Down Expand Up @@ -1479,22 +1477,11 @@ func (c *state) buildCreateContract(op OpCode) (*runtime.Contract, error) {
value,
gas,
input,
c.msg.AccessList,
)

return contract, nil
}

func (c *state) addAccessListSlot(address types.Address, slot types.Hash) {
c.msg.AddToJournal(&runtime.AccessListAddSlotChange{Address: address, Slot: slot})
c.msg.AccessList.AddSlot(address, slot)
}

func (c *state) addAccessListAddress(address types.Address) {
c.msg.AddToJournal(&runtime.AccessListAddAccountChange{Address: address})
c.msg.AccessList.AddAddress(address)
}

func opHalt(op OpCode) instruction {
return func(c *state) {
if op == REVERT && !c.config.Byzantium {
Expand Down
28 changes: 19 additions & 9 deletions state/runtime/journal.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,13 @@ import (
"github.com/0xPolygon/polygon-edge/types"
)

type JournalRevision struct {
ID int
Index int
}

type JournalEntry interface {
Revert(c *Contract)
Revert(host Host)
}

type Journal struct {
Expand All @@ -16,12 +21,17 @@ func (j *Journal) Append(entry JournalEntry) {
j.entries = append(j.entries, entry)
}

func (j *Journal) Revert(c *Contract) {
for i := len(j.entries) - 1; i >= 0; i-- {
j.entries[i].Revert(c)
func (j *Journal) Revert(host Host, snapshot int) {
for i := len(j.entries) - 1; i >= snapshot; i-- {
// Undo the changes made by the operation
j.entries[i].Revert(host)
}

j.entries = j.entries[:0]
j.entries = j.entries[:snapshot]
}

func (j *Journal) Len() int {
return len(j.entries)
}

type (
Expand All @@ -36,12 +46,12 @@ type (

var _ JournalEntry = (*AccessListAddAccountChange)(nil)

func (ch AccessListAddAccountChange) Revert(c *Contract) {
c.AccessList.DeleteAddress(ch.Address)
func (ch AccessListAddAccountChange) Revert(host Host) {
host.DeleteAccessListAddress(ch.Address)
}

var _ JournalEntry = (*AccessListAddSlotChange)(nil)

func (ch AccessListAddSlotChange) Revert(c *Contract) {
c.AccessList.DeleteSlot(ch.Address, ch.Slot)
func (ch AccessListAddSlotChange) Revert(host Host) {
host.DeleteAccessListSlot(ch.Address, ch.Slot)
}
Loading

0 comments on commit 5d94da1

Please sign in to comment.