Skip to content

Commit

Permalink
use random execID
Browse files Browse the repository at this point in the history
  • Loading branch information
beer-1 committed Sep 5, 2024
1 parent 82eae61 commit 5389cc9
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 34 deletions.
2 changes: 1 addition & 1 deletion x/evm/keeper/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func (k Keeper) NewStateDB(ctx context.Context, evm callableEVM, feeContrect com
sdk.UnwrapSDKContext(ctx).WithGasMeter(storetypes.NewInfiniteGasMeter()), k.Logger(ctx),
k.accountKeeper, k.VMStore, k.TransientVMStore, k.TransientCreated,
k.TransientSelfDestruct, k.TransientLogs, k.TransientLogSize,
k.TransientAccessList, k.TransientRefund, k.TransientExecIndex,
k.TransientAccessList, k.TransientRefund,
evm, k.ERC20Keeper().GetERC20ABI(), feeContrect,
)
}
Expand Down
2 changes: 0 additions & 2 deletions x/evm/keeper/keeper.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ type Keeper struct {
TransientCreated collections.KeySet[collections.Pair[uint64, []byte]]
TransientSelfDestruct collections.KeySet[collections.Pair[uint64, []byte]]
TransientAccessList collections.KeySet[collections.Pair[uint64, []byte]]
TransientExecIndex collections.Sequence

// erc20 stores of users
ERC20FactoryAddr collections.Item[[]byte]
Expand Down Expand Up @@ -111,7 +110,6 @@ func NewKeeper(
TransientVMStore: collections.NewMap(tsb, types.TransientVMStorePrefix, "transient_vm_store", collections.PairKeyCodec(collections.Uint64Key, collections.BytesKey), collections.BytesValue),
TransientCreated: collections.NewKeySet(tsb, types.TransientCreatedPrefix, "transient_created", collections.PairKeyCodec(collections.Uint64Key, collections.BytesKey)),
TransientSelfDestruct: collections.NewKeySet(tsb, types.TransientSelfDestructPrefix, "transient_self_destruct", collections.PairKeyCodec(collections.Uint64Key, collections.BytesKey)),
TransientExecIndex: collections.NewSequence(tsb, types.TransientExecIndexPrefix, "transient_exec_index"),
TransientLogs: collections.NewMap(tsb, types.TransientLogsPrefix, "transient_logs", collections.PairKeyCodec(collections.Uint64Key, collections.Uint64Key), codec.CollValue[types.Log](cdc)),
TransientLogSize: collections.NewMap(tsb, types.TransientLogSizePrefix, "transient_log_size", collections.Uint64Key, collections.Uint64Value),
TransientAccessList: collections.NewKeySet(tsb, types.TransientAccessListPrefix, "transient_access_list", collections.PairKeyCodec(collections.Uint64Key, collections.BytesKey)),
Expand Down
60 changes: 29 additions & 31 deletions x/evm/state/statedb.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"math/big"
"math/rand"

"github.com/holiman/uint256"

Expand Down Expand Up @@ -45,7 +46,7 @@ type StateDB struct {
transientLogSize collections.Map[uint64, uint64]
transientAccessList collections.KeySet[collections.Pair[uint64, []byte]]
transientRefund collections.Map[uint64, uint64]
execIndex uint64
execID uint64

evm callableEVM
erc20ABI *abi.ABI
Expand All @@ -72,21 +73,18 @@ func NewStateDB(
transientLogSize collections.Map[uint64, uint64],
transientAccessList collections.KeySet[collections.Pair[uint64, []byte]],
transientRefund collections.Map[uint64, uint64],
transientExecIndexStore collections.Sequence,
// erc20 params
evm callableEVM,
erc20ABI *abi.ABI,
feeContractAddr common.Address,
) (*StateDB, error) {
execIndex, err := transientExecIndexStore.Next(ctx)
// transient store is not a part of consensus, so we can use random execID
execID := rand.Uint64()
err := transientLogSize.Set(ctx, execID, 0)
if err != nil {
return nil, err
}
err = transientLogSize.Set(ctx, execIndex, 0)
if err != nil {
return nil, err
}
err = transientRefund.Set(ctx, execIndex, 0)
err = transientRefund.Set(ctx, execID, 0)
if err != nil {
return nil, err
}
Expand All @@ -105,7 +103,7 @@ func NewStateDB(
transientLogSize: transientLogSize,
transientAccessList: transientAccessList,
transientRefund: transientRefund,
execIndex: execIndex,
execID: execID,

evm: evm,
erc20ABI: erc20ABI,
Expand Down Expand Up @@ -183,20 +181,20 @@ func (s *StateDB) GetBalance(addr common.Address) *uint256.Int {

// AddRefund implements vm.StateDB.
func (s *StateDB) AddRefund(gas uint64) {
refund, err := s.transientRefund.Get(s.ctx, s.execIndex)
refund, err := s.transientRefund.Get(s.ctx, s.execID)
if err != nil {
panic(err)
}

err = s.transientRefund.Set(s.ctx, s.execIndex, refund+gas)
err = s.transientRefund.Set(s.ctx, s.execID, refund+gas)
if err != nil {
panic(err)
}
}

// SubRefund implements vm.StateDB.
func (s *StateDB) SubRefund(gas uint64) {
refund, err := s.transientRefund.Get(s.ctx, s.execIndex)
refund, err := s.transientRefund.Get(s.ctx, s.execID)
if err != nil {
panic(err)
}
Expand All @@ -205,15 +203,15 @@ func (s *StateDB) SubRefund(gas uint64) {
panic(fmt.Sprintf("Refund counter below zero (gas: %d > refund: %d)", gas, refund))
}

err = s.transientRefund.Set(s.ctx, s.execIndex, refund-gas)
err = s.transientRefund.Set(s.ctx, s.execID, refund-gas)
if err != nil {
panic(err)
}
}

// AddAddressToAccessList adds the given address to the access list
func (s *StateDB) AddAddressToAccessList(addr common.Address) {
err := s.transientAccessList.Set(s.ctx, collections.Join(s.execIndex, addr.Bytes()))
err := s.transientAccessList.Set(s.ctx, collections.Join(s.execID, addr.Bytes()))
if err != nil {
panic(err)
}
Expand All @@ -226,15 +224,15 @@ func (s *StateDB) AddSlotToAccessList(addr common.Address, slot common.Hash) {
s.AddAddressToAccessList(addr)
}

err := s.transientAccessList.Set(s.ctx, collections.Join(s.execIndex, append(addr.Bytes(), slot[:]...)))
err := s.transientAccessList.Set(s.ctx, collections.Join(s.execID, append(addr.Bytes(), slot[:]...)))
if err != nil {
panic(err)
}
}

// AddressInAccessList returns true if the given address is in the access list
func (s *StateDB) AddressInAccessList(addr common.Address) bool {
ok, err := s.transientAccessList.Has(s.ctx, collections.Join(s.execIndex, addr.Bytes()))
ok, err := s.transientAccessList.Has(s.ctx, collections.Join(s.execID, addr.Bytes()))
if err != nil {
panic(err)
}
Expand All @@ -244,14 +242,14 @@ func (s *StateDB) AddressInAccessList(addr common.Address) bool {

// SlotInAccessList returns true if the given (address, slot)-tuple is in the access list
func (s *StateDB) SlotInAccessList(addr common.Address, slot common.Hash) (addressOk bool, slotOk bool) {
ok, err := s.transientAccessList.Has(s.ctx, collections.Join(s.execIndex, addr.Bytes()))
ok, err := s.transientAccessList.Has(s.ctx, collections.Join(s.execID, addr.Bytes()))
if err != nil {
panic(err)
} else if !ok {
return false, false
}

ok, err = s.transientAccessList.Has(s.ctx, collections.Join(s.execIndex, append(addr.Bytes(), slot[:]...)))
ok, err = s.transientAccessList.Has(s.ctx, collections.Join(s.execID, append(addr.Bytes(), slot[:]...)))
if err != nil {
panic(err)
}
Expand All @@ -267,7 +265,7 @@ func (s *StateDB) CreateAccount(addr common.Address) {

// CreateContract creates a contract account with the given address
func (s *StateDB) CreateContract(contractAddr common.Address) {
if err := s.transientCreated.Set(s.ctx, collections.Join(s.execIndex, contractAddr.Bytes())); err != nil {
if err := s.transientCreated.Set(s.ctx, collections.Join(s.execID, contractAddr.Bytes())); err != nil {
panic(err)
}

Expand Down Expand Up @@ -448,7 +446,7 @@ func (s *StateDB) SetNonce(addr common.Address, nonce uint64) {

// GetRefund returns the refund
func (s *StateDB) GetRefund() uint64 {
refund, err := s.transientRefund.Get(s.ctx, s.execIndex)
refund, err := s.transientRefund.Get(s.ctx, s.execID)
if err != nil {
panic(err)
}
Expand Down Expand Up @@ -477,7 +475,7 @@ func (s *StateDB) GetState(addr common.Address, slot common.Hash) common.Hash {
func (s *StateDB) HasSelfDestructed(addr common.Address) bool {
acc := s.getAccount(addr)
if acc != nil {
ok, err := s.transientSelfDestruct.Has(s.ctx, collections.Join(s.execIndex, addr.Bytes()))
ok, err := s.transientSelfDestruct.Has(s.ctx, collections.Join(s.execID, addr.Bytes()))
if err != nil {
panic(err)
}
Expand All @@ -500,7 +498,7 @@ func (s *StateDB) SelfDestruct(addr common.Address) {
}

// mark the account as self-destructed
if err := s.transientSelfDestruct.Set(s.ctx, collections.Join(s.execIndex, addr.Bytes())); err != nil {
if err := s.transientSelfDestruct.Set(s.ctx, collections.Join(s.execID, addr.Bytes())); err != nil {
panic(err)
}

Expand All @@ -515,7 +513,7 @@ func (s *StateDB) Selfdestruct6780(addr common.Address) {
return
}

ok, err := s.transientCreated.Has(s.ctx, collections.Join(s.execIndex, addr.Bytes()))
ok, err := s.transientCreated.Has(s.ctx, collections.Join(s.execID, addr.Bytes()))
if err != nil {
panic(err)
} else if ok {
Expand All @@ -537,14 +535,14 @@ func (s *StateDB) SetTransientState(addr common.Address, key, value common.Hash)
return
}

if err := s.transientVMStore.Set(s.ctx, collections.Join(s.execIndex, key[:]), value[:]); err != nil {
if err := s.transientVMStore.Set(s.ctx, collections.Join(s.execID, key[:]), value[:]); err != nil {
panic(err)
}
}

// GetTransientState gets transient storage for a given account.
func (s *StateDB) GetTransientState(addr common.Address, key common.Hash) common.Hash {
data, err := s.transientVMStore.Get(s.ctx, collections.Join(s.execIndex, key[:]))
data, err := s.transientVMStore.Get(s.ctx, collections.Join(s.execID, key[:]))
if err != nil && errors.Is(err, collections.ErrNotFound) {
return common.Hash{}
} else if err != nil {
Expand Down Expand Up @@ -633,7 +631,7 @@ func (s *StateDB) Commit() error {
s.ctx = s.initialCtx

// clear destructed accounts
err := s.transientSelfDestruct.Walk(s.ctx, collections.NewPrefixedPairRange[uint64, []byte](s.execIndex), func(key collections.Pair[uint64, []byte]) (stop bool, err error) {
err := s.transientSelfDestruct.Walk(s.ctx, collections.NewPrefixedPairRange[uint64, []byte](s.execID), func(key collections.Pair[uint64, []byte]) (stop bool, err error) {
addr := common.BytesToAddress(key.K2())
err = s.vmStore.Clear(s.ctx, new(collections.Range[[]byte]).Prefix(addr.Bytes()))

Expand All @@ -650,32 +648,32 @@ func (s *StateDB) Commit() error {

// AddLog implements vm.StateDB.
func (s *StateDB) AddLog(log *types.Log) {
logSize, err := s.transientLogSize.Get(s.ctx, s.execIndex)
logSize, err := s.transientLogSize.Get(s.ctx, s.execID)
if err != nil {
panic(err)
}

err = s.transientLogSize.Set(s.ctx, s.execIndex, logSize+1)
err = s.transientLogSize.Set(s.ctx, s.execID, logSize+1)
if err != nil {
panic(err)
}

err = s.transientLogs.Set(s.ctx, collections.Join(s.execIndex, logSize), evmtypes.NewLog(log))
err = s.transientLogs.Set(s.ctx, collections.Join(s.execID, logSize), evmtypes.NewLog(log))
if err != nil {
panic(err)
}
}

func (s *StateDB) Logs() evmtypes.Logs {
logSize, err := s.transientLogSize.Get(s.ctx, s.execIndex)
logSize, err := s.transientLogSize.Get(s.ctx, s.execID)
if err != nil {
panic(err)
} else if logSize == 0 {
return []evmtypes.Log{}
}

logs := make([]evmtypes.Log, logSize)
err = s.transientLogs.Walk(s.ctx, collections.NewPrefixedPairRange[uint64, uint64](s.execIndex), func(key collections.Pair[uint64, uint64], log evmtypes.Log) (stop bool, err error) {
err = s.transientLogs.Walk(s.ctx, collections.NewPrefixedPairRange[uint64, uint64](s.execID), func(key collections.Pair[uint64, uint64], log evmtypes.Log) (stop bool, err error) {
logs[key.K2()] = log
return false, nil
})
Expand Down

0 comments on commit 5389cc9

Please sign in to comment.