From d9d48d8d2db23a1b221a78d2f94ee39a2f277db0 Mon Sep 17 00:00:00 2001 From: beer-1 Date: Thu, 12 Dec 2024 17:47:24 +0900 Subject: [PATCH] refactor with bugfix to keep account number on shorthand account creation --- x/evm/keeper/address.go | 10 ++- x/evm/keeper/address_test.go | 36 ++++++++++ x/evm/keeper/msg_server.go | 130 ++++++++++++++++------------------- 3 files changed, 105 insertions(+), 71 deletions(-) diff --git a/x/evm/keeper/address.go b/x/evm/keeper/address.go index 288583c0..c0f98c4c 100644 --- a/x/evm/keeper/address.go +++ b/x/evm/keeper/address.go @@ -18,6 +18,7 @@ func (k Keeper) convertToEVMAddress(ctx context.Context, addr sdk.AccAddress, is return common.BytesToAddress(addr.Bytes()), nil } + accountNumber := uint64(0) shorthandAddr := common.BytesToAddress(addr.Bytes()) if found := k.accountKeeper.HasAccount(ctx, shorthandAddr.Bytes()); found { account := k.accountKeeper.GetAccount(ctx, shorthandAddr.Bytes()) @@ -36,9 +37,16 @@ func (k Keeper) convertToEVMAddress(ctx context.Context, addr sdk.AccAddress, is return common.Address{}, types.ErrAddressAlreadyExists.Wrapf("failed to create shorthand account of `%s`: `%s`", addr, shorthandAddr) } + + accountNumber = account.GetAccountNumber() } if isSigner { + // if account number is not set, get next account number + if accountNumber == 0 { + accountNumber = k.accountKeeper.NextAccountNumber(ctx) + } + // create shorthand account shorthandAccount, err := types.NewShorthandAccountWithAddress(k.ac, addr) if err != nil { @@ -46,7 +54,7 @@ func (k Keeper) convertToEVMAddress(ctx context.Context, addr sdk.AccAddress, is } // register shorthand account - shorthandAccount.AccountNumber = k.accountKeeper.NextAccountNumber(ctx) + shorthandAccount.AccountNumber = accountNumber k.accountKeeper.SetAccount(ctx, shorthandAccount) } diff --git a/x/evm/keeper/address_test.go b/x/evm/keeper/address_test.go index 99eb5f99..6d94e06d 100644 --- a/x/evm/keeper/address_test.go +++ b/x/evm/keeper/address_test.go @@ -73,3 +73,39 @@ func Test_AllowLongCosmosAddress(t *testing.T) { )) require.ErrorContains(t, err, types.ErrAddressAlreadyExists.Error()) } + +func Test_AllowLongCosmosAddress_ConvertEmptyAccount(t *testing.T) { + ctx, input := createDefaultTestInput(t) + _, _, addr := keyPubAddr() + _, _, addr2 := keyPubAddr() + evmAddr := common.BytesToAddress(addr.Bytes()) + evmAddr2 := common.BytesToAddress(addr2.Bytes()) + + addr3 := append([]byte{0}, addr2.Bytes()...) + + erc20Keeper, err := keeper.NewERC20Keeper(&input.EVMKeeper) + require.NoError(t, err) + + // deploy erc20 contract + fooContractAddr := deployERC20(t, ctx, input, evmAddr, "foo") + fooDenom, err := types.ContractAddrToDenom(ctx, &input.EVMKeeper, fooContractAddr) + require.NoError(t, err) + require.Equal(t, "evm/"+fooContractAddr.Hex()[2:], fooDenom) + + // mint erc20 + mintERC20(t, ctx, input, evmAddr, evmAddr, sdk.NewCoin(fooDenom, math.NewInt(100)), false) + + // create empty account + mintERC20(t, ctx, input, evmAddr, evmAddr2, sdk.NewCoin(fooDenom, math.NewInt(100)), false) + expectedAccNum := input.AccountKeeper.GetAccount(ctx, addr2).GetAccountNumber() + + // take the address ownership + err = erc20Keeper.SendCoins(ctx, addr3, addr, sdk.NewCoins( + sdk.NewCoin(fooDenom, math.NewInt(50)), + )) + require.NoError(t, err) + + // account number should be the same + accNum := input.AccountKeeper.GetAccount(ctx, addr2).GetAccountNumber() + require.Equal(t, expectedAccNum, accNum) +} diff --git a/x/evm/keeper/msg_server.go b/x/evm/keeper/msg_server.go index e0dd20b8..67bbcbff 100644 --- a/x/evm/keeper/msg_server.go +++ b/x/evm/keeper/msg_server.go @@ -4,11 +4,14 @@ import ( "context" "errors" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/core/tracing" + coretypes "github.com/ethereum/go-ethereum/core/types" "github.com/holiman/uint256" "cosmossdk.io/collections" + "cosmossdk.io/math" sdk "github.com/cosmos/cosmos-sdk/types" sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" govtypes "github.com/cosmos/cosmos-sdk/x/gov/types" @@ -39,44 +42,17 @@ func (ms *msgServerImpl) Create(ctx context.Context, msg *types.MsgCreate) (*typ } // argument validation - caller, err := ms.convertToEVMAddress(ctx, sender, true) + caller, codeBz, value, accessList, err := ms.validateArguments(ctx, sender, msg.Code, msg.Value, msg.AccessList, true) if err != nil { return nil, err } - if len(msg.Code) == 0 { - return nil, sdkerrors.ErrInvalidRequest.Wrap("empty code bytes") - } - codeBz, err := hexutil.Decode(msg.Code) - if err != nil { - return nil, types.ErrInvalidHexString.Wrap(err.Error()) - } - value, overflow := uint256.FromBig(msg.Value.BigInt()) - if overflow { - return nil, types.ErrInvalidValue.Wrap("value is out of range") - } - accessList := types.ConvertCosmosAccessListToEth(msg.AccessList) + // check the sender is allowed publisher - params, err := ms.Params.Get(ctx) + err = ms.assertAllowedPublishers(ctx, msg.Sender) if err != nil { return nil, err } - // assert deploy authorization - if len(params.AllowedPublishers) != 0 { - allowed := false - for _, publisher := range params.AllowedPublishers { - if msg.Sender == publisher { - allowed = true - - break - } - } - - if !allowed { - return nil, sdkerrors.ErrUnauthorized.Wrapf("`%s` is not allowed to deploy a contract", msg.Sender) - } - } - // deploy a contract retBz, contractAddr, logs, err := ms.EVMCreate(ctx, caller, codeBz, value, accessList) if err != nil { @@ -104,44 +80,17 @@ func (ms *msgServerImpl) Create2(ctx context.Context, msg *types.MsgCreate2) (*t } // argument validation - caller, err := ms.convertToEVMAddress(ctx, sender, true) + caller, codeBz, value, accessList, err := ms.validateArguments(ctx, sender, msg.Code, msg.Value, msg.AccessList, true) if err != nil { return nil, err } - if len(msg.Code) == 0 { - return nil, sdkerrors.ErrInvalidRequest.Wrap("empty code bytes") - } - codeBz, err := hexutil.Decode(msg.Code) - if err != nil { - return nil, types.ErrInvalidHexString.Wrap(err.Error()) - } - value, overflow := uint256.FromBig(msg.Value.BigInt()) - if overflow { - return nil, types.ErrInvalidValue.Wrap("value is out of range") - } - accessList := types.ConvertCosmosAccessListToEth(msg.AccessList) + // check the sender is allowed publisher - params, err := ms.Params.Get(ctx) + err = ms.assertAllowedPublishers(ctx, msg.Sender) if err != nil { return nil, err } - // assert deploy authorization - if len(params.AllowedPublishers) != 0 { - allowed := false - for _, publisher := range params.AllowedPublishers { - if msg.Sender == publisher { - allowed = true - - break - } - } - - if !allowed { - return nil, sdkerrors.ErrUnauthorized.Wrapf("`%s` is not allowed to deploy a contract", msg.Sender) - } - } - // deploy a contract retBz, contractAddr, logs, err := ms.EVMCreate2(ctx, caller, codeBz, value, msg.Salt, accessList) if err != nil { @@ -174,19 +123,10 @@ func (ms *msgServerImpl) Call(ctx context.Context, msg *types.MsgCall) (*types.M } // argument validation - caller, err := ms.convertToEVMAddress(ctx, sender, true) + caller, inputBz, value, accessList, err := ms.validateArguments(ctx, sender, msg.Input, msg.Value, msg.AccessList, false) if err != nil { return nil, err } - inputBz, err := hexutil.Decode(msg.Input) - if err != nil { - return nil, types.ErrInvalidHexString.Wrap(err.Error()) - } - value, overflow := uint256.FromBig(msg.Value.BigInt()) - if overflow { - return nil, types.ErrInvalidValue.Wrap("value is out of range") - } - accessList := types.ConvertCosmosAccessListToEth(msg.AccessList) retBz, logs, err := ms.EVMCall(ctx, caller, contractAddr, inputBz, value, accessList) if err != nil { @@ -291,3 +231,53 @@ func (k *msgServerImpl) handleSequenceIncremented(ctx context.Context, sender sd return nil } + +// validateArguments validates the arguments of create, create2, and call messages. +func (ms *msgServerImpl) validateArguments( + ctx context.Context, sender []byte, data string, + value math.Int, accessList []types.AccessTuple, isCreate bool, +) (common.Address, []byte, *uint256.Int, coretypes.AccessList, error) { + caller, err := ms.convertToEVMAddress(ctx, sender, true) + if err != nil { + return common.Address{}, nil, nil, nil, err + } + if isCreate && len(data) == 0 { + return common.Address{}, nil, nil, nil, sdkerrors.ErrInvalidRequest.Wrap("empty code bytes") + } + dataBz, err := hexutil.Decode(data) + if err != nil { + return common.Address{}, nil, nil, nil, types.ErrInvalidHexString.Wrap(err.Error()) + } + val, overflow := uint256.FromBig(value.BigInt()) + if overflow { + return common.Address{}, nil, nil, nil, types.ErrInvalidValue.Wrap("value is out of range") + } + + return caller, dataBz, val, types.ConvertCosmosAccessListToEth(accessList), nil +} + +// assertAllowedPublishers asserts the sender is allowed to deploy a contract. +func (ms *msgServerImpl) assertAllowedPublishers(ctx context.Context, sender string) error { + params, err := ms.Params.Get(ctx) + if err != nil { + return err + } + + // assert deploy authorization + if len(params.AllowedPublishers) != 0 { + allowed := false + for _, publisher := range params.AllowedPublishers { + if sender == publisher { + allowed = true + + break + } + } + + if !allowed { + return sdkerrors.ErrUnauthorized.Wrapf("`%s` is not allowed to deploy a contract", sender) + } + } + + return nil +}