Skip to content

Commit

Permalink
Reuse existing EVM instance in interop calls (sei-protocol#1731)
Browse files Browse the repository at this point in the history
* Reuse existing EVM instance in interop calls

* rebase
  • Loading branch information
codchen authored Jul 22, 2024
1 parent bdf72f2 commit f285e6f
Show file tree
Hide file tree
Showing 13 changed files with 152 additions and 51 deletions.
6 changes: 5 additions & 1 deletion evmrpc/simulate.go
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,11 @@ func (b *Backend) GetEVM(_ context.Context, msg *core.Message, stateDB vm.StateD
if blockCtx == nil {
blockCtx, _ = b.keeper.GetVMBlockContext(b.ctxProvider(LatestCtxHeight), core.GasPool(b.RPCGasCap()))
}
return vm.NewEVM(*blockCtx, txContext, stateDB, b.ChainConfig(), *vmConfig)
evm := vm.NewEVM(*blockCtx, txContext, stateDB, b.ChainConfig(), *vmConfig)
if dbImpl, ok := stateDB.(*state.DBImpl); ok {
dbImpl.SetEVM(evm)
}
return evm
}

func (b *Backend) CurrentHeader() *ethtypes.Header {
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ require (
)

replace (
github.com/CosmWasm/wasmd => github.com/sei-protocol/sei-wasmd v0.2.0
github.com/CosmWasm/wasmd => github.com/sei-protocol/sei-wasmd v0.2.1
github.com/confio/ics23/go => github.com/cosmos/cosmos-sdk/ics23/go v0.8.0
github.com/cosmos/cosmos-sdk => github.com/sei-protocol/sei-cosmos v0.3.26
github.com/cosmos/iavl => github.com/sei-protocol/sei-iavl v0.1.9
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -1359,8 +1359,8 @@ github.com/sei-protocol/sei-tendermint v0.3.4 h1:pAMXB2Cd0/rmmEkPgcEdIEjw7k64K7+
github.com/sei-protocol/sei-tendermint v0.3.4/go.mod h1:4LSlJdhl3nf3OmohliwRNUFLOB1XWlrmSodrIP7fLh4=
github.com/sei-protocol/sei-tm-db v0.0.5 h1:3WONKdSXEqdZZeLuWYfK5hP37TJpfaUa13vAyAlvaQY=
github.com/sei-protocol/sei-tm-db v0.0.5/go.mod h1:Cpa6rGyczgthq7/0pI31jys2Fw0Nfrc+/jKdP1prVqY=
github.com/sei-protocol/sei-wasmd v0.2.0 h1:DiR5u7ZRtRKMYjvGPsH+/nMnJAprcFovbaITLf1Et0Y=
github.com/sei-protocol/sei-wasmd v0.2.0/go.mod h1:EnQkqvUA3tYpdgXjqatHK8ym9LCm1z+lM7XMqR9SA3o=
github.com/sei-protocol/sei-wasmd v0.2.1 h1:2COAeomO22CAGQRTnAd3I1I3b4UhraQuV6Y0PASejc8=
github.com/sei-protocol/sei-wasmd v0.2.1/go.mod h1:EnQkqvUA3tYpdgXjqatHK8ym9LCm1z+lM7XMqR9SA3o=
github.com/sei-protocol/tm-db v0.0.4 h1:7Y4EU62Xzzg6wKAHEotm7SXQR0aPLcGhKHkh3qd0tnk=
github.com/sei-protocol/tm-db v0.0.4/go.mod h1:PWsIWOTwdwC7Ow/GUvx8HgUJTO691pBuorIQD8JvwAs=
github.com/sergi/go-diff v1.1.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM=
Expand Down
7 changes: 3 additions & 4 deletions wasmbinding/message_plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package wasmbinding
import (
wasmkeeper "github.com/CosmWasm/wasmd/x/wasm/keeper"
wasmtypes "github.com/CosmWasm/wasmd/x/wasm/types"
"github.com/cosmos/cosmos-sdk/baseapp"
codectypes "github.com/cosmos/cosmos-sdk/codec/types"
sdk "github.com/cosmos/cosmos-sdk/types"
sdkacltypes "github.com/cosmos/cosmos-sdk/types/accesscontrol"
Expand All @@ -19,14 +18,14 @@ type CustomRouter struct {
evmKeeper *evmkeeper.Keeper
}

func (r *CustomRouter) Handler(msg sdk.Msg) baseapp.MsgServiceHandler {
func (r *CustomRouter) Handler(msg sdk.Msg) wasmkeeper.MsgHandler {
switch m := msg.(type) {
case *evmtypes.MsgInternalEVMCall:
return func(ctx sdk.Context, _ sdk.Msg) (*sdk.Result, error) {
return func(ctx sdk.Context, _ sdk.Msg) (sdk.Context, *sdk.Result, error) {
return r.evmKeeper.HandleInternalEVMCall(ctx, m)
}
case *evmtypes.MsgInternalEVMDelegateCall:
return func(ctx sdk.Context, _ sdk.Msg) (*sdk.Result, error) {
return func(ctx sdk.Context, _ sdk.Msg) (sdk.Context, *sdk.Result, error) {
return r.evmKeeper.HandleInternalEVMDelegateCall(ctx, m)
}
default:
Expand Down
19 changes: 17 additions & 2 deletions wasmbinding/wasm.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ import (
"github.com/CosmWasm/wasmd/x/wasm"
wasmkeeper "github.com/CosmWasm/wasmd/x/wasm/keeper"
wasmtypes "github.com/CosmWasm/wasmd/x/wasm/types"
"github.com/cosmos/cosmos-sdk/baseapp"
codectypes "github.com/cosmos/cosmos-sdk/codec/types"
sdk "github.com/cosmos/cosmos-sdk/types"
aclkeeper "github.com/cosmos/cosmos-sdk/x/accesscontrol/keeper"
authkeeper "github.com/cosmos/cosmos-sdk/x/auth/keeper"
epochwasm "github.com/sei-protocol/sei-chain/x/epoch/client/wasm"
Expand All @@ -17,12 +19,24 @@ import (
tokenfactorykeeper "github.com/sei-protocol/sei-chain/x/tokenfactory/keeper"
)

type routerWithContext struct {
router *baseapp.MsgServiceRouter
}

func (rc routerWithContext) Handler(msg sdk.Msg) wasmkeeper.MsgHandler {
h := rc.router.Handler(msg)
return func(ctx sdk.Context, req sdk.Msg) (sdk.Context, *sdk.Result, error) {
result, err := h(ctx, msg)
return ctx, result, err
}
}

func RegisterCustomPlugins(
oracle *oraclekeeper.Keeper,
epoch *epochkeeper.Keeper,
tokenfactory *tokenfactorykeeper.Keeper,
_ *authkeeper.AccountKeeper,
router wasmkeeper.MessageRouter,
router *baseapp.MsgServiceRouter,
channelKeeper wasmtypes.ChannelKeeper,
capabilityKeeper wasmtypes.CapabilityKeeper,
bankKeeper wasmtypes.Burner,
Expand All @@ -40,8 +54,9 @@ func RegisterCustomPlugins(
queryPluginOpt := wasmkeeper.WithQueryPlugins(&wasmkeeper.QueryPlugins{
Custom: CustomQuerier(wasmQueryPlugin),
})
routerWithCtx := routerWithContext{router}
messengerHandlerOpt := wasmkeeper.WithMessageHandler(
CustomMessageHandler(router, channelKeeper, capabilityKeeper, bankKeeper, evmKeeper, unpacker, portSource, aclKeeper),
CustomMessageHandler(routerWithCtx, channelKeeper, capabilityKeeper, bankKeeper, evmKeeper, unpacker, portSource, aclKeeper),
)

return []wasm.Option{
Expand Down
1 change: 1 addition & 0 deletions x/evm/ante/fee.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ func (fc EVMFeeCheckDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simulate b
cfg := evmtypes.DefaultChainConfig().EthereumConfig(fc.evmKeeper.ChainID(ctx))
txCtx := core.NewEVMTxContext(emsg)
evmInstance := vm.NewEVM(*blockCtx, txCtx, stateDB, cfg, vm.Config{})
stateDB.SetEVM(evmInstance)
st := core.NewStateTransition(evmInstance, emsg, &gp, true)
// run stateless checks before charging gas (mimicking Geth behavior)
if !ctx.IsCheckTx() && !ctx.IsReCheckTx() {
Expand Down
67 changes: 43 additions & 24 deletions x/evm/keeper/evm.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,69 +23,87 @@ type EVMCallFunc func(caller vm.ContractRef, addr *common.Address, input []byte,

var MaxUint64BigInt = new(big.Int).SetUint64(math.MaxUint64)

func (k *Keeper) HandleInternalEVMCall(ctx sdk.Context, req *types.MsgInternalEVMCall) (*sdk.Result, error) {
func (k *Keeper) HandleInternalEVMCall(ctx sdk.Context, req *types.MsgInternalEVMCall) (sdk.Context, *sdk.Result, error) {
var to *common.Address
if req.To != "" {
addr := common.HexToAddress(req.To)
to = &addr
}
senderAddr, err := sdk.AccAddressFromBech32(req.Sender)
if err != nil {
return nil, err
return ctx, nil, err
}
ret, err := k.CallEVM(ctx, k.GetEVMAddressOrDefault(ctx, senderAddr), to, req.Value, req.Data)
retctx, ret, err := k.CallEVM(ctx, k.GetEVMAddressOrDefault(ctx, senderAddr), to, req.Value, req.Data)
if err != nil {
return nil, err
return ctx, nil, err
}
return &sdk.Result{Data: ret}, nil
return retctx, &sdk.Result{Data: ret}, nil
}

func (k *Keeper) HandleInternalEVMDelegateCall(ctx sdk.Context, req *types.MsgInternalEVMDelegateCall) (*sdk.Result, error) {
func (k *Keeper) HandleInternalEVMDelegateCall(ctx sdk.Context, req *types.MsgInternalEVMDelegateCall) (sdk.Context, *sdk.Result, error) {
var to *common.Address
if req.To != "" {
addr := common.HexToAddress(req.To)
to = &addr
} else {
return nil, errors.New("cannot use a CosmWasm contract to delegate-create an EVM contract")
return ctx, nil, errors.New("cannot use a CosmWasm contract to delegate-create an EVM contract")
}
addr, _, exists := k.GetPointerInfo(ctx, types.PointerReverseRegistryKey(common.BytesToAddress([]byte(req.FromContract))))
if !exists || common.BytesToAddress(addr).Cmp(*to) != 0 {
return nil, errors.New("only pointer contract can make delegatecalls")
return ctx, nil, errors.New("only pointer contract can make delegatecalls")
}
zeroInt := sdk.ZeroInt()
senderAddr, err := sdk.AccAddressFromBech32(req.Sender)
if err != nil {
return nil, err
return ctx, nil, err
}
// delegatecall caller must be associated; otherwise any state change on EVM contract will be lost
// after they asssociate.
senderEvmAddr, found := k.GetEVMAddress(ctx, senderAddr)
if !found {
err := types.NewAssociationMissingErr(req.Sender)
metrics.IncrementAssociationError("evm_handle_internal_evm_delegate_call", err)
return nil, err
return ctx, nil, err
}
ret, err := k.CallEVM(ctx, senderEvmAddr, to, &zeroInt, req.Data)
retctx, ret, err := k.CallEVM(ctx, senderEvmAddr, to, &zeroInt, req.Data)
if err != nil {
return nil, err
return ctx, nil, err
}
return &sdk.Result{Data: ret}, nil
return retctx, &sdk.Result{Data: ret}, nil
}

func (k *Keeper) CallEVM(ctx sdk.Context, from common.Address, to *common.Address, val *sdk.Int, data []byte) (retdata []byte, reterr error) {
if ctx.IsEVM() {
return nil, errors.New("sei does not support EVM->CW->EVM call pattern")
}
func (k *Keeper) CallEVM(ctx sdk.Context, from common.Address, to *common.Address, val *sdk.Int, data []byte) (retctx sdk.Context, retdata []byte, reterr error) {
if to == nil && len(data) > params.MaxInitCodeSize {
return nil, fmt.Errorf("%w: code size %v, limit %v", core.ErrMaxInitCodeSizeExceeded, len(data), params.MaxInitCodeSize)
return ctx, nil, fmt.Errorf("%w: code size %v, limit %v", core.ErrMaxInitCodeSizeExceeded, len(data), params.MaxInitCodeSize)
}
value := utils.Big0
if val != nil {
if val.IsNegative() {
return nil, sdkerrors.ErrInvalidCoins
return ctx, nil, sdkerrors.ErrInvalidCoins
}
value = val.BigInt()
}
evm := types.GetCtxEVM(ctx)
if evm != nil {
// This call is part of an existing StateTransition, so directly invoking `Call`
var f EVMCallFunc
if to == nil {
// contract creation
f = func(caller vm.ContractRef, _ *common.Address, input []byte, gas uint64, value *big.Int) ([]byte, uint64, error) {
ret, _, leftoverGas, err := evm.Create(caller, input, gas, value)
return ret, leftoverGas, err
}
} else {
f = func(caller vm.ContractRef, addr *common.Address, input []byte, gas uint64, value *big.Int) ([]byte, uint64, error) {
return evm.Call(caller, *addr, input, gas, value)
}
}
ret, err := k.callEVM(ctx, from, to, val, data, f)
if err != nil {
return ctx, ret, err
}
return evm.StateDB.(*state.DBImpl).Ctx(), ret, err
}
// This call was not part of an existing StateTransition, so it should trigger one
executionCtx := ctx.WithGasMeter(sdk.NewInfiniteGasMeterWithMultiplier(ctx)).WithIsEVM(true)
stateDB := state.NewDBImpl(executionCtx, k, false)
Expand All @@ -104,28 +122,29 @@ func (k *Keeper) CallEVM(ctx sdk.Context, from common.Address, to *common.Addres
}
res, err := k.applyEVMMessage(ctx, evmMsg, stateDB, gp)
if err != nil {
return nil, err
return ctx, nil, err
}
k.consumeEvmGas(ctx, res.UsedGas)
if res.Err != nil {
return nil, res.Err
return ctx, nil, res.Err
}
surplus, err := stateDB.Finalize()
if err != nil {
return nil, err
return ctx, nil, err
}
vmErr := ""
if res.Err != nil {
vmErr = res.Err.Error()
}
receipt, err := k.WriteReceipt(ctx, stateDB, evmMsg, ethtypes.LegacyTxType, ctx.TxSum(), res.UsedGas, vmErr)
if err != nil {
return nil, err
return ctx, nil, err
}
bloom := ethtypes.Bloom{}
bloom.SetBytes(receipt.LogsBloom)
k.AppendToEvmTxDeferredInfo(ctx, bloom, ctx.TxSum(), surplus)
return res.ReturnData, nil
ctx.EVMEventManager().EmitEvents(stateDB.GetAllLogs())
return stateDB.Ctx(), res.ReturnData, nil
}

func (k *Keeper) StaticCallEVM(ctx sdk.Context, from sdk.AccAddress, to *common.Address, data []byte) ([]byte, error) {
Expand Down
25 changes: 10 additions & 15 deletions x/evm/keeper/evm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,8 @@ func TestInternalCallCreateContract(t *testing.T) {
Sender: testAddr.String(),
Data: contractData,
}
// circular interop call
ctx = ctx.WithIsEVM(true)
_, err = k.HandleInternalEVMCall(ctx, req)
require.Equal(t, "sei does not support EVM->CW->EVM call pattern", err.Error())
ctx = ctx.WithIsEVM(false)
_, err = k.HandleInternalEVMCall(ctx, req)
_, _, err = k.HandleInternalEVMCall(ctx, req)
require.Nil(t, err)
receipt, err := k.GetTransientReceipt(ctx, [32]byte{1, 2, 3})
require.Nil(t, err)
Expand All @@ -65,17 +61,16 @@ func TestInternalCall(t *testing.T) {
Sender: testAddr.String(),
Data: contractData,
}
ctx = ctx.WithIsEVM(true)
_, err = k.HandleInternalEVMCall(ctx, req)
require.Equal(t, "sei does not support EVM->CW->EVM call pattern", err.Error())
ctx = ctx.WithIsEVM(false)
ret, err := k.HandleInternalEVMCall(ctx, req)
resCtx, ret, err := k.HandleInternalEVMCall(ctx, req)
require.Nil(t, err)
contractAddr := crypto.CreateAddress(senderEvmAddr, 0)
require.NotEmpty(t, k.GetCode(ctx, contractAddr))
require.Equal(t, ret.Data, k.GetCode(ctx, contractAddr))
k.SetERC20NativePointer(ctx, "test", contractAddr)

ctx = resCtx
require.NotNil(t, types.GetCtxEVM(ctx))
receiverAddr, evmAddr := testkeeper.MockAddressPair()
k.SetAddressMapping(ctx, receiverAddr, evmAddr)
args, err = abi.Pack("transfer", evmAddr, big.NewInt(1000))
Expand All @@ -89,9 +84,9 @@ func TestInternalCall(t *testing.T) {
Data: args,
Value: &val,
}
_, err = k.HandleInternalEVMCall(ctx, req)
resCtx, _, err = k.HandleInternalEVMCall(ctx, req)
require.Nil(t, err)
require.Equal(t, int64(1000), testkeeper.EVMTestApp.BankKeeper.GetBalance(ctx, receiverAddr, "test").Amount.Int64())
require.Equal(t, int64(1000), testkeeper.EVMTestApp.BankKeeper.GetBalance(resCtx, receiverAddr, "test").Amount.Int64())
}

func TestStaticCall(t *testing.T) {
Expand All @@ -113,7 +108,7 @@ func TestStaticCall(t *testing.T) {
Sender: testAddr.String(),
Data: contractData,
}
ret, err := k.HandleInternalEVMCall(ctx, req)
_, ret, err := k.HandleInternalEVMCall(ctx, req)
require.Nil(t, err)
contractAddr := crypto.CreateAddress(senderEvmAddr, 0)
require.NotEmpty(t, k.GetCode(ctx, contractAddr))
Expand Down Expand Up @@ -163,7 +158,7 @@ func TestNegativeTransfer(t *testing.T) {
require.Zero(t, preAttackerBal)
require.Equal(t, steal_amount, preVictimBal)

_, err := k.HandleInternalEVMCall(ctx, req)
_, _, err := k.HandleInternalEVMCall(ctx, req)
require.ErrorContains(t, err, "invalid coins")

// post verification
Expand All @@ -179,7 +174,7 @@ func TestNegativeTransfer(t *testing.T) {
Value: &zeroVal,
}

_, err = k.HandleInternalEVMCall(ctx, req2)
_, _, err = k.HandleInternalEVMCall(ctx, req2)
require.ErrorContains(t, err, "max initcode size exceeded")
}

Expand All @@ -205,6 +200,6 @@ func TestHandleInternalEVMDelegateCall_AssociationError(t *testing.T) {
FromContract: string(contractAddr.Bytes()),
To: castedAddr.Hex(),
}
_, err := k.HandleInternalEVMDelegateCall(ctx, req)
_, _, err := k.HandleInternalEVMDelegateCall(ctx, req)
require.Equal(t, err.Error(), types.NewAssociationMissingErr(testAddr.String()).Error())
}
1 change: 1 addition & 0 deletions x/evm/keeper/msg_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ func (k Keeper) applyEVMMessage(ctx sdk.Context, msg *core.Message, stateDB *sta
cfg := types.DefaultChainConfig().EthereumConfig(k.ChainID(ctx))
txCtx := core.NewEVMTxContext(msg)
evmInstance := vm.NewEVM(*blockCtx, txCtx, stateDB, cfg, vm.Config{})
stateDB.SetEVM(evmInstance)
st := core.NewStateTransition(evmInstance, msg, &gp, true) // fee already charged in ante handler
res, err := st.TransitionDb()
return res, err
Expand Down
14 changes: 14 additions & 0 deletions x/evm/state/state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/tracing"
"github.com/ethereum/go-ethereum/core/vm"
testkeeper "github.com/sei-protocol/sei-chain/testutil/keeper"
"github.com/sei-protocol/sei-chain/x/evm/state"
"github.com/sei-protocol/sei-chain/x/evm/types"
Expand Down Expand Up @@ -166,3 +167,16 @@ func TestSnapshot(t *testing.T) {
require.Equal(t, common.Hash{}, newStateDB.GetTransientState(evmAddr, tkey))
require.Equal(t, val, newStateDB.GetState(evmAddr, key))
}

func TestSetEVM(t *testing.T) {
k, ctx := testkeeper.MockEVMKeeper()
statedb := state.NewDBImpl(ctx, k, false)
rev1 := statedb.Snapshot()
statedb.SetEVM(&vm.EVM{})
rev2 := statedb.Snapshot()
require.NotNil(t, types.GetCtxEVM(statedb.Ctx()))
statedb.RevertToSnapshot(rev2)
require.NotNil(t, types.GetCtxEVM(statedb.Ctx()))
statedb.RevertToSnapshot(rev1)
require.NotNil(t, types.GetCtxEVM(statedb.Ctx()))
}
Loading

0 comments on commit f285e6f

Please sign in to comment.