From f285e6f43549bdba63eebde5f3662c31744a7d1e Mon Sep 17 00:00:00 2001 From: codchen Date: Mon, 22 Jul 2024 11:51:43 +0800 Subject: [PATCH] Reuse existing EVM instance in interop calls (#1731) * Reuse existing EVM instance in interop calls * rebase --- evmrpc/simulate.go | 6 +++- go.mod | 2 +- go.sum | 4 +-- wasmbinding/message_plugin.go | 7 ++-- wasmbinding/wasm.go | 19 ++++++++-- x/evm/ante/fee.go | 1 + x/evm/keeper/evm.go | 67 ++++++++++++++++++++++------------- x/evm/keeper/evm_test.go | 25 ++++++------- x/evm/keeper/msg_server.go | 1 + x/evm/state/state_test.go | 14 ++++++++ x/evm/state/statedb.go | 9 +++-- x/evm/types/context.go | 28 +++++++++++++++ x/evm/types/context_test.go | 20 +++++++++++ 13 files changed, 152 insertions(+), 51 deletions(-) create mode 100644 x/evm/types/context.go create mode 100644 x/evm/types/context_test.go diff --git a/evmrpc/simulate.go b/evmrpc/simulate.go index 17b431fdd5..8ae859a4be 100644 --- a/evmrpc/simulate.go +++ b/evmrpc/simulate.go @@ -362,7 +362,11 @@ func (b *Backend) GetEVM(_ context.Context, msg *core.Message, stateDB vm.StateD if blockCtx == nil { blockCtx, _ = b.keeper.GetVMBlockContext(b.ctxProvider(LatestCtxHeight), core.GasPool(b.RPCGasCap())) } - return vm.NewEVM(*blockCtx, txContext, stateDB, b.ChainConfig(), *vmConfig) + evm := vm.NewEVM(*blockCtx, txContext, stateDB, b.ChainConfig(), *vmConfig) + if dbImpl, ok := stateDB.(*state.DBImpl); ok { + dbImpl.SetEVM(evm) + } + return evm } func (b *Backend) CurrentHeader() *ethtypes.Header { diff --git a/go.mod b/go.mod index 59088b7806..8f9a0cc23a 100644 --- a/go.mod +++ b/go.mod @@ -344,7 +344,7 @@ require ( ) replace ( - github.com/CosmWasm/wasmd => github.com/sei-protocol/sei-wasmd v0.2.0 + github.com/CosmWasm/wasmd => github.com/sei-protocol/sei-wasmd v0.2.1 github.com/confio/ics23/go => github.com/cosmos/cosmos-sdk/ics23/go v0.8.0 github.com/cosmos/cosmos-sdk => github.com/sei-protocol/sei-cosmos v0.3.26 github.com/cosmos/iavl => github.com/sei-protocol/sei-iavl v0.1.9 diff --git a/go.sum b/go.sum index 8d13fd5101..d4571714c6 100644 --- a/go.sum +++ b/go.sum @@ -1359,8 +1359,8 @@ github.com/sei-protocol/sei-tendermint v0.3.4 h1:pAMXB2Cd0/rmmEkPgcEdIEjw7k64K7+ github.com/sei-protocol/sei-tendermint v0.3.4/go.mod h1:4LSlJdhl3nf3OmohliwRNUFLOB1XWlrmSodrIP7fLh4= github.com/sei-protocol/sei-tm-db v0.0.5 h1:3WONKdSXEqdZZeLuWYfK5hP37TJpfaUa13vAyAlvaQY= github.com/sei-protocol/sei-tm-db v0.0.5/go.mod h1:Cpa6rGyczgthq7/0pI31jys2Fw0Nfrc+/jKdP1prVqY= -github.com/sei-protocol/sei-wasmd v0.2.0 h1:DiR5u7ZRtRKMYjvGPsH+/nMnJAprcFovbaITLf1Et0Y= -github.com/sei-protocol/sei-wasmd v0.2.0/go.mod h1:EnQkqvUA3tYpdgXjqatHK8ym9LCm1z+lM7XMqR9SA3o= +github.com/sei-protocol/sei-wasmd v0.2.1 h1:2COAeomO22CAGQRTnAd3I1I3b4UhraQuV6Y0PASejc8= +github.com/sei-protocol/sei-wasmd v0.2.1/go.mod h1:EnQkqvUA3tYpdgXjqatHK8ym9LCm1z+lM7XMqR9SA3o= github.com/sei-protocol/tm-db v0.0.4 h1:7Y4EU62Xzzg6wKAHEotm7SXQR0aPLcGhKHkh3qd0tnk= github.com/sei-protocol/tm-db v0.0.4/go.mod h1:PWsIWOTwdwC7Ow/GUvx8HgUJTO691pBuorIQD8JvwAs= github.com/sergi/go-diff v1.1.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= diff --git a/wasmbinding/message_plugin.go b/wasmbinding/message_plugin.go index 8a5cf96877..cb7cd10a4a 100644 --- a/wasmbinding/message_plugin.go +++ b/wasmbinding/message_plugin.go @@ -3,7 +3,6 @@ package wasmbinding import ( wasmkeeper "github.com/CosmWasm/wasmd/x/wasm/keeper" wasmtypes "github.com/CosmWasm/wasmd/x/wasm/types" - "github.com/cosmos/cosmos-sdk/baseapp" codectypes "github.com/cosmos/cosmos-sdk/codec/types" sdk "github.com/cosmos/cosmos-sdk/types" sdkacltypes "github.com/cosmos/cosmos-sdk/types/accesscontrol" @@ -19,14 +18,14 @@ type CustomRouter struct { evmKeeper *evmkeeper.Keeper } -func (r *CustomRouter) Handler(msg sdk.Msg) baseapp.MsgServiceHandler { +func (r *CustomRouter) Handler(msg sdk.Msg) wasmkeeper.MsgHandler { switch m := msg.(type) { case *evmtypes.MsgInternalEVMCall: - return func(ctx sdk.Context, _ sdk.Msg) (*sdk.Result, error) { + return func(ctx sdk.Context, _ sdk.Msg) (sdk.Context, *sdk.Result, error) { return r.evmKeeper.HandleInternalEVMCall(ctx, m) } case *evmtypes.MsgInternalEVMDelegateCall: - return func(ctx sdk.Context, _ sdk.Msg) (*sdk.Result, error) { + return func(ctx sdk.Context, _ sdk.Msg) (sdk.Context, *sdk.Result, error) { return r.evmKeeper.HandleInternalEVMDelegateCall(ctx, m) } default: diff --git a/wasmbinding/wasm.go b/wasmbinding/wasm.go index bc2d492f26..e8c5875bc9 100644 --- a/wasmbinding/wasm.go +++ b/wasmbinding/wasm.go @@ -4,7 +4,9 @@ import ( "github.com/CosmWasm/wasmd/x/wasm" wasmkeeper "github.com/CosmWasm/wasmd/x/wasm/keeper" wasmtypes "github.com/CosmWasm/wasmd/x/wasm/types" + "github.com/cosmos/cosmos-sdk/baseapp" codectypes "github.com/cosmos/cosmos-sdk/codec/types" + sdk "github.com/cosmos/cosmos-sdk/types" aclkeeper "github.com/cosmos/cosmos-sdk/x/accesscontrol/keeper" authkeeper "github.com/cosmos/cosmos-sdk/x/auth/keeper" epochwasm "github.com/sei-protocol/sei-chain/x/epoch/client/wasm" @@ -17,12 +19,24 @@ import ( tokenfactorykeeper "github.com/sei-protocol/sei-chain/x/tokenfactory/keeper" ) +type routerWithContext struct { + router *baseapp.MsgServiceRouter +} + +func (rc routerWithContext) Handler(msg sdk.Msg) wasmkeeper.MsgHandler { + h := rc.router.Handler(msg) + return func(ctx sdk.Context, req sdk.Msg) (sdk.Context, *sdk.Result, error) { + result, err := h(ctx, msg) + return ctx, result, err + } +} + func RegisterCustomPlugins( oracle *oraclekeeper.Keeper, epoch *epochkeeper.Keeper, tokenfactory *tokenfactorykeeper.Keeper, _ *authkeeper.AccountKeeper, - router wasmkeeper.MessageRouter, + router *baseapp.MsgServiceRouter, channelKeeper wasmtypes.ChannelKeeper, capabilityKeeper wasmtypes.CapabilityKeeper, bankKeeper wasmtypes.Burner, @@ -40,8 +54,9 @@ func RegisterCustomPlugins( queryPluginOpt := wasmkeeper.WithQueryPlugins(&wasmkeeper.QueryPlugins{ Custom: CustomQuerier(wasmQueryPlugin), }) + routerWithCtx := routerWithContext{router} messengerHandlerOpt := wasmkeeper.WithMessageHandler( - CustomMessageHandler(router, channelKeeper, capabilityKeeper, bankKeeper, evmKeeper, unpacker, portSource, aclKeeper), + CustomMessageHandler(routerWithCtx, channelKeeper, capabilityKeeper, bankKeeper, evmKeeper, unpacker, portSource, aclKeeper), ) return []wasm.Option{ diff --git a/x/evm/ante/fee.go b/x/evm/ante/fee.go index 9c161d1879..7edea9d010 100644 --- a/x/evm/ante/fee.go +++ b/x/evm/ante/fee.go @@ -70,6 +70,7 @@ func (fc EVMFeeCheckDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simulate b cfg := evmtypes.DefaultChainConfig().EthereumConfig(fc.evmKeeper.ChainID(ctx)) txCtx := core.NewEVMTxContext(emsg) evmInstance := vm.NewEVM(*blockCtx, txCtx, stateDB, cfg, vm.Config{}) + stateDB.SetEVM(evmInstance) st := core.NewStateTransition(evmInstance, emsg, &gp, true) // run stateless checks before charging gas (mimicking Geth behavior) if !ctx.IsCheckTx() && !ctx.IsReCheckTx() { diff --git a/x/evm/keeper/evm.go b/x/evm/keeper/evm.go index 3892b37d5f..1510468e3b 100644 --- a/x/evm/keeper/evm.go +++ b/x/evm/keeper/evm.go @@ -23,7 +23,7 @@ type EVMCallFunc func(caller vm.ContractRef, addr *common.Address, input []byte, var MaxUint64BigInt = new(big.Int).SetUint64(math.MaxUint64) -func (k *Keeper) HandleInternalEVMCall(ctx sdk.Context, req *types.MsgInternalEVMCall) (*sdk.Result, error) { +func (k *Keeper) HandleInternalEVMCall(ctx sdk.Context, req *types.MsgInternalEVMCall) (sdk.Context, *sdk.Result, error) { var to *common.Address if req.To != "" { addr := common.HexToAddress(req.To) @@ -31,31 +31,31 @@ func (k *Keeper) HandleInternalEVMCall(ctx sdk.Context, req *types.MsgInternalEV } senderAddr, err := sdk.AccAddressFromBech32(req.Sender) if err != nil { - return nil, err + return ctx, nil, err } - ret, err := k.CallEVM(ctx, k.GetEVMAddressOrDefault(ctx, senderAddr), to, req.Value, req.Data) + retctx, ret, err := k.CallEVM(ctx, k.GetEVMAddressOrDefault(ctx, senderAddr), to, req.Value, req.Data) if err != nil { - return nil, err + return ctx, nil, err } - return &sdk.Result{Data: ret}, nil + return retctx, &sdk.Result{Data: ret}, nil } -func (k *Keeper) HandleInternalEVMDelegateCall(ctx sdk.Context, req *types.MsgInternalEVMDelegateCall) (*sdk.Result, error) { +func (k *Keeper) HandleInternalEVMDelegateCall(ctx sdk.Context, req *types.MsgInternalEVMDelegateCall) (sdk.Context, *sdk.Result, error) { var to *common.Address if req.To != "" { addr := common.HexToAddress(req.To) to = &addr } else { - return nil, errors.New("cannot use a CosmWasm contract to delegate-create an EVM contract") + return ctx, nil, errors.New("cannot use a CosmWasm contract to delegate-create an EVM contract") } addr, _, exists := k.GetPointerInfo(ctx, types.PointerReverseRegistryKey(common.BytesToAddress([]byte(req.FromContract)))) if !exists || common.BytesToAddress(addr).Cmp(*to) != 0 { - return nil, errors.New("only pointer contract can make delegatecalls") + return ctx, nil, errors.New("only pointer contract can make delegatecalls") } zeroInt := sdk.ZeroInt() senderAddr, err := sdk.AccAddressFromBech32(req.Sender) if err != nil { - return nil, err + return ctx, nil, err } // delegatecall caller must be associated; otherwise any state change on EVM contract will be lost // after they asssociate. @@ -63,29 +63,47 @@ func (k *Keeper) HandleInternalEVMDelegateCall(ctx sdk.Context, req *types.MsgIn if !found { err := types.NewAssociationMissingErr(req.Sender) metrics.IncrementAssociationError("evm_handle_internal_evm_delegate_call", err) - return nil, err + return ctx, nil, err } - ret, err := k.CallEVM(ctx, senderEvmAddr, to, &zeroInt, req.Data) + retctx, ret, err := k.CallEVM(ctx, senderEvmAddr, to, &zeroInt, req.Data) if err != nil { - return nil, err + return ctx, nil, err } - return &sdk.Result{Data: ret}, nil + return retctx, &sdk.Result{Data: ret}, nil } -func (k *Keeper) CallEVM(ctx sdk.Context, from common.Address, to *common.Address, val *sdk.Int, data []byte) (retdata []byte, reterr error) { - if ctx.IsEVM() { - return nil, errors.New("sei does not support EVM->CW->EVM call pattern") - } +func (k *Keeper) CallEVM(ctx sdk.Context, from common.Address, to *common.Address, val *sdk.Int, data []byte) (retctx sdk.Context, retdata []byte, reterr error) { if to == nil && len(data) > params.MaxInitCodeSize { - return nil, fmt.Errorf("%w: code size %v, limit %v", core.ErrMaxInitCodeSizeExceeded, len(data), params.MaxInitCodeSize) + return ctx, nil, fmt.Errorf("%w: code size %v, limit %v", core.ErrMaxInitCodeSizeExceeded, len(data), params.MaxInitCodeSize) } value := utils.Big0 if val != nil { if val.IsNegative() { - return nil, sdkerrors.ErrInvalidCoins + return ctx, nil, sdkerrors.ErrInvalidCoins } value = val.BigInt() } + evm := types.GetCtxEVM(ctx) + if evm != nil { + // This call is part of an existing StateTransition, so directly invoking `Call` + var f EVMCallFunc + if to == nil { + // contract creation + f = func(caller vm.ContractRef, _ *common.Address, input []byte, gas uint64, value *big.Int) ([]byte, uint64, error) { + ret, _, leftoverGas, err := evm.Create(caller, input, gas, value) + return ret, leftoverGas, err + } + } else { + f = func(caller vm.ContractRef, addr *common.Address, input []byte, gas uint64, value *big.Int) ([]byte, uint64, error) { + return evm.Call(caller, *addr, input, gas, value) + } + } + ret, err := k.callEVM(ctx, from, to, val, data, f) + if err != nil { + return ctx, ret, err + } + return evm.StateDB.(*state.DBImpl).Ctx(), ret, err + } // This call was not part of an existing StateTransition, so it should trigger one executionCtx := ctx.WithGasMeter(sdk.NewInfiniteGasMeterWithMultiplier(ctx)).WithIsEVM(true) stateDB := state.NewDBImpl(executionCtx, k, false) @@ -104,15 +122,15 @@ func (k *Keeper) CallEVM(ctx sdk.Context, from common.Address, to *common.Addres } res, err := k.applyEVMMessage(ctx, evmMsg, stateDB, gp) if err != nil { - return nil, err + return ctx, nil, err } k.consumeEvmGas(ctx, res.UsedGas) if res.Err != nil { - return nil, res.Err + return ctx, nil, res.Err } surplus, err := stateDB.Finalize() if err != nil { - return nil, err + return ctx, nil, err } vmErr := "" if res.Err != nil { @@ -120,12 +138,13 @@ func (k *Keeper) CallEVM(ctx sdk.Context, from common.Address, to *common.Addres } receipt, err := k.WriteReceipt(ctx, stateDB, evmMsg, ethtypes.LegacyTxType, ctx.TxSum(), res.UsedGas, vmErr) if err != nil { - return nil, err + return ctx, nil, err } bloom := ethtypes.Bloom{} bloom.SetBytes(receipt.LogsBloom) k.AppendToEvmTxDeferredInfo(ctx, bloom, ctx.TxSum(), surplus) - return res.ReturnData, nil + ctx.EVMEventManager().EmitEvents(stateDB.GetAllLogs()) + return stateDB.Ctx(), res.ReturnData, nil } func (k *Keeper) StaticCallEVM(ctx sdk.Context, from sdk.AccAddress, to *common.Address, data []byte) ([]byte, error) { diff --git a/x/evm/keeper/evm_test.go b/x/evm/keeper/evm_test.go index afd667b6ef..4ae2ad62e7 100644 --- a/x/evm/keeper/evm_test.go +++ b/x/evm/keeper/evm_test.go @@ -34,12 +34,8 @@ func TestInternalCallCreateContract(t *testing.T) { Sender: testAddr.String(), Data: contractData, } - // circular interop call - ctx = ctx.WithIsEVM(true) - _, err = k.HandleInternalEVMCall(ctx, req) - require.Equal(t, "sei does not support EVM->CW->EVM call pattern", err.Error()) ctx = ctx.WithIsEVM(false) - _, err = k.HandleInternalEVMCall(ctx, req) + _, _, err = k.HandleInternalEVMCall(ctx, req) require.Nil(t, err) receipt, err := k.GetTransientReceipt(ctx, [32]byte{1, 2, 3}) require.Nil(t, err) @@ -65,17 +61,16 @@ func TestInternalCall(t *testing.T) { Sender: testAddr.String(), Data: contractData, } - ctx = ctx.WithIsEVM(true) - _, err = k.HandleInternalEVMCall(ctx, req) - require.Equal(t, "sei does not support EVM->CW->EVM call pattern", err.Error()) ctx = ctx.WithIsEVM(false) - ret, err := k.HandleInternalEVMCall(ctx, req) + resCtx, ret, err := k.HandleInternalEVMCall(ctx, req) require.Nil(t, err) contractAddr := crypto.CreateAddress(senderEvmAddr, 0) require.NotEmpty(t, k.GetCode(ctx, contractAddr)) require.Equal(t, ret.Data, k.GetCode(ctx, contractAddr)) k.SetERC20NativePointer(ctx, "test", contractAddr) + ctx = resCtx + require.NotNil(t, types.GetCtxEVM(ctx)) receiverAddr, evmAddr := testkeeper.MockAddressPair() k.SetAddressMapping(ctx, receiverAddr, evmAddr) args, err = abi.Pack("transfer", evmAddr, big.NewInt(1000)) @@ -89,9 +84,9 @@ func TestInternalCall(t *testing.T) { Data: args, Value: &val, } - _, err = k.HandleInternalEVMCall(ctx, req) + resCtx, _, err = k.HandleInternalEVMCall(ctx, req) require.Nil(t, err) - require.Equal(t, int64(1000), testkeeper.EVMTestApp.BankKeeper.GetBalance(ctx, receiverAddr, "test").Amount.Int64()) + require.Equal(t, int64(1000), testkeeper.EVMTestApp.BankKeeper.GetBalance(resCtx, receiverAddr, "test").Amount.Int64()) } func TestStaticCall(t *testing.T) { @@ -113,7 +108,7 @@ func TestStaticCall(t *testing.T) { Sender: testAddr.String(), Data: contractData, } - ret, err := k.HandleInternalEVMCall(ctx, req) + _, ret, err := k.HandleInternalEVMCall(ctx, req) require.Nil(t, err) contractAddr := crypto.CreateAddress(senderEvmAddr, 0) require.NotEmpty(t, k.GetCode(ctx, contractAddr)) @@ -163,7 +158,7 @@ func TestNegativeTransfer(t *testing.T) { require.Zero(t, preAttackerBal) require.Equal(t, steal_amount, preVictimBal) - _, err := k.HandleInternalEVMCall(ctx, req) + _, _, err := k.HandleInternalEVMCall(ctx, req) require.ErrorContains(t, err, "invalid coins") // post verification @@ -179,7 +174,7 @@ func TestNegativeTransfer(t *testing.T) { Value: &zeroVal, } - _, err = k.HandleInternalEVMCall(ctx, req2) + _, _, err = k.HandleInternalEVMCall(ctx, req2) require.ErrorContains(t, err, "max initcode size exceeded") } @@ -205,6 +200,6 @@ func TestHandleInternalEVMDelegateCall_AssociationError(t *testing.T) { FromContract: string(contractAddr.Bytes()), To: castedAddr.Hex(), } - _, err := k.HandleInternalEVMDelegateCall(ctx, req) + _, _, err := k.HandleInternalEVMDelegateCall(ctx, req) require.Equal(t, err.Error(), types.NewAssociationMissingErr(testAddr.String()).Error()) } diff --git a/x/evm/keeper/msg_server.go b/x/evm/keeper/msg_server.go index 1f3209af30..a8f6aee32e 100644 --- a/x/evm/keeper/msg_server.go +++ b/x/evm/keeper/msg_server.go @@ -203,6 +203,7 @@ func (k Keeper) applyEVMMessage(ctx sdk.Context, msg *core.Message, stateDB *sta cfg := types.DefaultChainConfig().EthereumConfig(k.ChainID(ctx)) txCtx := core.NewEVMTxContext(msg) evmInstance := vm.NewEVM(*blockCtx, txCtx, stateDB, cfg, vm.Config{}) + stateDB.SetEVM(evmInstance) st := core.NewStateTransition(evmInstance, msg, &gp, true) // fee already charged in ante handler res, err := st.TransitionDb() return res, err diff --git a/x/evm/state/state_test.go b/x/evm/state/state_test.go index 1b165d63ff..ad7c9b1c3d 100644 --- a/x/evm/state/state_test.go +++ b/x/evm/state/state_test.go @@ -7,6 +7,7 @@ import ( sdk "github.com/cosmos/cosmos-sdk/types" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/tracing" + "github.com/ethereum/go-ethereum/core/vm" testkeeper "github.com/sei-protocol/sei-chain/testutil/keeper" "github.com/sei-protocol/sei-chain/x/evm/state" "github.com/sei-protocol/sei-chain/x/evm/types" @@ -166,3 +167,16 @@ func TestSnapshot(t *testing.T) { require.Equal(t, common.Hash{}, newStateDB.GetTransientState(evmAddr, tkey)) require.Equal(t, val, newStateDB.GetState(evmAddr, key)) } + +func TestSetEVM(t *testing.T) { + k, ctx := testkeeper.MockEVMKeeper() + statedb := state.NewDBImpl(ctx, k, false) + rev1 := statedb.Snapshot() + statedb.SetEVM(&vm.EVM{}) + rev2 := statedb.Snapshot() + require.NotNil(t, types.GetCtxEVM(statedb.Ctx())) + statedb.RevertToSnapshot(rev2) + require.NotNil(t, types.GetCtxEVM(statedb.Ctx())) + statedb.RevertToSnapshot(rev1) + require.NotNil(t, types.GetCtxEVM(statedb.Ctx())) +} diff --git a/x/evm/state/statedb.go b/x/evm/state/statedb.go index 7cd888e821..205854af49 100644 --- a/x/evm/state/statedb.go +++ b/x/evm/state/statedb.go @@ -7,6 +7,7 @@ import ( ethtypes "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/vm" "github.com/sei-protocol/sei-chain/utils" + "github.com/sei-protocol/sei-chain/x/evm/types" ) // Initialized for each transaction individually @@ -72,8 +73,10 @@ func (s *DBImpl) SetLogger(logger *tracing.Hooks) { s.logger = logger } -// for interface compliance -func (s *DBImpl) SetEVM(evm *vm.EVM) {} +func (s *DBImpl) SetEVM(evm *vm.EVM) { + s.ctx = types.SetCtxEVM(s.ctx, evm) + s.snapshottedCtxs = utils.Map(s.snapshottedCtxs, func(ctx sdk.Context) sdk.Context { return types.SetCtxEVM(ctx, evm) }) +} // AddPreimage records a SHA3 preimage seen by the VM. // AddPreimage performs a no-op since the EnablePreimageRecording flag is disabled @@ -105,6 +108,8 @@ func (s *DBImpl) Finalize() (surplus sdk.Int, err error) { for i := len(s.snapshottedCtxs) - 1; i > 0; i-- { s.flushCtx(s.snapshottedCtxs[i]) } + s.ctx = s.snapshottedCtxs[0] + s.snapshottedCtxs = []sdk.Context{} surplus = s.tempStateCurrent.surplus for _, ts := range s.tempStatesHist { diff --git a/x/evm/types/context.go b/x/evm/types/context.go new file mode 100644 index 0000000000..e3a30f7a74 --- /dev/null +++ b/x/evm/types/context.go @@ -0,0 +1,28 @@ +package types + +import ( + "context" + + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/ethereum/go-ethereum/core/vm" +) + +type CtxEVMKeyType string + +const CtxEVMKey = CtxEVMKeyType("evm") + +func SetCtxEVM(ctx sdk.Context, evm *vm.EVM) sdk.Context { + return ctx.WithContext(context.WithValue(ctx.Context(), CtxEVMKey, evm)) +} + +func GetCtxEVM(ctx sdk.Context) *vm.EVM { + rawVal := ctx.Context().Value(CtxEVMKey) + if rawVal == nil { + return nil + } + evm, ok := rawVal.(*vm.EVM) + if !ok { + return nil + } + return evm +} diff --git a/x/evm/types/context_test.go b/x/evm/types/context_test.go new file mode 100644 index 0000000000..40e6244cea --- /dev/null +++ b/x/evm/types/context_test.go @@ -0,0 +1,20 @@ +package types_test + +import ( + "context" + "testing" + + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/ethereum/go-ethereum/core/vm" + "github.com/sei-protocol/sei-chain/x/evm/types" + "github.com/stretchr/testify/require" +) + +func TestCtxEvm(t *testing.T) { + ctx := sdk.Context{}.WithContext(context.Background()) + require.Nil(t, types.GetCtxEVM(ctx)) + ctx = types.SetCtxEVM(ctx, &vm.EVM{}) + require.NotNil(t, types.GetCtxEVM(ctx)) + ctx = ctx.WithContext(context.WithValue(ctx.Context(), types.CtxEVMKey, 123)) + require.Nil(t, types.GetCtxEVM(ctx)) +}