From fbe6cfc9c8b21051b74507b8746e6bcf05ef0a3a Mon Sep 17 00:00:00 2001 From: beer-1 <147697694+beer-1@users.noreply.github.com> Date: Thu, 11 Apr 2024 11:09:29 +0900 Subject: [PATCH] fix address collision checking --- x/evm/keeper/address.go | 54 ++++++++++++++++-------------------- x/evm/keeper/address_test.go | 22 ++++++++++++--- x/evm/keeper/erc20.go | 14 +++++----- x/evm/keeper/msg_server.go | 6 ++-- 4 files changed, 52 insertions(+), 44 deletions(-) diff --git a/x/evm/keeper/address.go b/x/evm/keeper/address.go index 141dbb8..288583c 100644 --- a/x/evm/keeper/address.go +++ b/x/evm/keeper/address.go @@ -9,9 +9,11 @@ import ( ) // convertToEVMAddress converts a cosmos address to an EVM address -// check if the shorthand has been registered and if so return error -// else register the shorthand address as an account. -func (k Keeper) convertToEVMAddress(ctx context.Context, addr sdk.AccAddress) (common.Address, error) { +// check if the shorthand has been registered and if so, check the +// registered account's origin address is same with given address. +// +// Also we create shorthand account if the address is not registered yet and is a signer. +func (k Keeper) convertToEVMAddress(ctx context.Context, addr sdk.AccAddress, isSigner bool) (common.Address, error) { if len(addr) == common.AddressLength { return common.BytesToAddress(addr.Bytes()), nil } @@ -21,40 +23,32 @@ func (k Keeper) convertToEVMAddress(ctx context.Context, addr sdk.AccAddress) (c account := k.accountKeeper.GetAccount(ctx, shorthandAddr.Bytes()) // if the account is empty account, convert it to shorthand account - if types.IsEmptyAccount(account) { - shorthandAccount, err := types.NewShorthandAccountWithAddress(k.ac, addr) - if err != nil { - return common.Address{}, err + if !types.IsEmptyAccount(account) { + + // check if the account is shorthand account, and if so, check if the original address is the same + if shorthandAccount, isShorthandAccount := account.(types.ShorthandAccountI); isShorthandAccount { + if originAddr, err := shorthandAccount.GetOriginalAddress(k.ac); err != nil { + return common.Address{}, err + } else if originAddr.Equals(addr) { + return shorthandAddr, nil + } } - shorthandAccount.AccountNumber = account.GetAccountNumber() - k.accountKeeper.SetAccount(ctx, shorthandAccount) - - return shorthandAddr, nil + return common.Address{}, types.ErrAddressAlreadyExists.Wrapf("failed to create shorthand account of `%s`: `%s`", addr, shorthandAddr) } + } - // check if the account is shorthand account, and if so, check if the original address is the same - shorthandAccount, isShorthandAccount := account.(types.ShorthandAccountI) - if isShorthandAccount { - if originAddr, err := shorthandAccount.GetOriginalAddress(k.ac); err != nil { - return common.Address{}, err - } else if originAddr.Equals(addr) { - return shorthandAddr, nil - } + if isSigner { + // create shorthand account + shorthandAccount, err := types.NewShorthandAccountWithAddress(k.ac, addr) + if err != nil { + return common.Address{}, err } - return common.Address{}, types.ErrAddressAlreadyExists.Wrapf("failed to create shorthand account of `%s`: `%s`", addr, shorthandAddr) + // register shorthand account + shorthandAccount.AccountNumber = k.accountKeeper.NextAccountNumber(ctx) + k.accountKeeper.SetAccount(ctx, shorthandAccount) } - // create shorthand account - shorthandAccount, err := types.NewShorthandAccountWithAddress(k.ac, addr) - if err != nil { - return common.Address{}, err - } - - // register shorthand account - shorthandAccount.AccountNumber = k.accountKeeper.NextAccountNumber(ctx) - k.accountKeeper.SetAccount(ctx, shorthandAccount) - return shorthandAddr, nil } diff --git a/x/evm/keeper/address_test.go b/x/evm/keeper/address_test.go index b1f4095..03d6cfc 100644 --- a/x/evm/keeper/address_test.go +++ b/x/evm/keeper/address_test.go @@ -45,17 +45,31 @@ func Test_AllowLongCosmosAddress(t *testing.T) { )) require.NoError(t, err) - // should be failed with already existing address + // should be be allowed because the address is not taken yet err = erc20Keeper.SendCoins(ctx, addr, addr4, sdk.NewCoins( sdk.NewCoin("bar", math.NewInt(100)), sdk.NewCoin(fooDenom, math.NewInt(50)), )) - require.ErrorContains(t, err, types.ErrAddressAlreadyExists.Error()) + require.NoError(t, err) - // but still can send to the same address - err = erc20Keeper.SendCoins(ctx, addr, addr3, sdk.NewCoins( + // take the address ownership + err = erc20Keeper.SendCoins(ctx, addr3, addr, sdk.NewCoins( sdk.NewCoin("bar", math.NewInt(100)), sdk.NewCoin(fooDenom, math.NewInt(50)), )) require.NoError(t, err) + + // then other account can't use this address + err = erc20Keeper.SendCoins(ctx, addr4, addr, sdk.NewCoins( + sdk.NewCoin("bar", math.NewInt(100)), + sdk.NewCoin(fooDenom, math.NewInt(50)), + )) + require.ErrorContains(t, err, types.ErrAddressAlreadyExists.Error()) + + // also can't use the address as a receive + err = erc20Keeper.SendCoins(ctx, addr, addr4, sdk.NewCoins( + sdk.NewCoin("bar", math.NewInt(100)), + sdk.NewCoin(fooDenom, math.NewInt(50)), + )) + require.ErrorContains(t, err, types.ErrAddressAlreadyExists.Error()) } diff --git a/x/evm/keeper/erc20.go b/x/evm/keeper/erc20.go index c7d44e9..f9ae3d7 100644 --- a/x/evm/keeper/erc20.go +++ b/x/evm/keeper/erc20.go @@ -42,7 +42,7 @@ func NewERC20Keeper(k *Keeper) (types.IERC20Keeper, error) { // BurnCoins implements IERC20Keeper. func (k ERC20Keeper) BurnCoins(ctx context.Context, addr sdk.AccAddress, amount sdk.Coins) error { - evmAddr, err := k.convertToEVMAddress(ctx, addr) + evmAddr, err := k.convertToEVMAddress(ctx, addr, false) if err != nil { return err } @@ -81,7 +81,7 @@ func (k ERC20Keeper) BurnCoins(ctx context.Context, addr sdk.AccAddress, amount // GetBalance implements IERC20Keeper. func (k ERC20Keeper) GetBalance(ctx context.Context, addr sdk.AccAddress, denom string) (math.Int, error) { - evmAddr, err := k.convertToEVMAddress(ctx, addr) + evmAddr, err := k.convertToEVMAddress(ctx, addr, false) if err != nil { return math.ZeroInt(), err } @@ -156,7 +156,7 @@ func (k ERC20Keeper) GetMetadata(ctx context.Context, denom string) (banktypes.M // GetPaginatedBalances implements IERC20Keeper. func (k ERC20Keeper) GetPaginatedBalances(ctx context.Context, pageReq *query.PageRequest, addr sdk.AccAddress) (sdk.Coins, *query.PageResponse, error) { - evmAddr, err := k.convertToEVMAddress(ctx, addr) + evmAddr, err := k.convertToEVMAddress(ctx, addr, false) if err != nil { return nil, nil, err } @@ -222,7 +222,7 @@ func (k ERC20Keeper) HasSupply(ctx context.Context, denom string) (bool, error) // IterateAccountBalances implements IERC20Keeper. func (k ERC20Keeper) IterateAccountBalances(ctx context.Context, addr sdk.AccAddress, cb func(sdk.Coin) (bool, error)) error { - evmAddr, err := k.convertToEVMAddress(ctx, addr) + evmAddr, err := k.convertToEVMAddress(ctx, addr, false) if err != nil { return err } @@ -267,7 +267,7 @@ func (k ERC20Keeper) IterateSupply(ctx context.Context, cb func(supply sdk.Coin) // MintCoins implements IERC20Keeper. func (k ERC20Keeper) MintCoins(ctx context.Context, addr sdk.AccAddress, amount sdk.Coins) error { sdkCtx := sdk.UnwrapSDKContext(ctx) - evmAddr, err := k.convertToEVMAddress(ctx, addr) + evmAddr, err := k.convertToEVMAddress(ctx, addr, false) if err != nil { return err } @@ -338,11 +338,11 @@ func (k ERC20Keeper) MintCoins(ctx context.Context, addr sdk.AccAddress, amount // SendCoins implements IERC20Keeper. func (k ERC20Keeper) SendCoins(ctx context.Context, fromAddr sdk.AccAddress, toAddr sdk.AccAddress, amt sdk.Coins) error { - evmFromAddr, err := k.convertToEVMAddress(ctx, fromAddr) + evmFromAddr, err := k.convertToEVMAddress(ctx, fromAddr, true) if err != nil { return err } - evmToAddr, err := k.convertToEVMAddress(ctx, toAddr) + evmToAddr, err := k.convertToEVMAddress(ctx, toAddr, false) if err != nil { return err } diff --git a/x/evm/keeper/msg_server.go b/x/evm/keeper/msg_server.go index 8bebd2d..fcd70b8 100644 --- a/x/evm/keeper/msg_server.go +++ b/x/evm/keeper/msg_server.go @@ -27,7 +27,7 @@ func (ms *msgServerImpl) Create(ctx context.Context, msg *types.MsgCreate) (*typ } // argument validation - caller, err := ms.convertToEVMAddress(ctx, sender) + caller, err := ms.convertToEVMAddress(ctx, sender, true) if err != nil { return nil, err } @@ -81,7 +81,7 @@ func (ms *msgServerImpl) Create2(ctx context.Context, msg *types.MsgCreate2) (*t } // argument validation - caller, err := ms.convertToEVMAddress(ctx, sender) + caller, err := ms.convertToEVMAddress(ctx, sender, true) if err != nil { return nil, err } @@ -140,7 +140,7 @@ func (ms *msgServerImpl) Call(ctx context.Context, msg *types.MsgCall) (*types.M } // argument validation - caller, err := ms.convertToEVMAddress(ctx, sender) + caller, err := ms.convertToEVMAddress(ctx, sender, true) if err != nil { return nil, err }