Skip to content

Commit

Permalink
fix address collision checking
Browse files Browse the repository at this point in the history
  • Loading branch information
beer-1 committed Apr 11, 2024
1 parent 37b533f commit fbe6cfc
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 44 deletions.
54 changes: 24 additions & 30 deletions x/evm/keeper/address.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@ import (
)

// convertToEVMAddress converts a cosmos address to an EVM address
// check if the shorthand has been registered and if so return error
// else register the shorthand address as an account.
func (k Keeper) convertToEVMAddress(ctx context.Context, addr sdk.AccAddress) (common.Address, error) {
// check if the shorthand has been registered and if so, check the
// registered account's origin address is same with given address.
//
// Also we create shorthand account if the address is not registered yet and is a signer.
func (k Keeper) convertToEVMAddress(ctx context.Context, addr sdk.AccAddress, isSigner bool) (common.Address, error) {
if len(addr) == common.AddressLength {
return common.BytesToAddress(addr.Bytes()), nil
}
Expand All @@ -21,40 +23,32 @@ func (k Keeper) convertToEVMAddress(ctx context.Context, addr sdk.AccAddress) (c
account := k.accountKeeper.GetAccount(ctx, shorthandAddr.Bytes())

// if the account is empty account, convert it to shorthand account
if types.IsEmptyAccount(account) {
shorthandAccount, err := types.NewShorthandAccountWithAddress(k.ac, addr)
if err != nil {
return common.Address{}, err
if !types.IsEmptyAccount(account) {

// check if the account is shorthand account, and if so, check if the original address is the same
if shorthandAccount, isShorthandAccount := account.(types.ShorthandAccountI); isShorthandAccount {
if originAddr, err := shorthandAccount.GetOriginalAddress(k.ac); err != nil {
return common.Address{}, err
} else if originAddr.Equals(addr) {
return shorthandAddr, nil
}
}

shorthandAccount.AccountNumber = account.GetAccountNumber()
k.accountKeeper.SetAccount(ctx, shorthandAccount)

return shorthandAddr, nil
return common.Address{}, types.ErrAddressAlreadyExists.Wrapf("failed to create shorthand account of `%s`: `%s`", addr, shorthandAddr)
}
}

// check if the account is shorthand account, and if so, check if the original address is the same
shorthandAccount, isShorthandAccount := account.(types.ShorthandAccountI)
if isShorthandAccount {
if originAddr, err := shorthandAccount.GetOriginalAddress(k.ac); err != nil {
return common.Address{}, err
} else if originAddr.Equals(addr) {
return shorthandAddr, nil
}
if isSigner {
// create shorthand account
shorthandAccount, err := types.NewShorthandAccountWithAddress(k.ac, addr)
if err != nil {
return common.Address{}, err
}

return common.Address{}, types.ErrAddressAlreadyExists.Wrapf("failed to create shorthand account of `%s`: `%s`", addr, shorthandAddr)
// register shorthand account
shorthandAccount.AccountNumber = k.accountKeeper.NextAccountNumber(ctx)
k.accountKeeper.SetAccount(ctx, shorthandAccount)
}

// create shorthand account
shorthandAccount, err := types.NewShorthandAccountWithAddress(k.ac, addr)
if err != nil {
return common.Address{}, err
}

// register shorthand account
shorthandAccount.AccountNumber = k.accountKeeper.NextAccountNumber(ctx)
k.accountKeeper.SetAccount(ctx, shorthandAccount)

return shorthandAddr, nil
}
22 changes: 18 additions & 4 deletions x/evm/keeper/address_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,31 @@ func Test_AllowLongCosmosAddress(t *testing.T) {
))
require.NoError(t, err)

// should be failed with already existing address
// should be be allowed because the address is not taken yet
err = erc20Keeper.SendCoins(ctx, addr, addr4, sdk.NewCoins(
sdk.NewCoin("bar", math.NewInt(100)),
sdk.NewCoin(fooDenom, math.NewInt(50)),
))
require.ErrorContains(t, err, types.ErrAddressAlreadyExists.Error())
require.NoError(t, err)

// but still can send to the same address
err = erc20Keeper.SendCoins(ctx, addr, addr3, sdk.NewCoins(
// take the address ownership
err = erc20Keeper.SendCoins(ctx, addr3, addr, sdk.NewCoins(
sdk.NewCoin("bar", math.NewInt(100)),
sdk.NewCoin(fooDenom, math.NewInt(50)),
))
require.NoError(t, err)

// then other account can't use this address
err = erc20Keeper.SendCoins(ctx, addr4, addr, sdk.NewCoins(
sdk.NewCoin("bar", math.NewInt(100)),
sdk.NewCoin(fooDenom, math.NewInt(50)),
))
require.ErrorContains(t, err, types.ErrAddressAlreadyExists.Error())

// also can't use the address as a receive
err = erc20Keeper.SendCoins(ctx, addr, addr4, sdk.NewCoins(
sdk.NewCoin("bar", math.NewInt(100)),
sdk.NewCoin(fooDenom, math.NewInt(50)),
))
require.ErrorContains(t, err, types.ErrAddressAlreadyExists.Error())
}
14 changes: 7 additions & 7 deletions x/evm/keeper/erc20.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func NewERC20Keeper(k *Keeper) (types.IERC20Keeper, error) {

// BurnCoins implements IERC20Keeper.
func (k ERC20Keeper) BurnCoins(ctx context.Context, addr sdk.AccAddress, amount sdk.Coins) error {
evmAddr, err := k.convertToEVMAddress(ctx, addr)
evmAddr, err := k.convertToEVMAddress(ctx, addr, false)
if err != nil {
return err
}
Expand Down Expand Up @@ -81,7 +81,7 @@ func (k ERC20Keeper) BurnCoins(ctx context.Context, addr sdk.AccAddress, amount

// GetBalance implements IERC20Keeper.
func (k ERC20Keeper) GetBalance(ctx context.Context, addr sdk.AccAddress, denom string) (math.Int, error) {
evmAddr, err := k.convertToEVMAddress(ctx, addr)
evmAddr, err := k.convertToEVMAddress(ctx, addr, false)
if err != nil {
return math.ZeroInt(), err
}
Expand Down Expand Up @@ -156,7 +156,7 @@ func (k ERC20Keeper) GetMetadata(ctx context.Context, denom string) (banktypes.M

// GetPaginatedBalances implements IERC20Keeper.
func (k ERC20Keeper) GetPaginatedBalances(ctx context.Context, pageReq *query.PageRequest, addr sdk.AccAddress) (sdk.Coins, *query.PageResponse, error) {
evmAddr, err := k.convertToEVMAddress(ctx, addr)
evmAddr, err := k.convertToEVMAddress(ctx, addr, false)
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -222,7 +222,7 @@ func (k ERC20Keeper) HasSupply(ctx context.Context, denom string) (bool, error)

// IterateAccountBalances implements IERC20Keeper.
func (k ERC20Keeper) IterateAccountBalances(ctx context.Context, addr sdk.AccAddress, cb func(sdk.Coin) (bool, error)) error {
evmAddr, err := k.convertToEVMAddress(ctx, addr)
evmAddr, err := k.convertToEVMAddress(ctx, addr, false)
if err != nil {
return err
}
Expand Down Expand Up @@ -267,7 +267,7 @@ func (k ERC20Keeper) IterateSupply(ctx context.Context, cb func(supply sdk.Coin)
// MintCoins implements IERC20Keeper.
func (k ERC20Keeper) MintCoins(ctx context.Context, addr sdk.AccAddress, amount sdk.Coins) error {
sdkCtx := sdk.UnwrapSDKContext(ctx)
evmAddr, err := k.convertToEVMAddress(ctx, addr)
evmAddr, err := k.convertToEVMAddress(ctx, addr, false)
if err != nil {
return err
}
Expand Down Expand Up @@ -338,11 +338,11 @@ func (k ERC20Keeper) MintCoins(ctx context.Context, addr sdk.AccAddress, amount

// SendCoins implements IERC20Keeper.
func (k ERC20Keeper) SendCoins(ctx context.Context, fromAddr sdk.AccAddress, toAddr sdk.AccAddress, amt sdk.Coins) error {
evmFromAddr, err := k.convertToEVMAddress(ctx, fromAddr)
evmFromAddr, err := k.convertToEVMAddress(ctx, fromAddr, true)
if err != nil {
return err
}
evmToAddr, err := k.convertToEVMAddress(ctx, toAddr)
evmToAddr, err := k.convertToEVMAddress(ctx, toAddr, false)
if err != nil {
return err
}
Expand Down
6 changes: 3 additions & 3 deletions x/evm/keeper/msg_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func (ms *msgServerImpl) Create(ctx context.Context, msg *types.MsgCreate) (*typ
}

// argument validation
caller, err := ms.convertToEVMAddress(ctx, sender)
caller, err := ms.convertToEVMAddress(ctx, sender, true)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -81,7 +81,7 @@ func (ms *msgServerImpl) Create2(ctx context.Context, msg *types.MsgCreate2) (*t
}

// argument validation
caller, err := ms.convertToEVMAddress(ctx, sender)
caller, err := ms.convertToEVMAddress(ctx, sender, true)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -140,7 +140,7 @@ func (ms *msgServerImpl) Call(ctx context.Context, msg *types.MsgCall) (*types.M
}

// argument validation
caller, err := ms.convertToEVMAddress(ctx, sender)
caller, err := ms.convertToEVMAddress(ctx, sender, true)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit fbe6cfc

Please sign in to comment.