Skip to content

Commit

Permalink
feat: approve before hook execution (#56)
Browse files Browse the repository at this point in the history
* approve before hook execution

* remove comment
  • Loading branch information
beer-1 authored Sep 2, 2024
1 parent e193872 commit 0823a2d
Show file tree
Hide file tree
Showing 6 changed files with 237 additions and 28 deletions.
1 change: 0 additions & 1 deletion app/ibc-hooks/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,6 @@ func _createTestInput(
faucet := NewTestFaucet(t, ctx, bankKeeper, authtypes.Minter, initialTotalSupply()...)

// ibc middleware setup

mockIBCMiddleware := mockIBCMiddleware{}
evmHooks := evmhooks.NewEVMHooks(appCodec, ac, evmKeeper)

Expand Down
78 changes: 74 additions & 4 deletions app/ibc-hooks/receive.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@ package evm_hooks

import (
"fmt"
"math/big"

sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/hexutil"

transfertypes "github.com/cosmos/ibc-go/v8/modules/apps/transfer/types"
channeltypes "github.com/cosmos/ibc-go/v8/modules/core/04-channel/types"
Expand Down Expand Up @@ -43,7 +46,7 @@ func (h EVMHooks) onRecvIcs20Packet(
}

// Calculate the receiver / contract caller based on the packet's channel and sender
intermediateSender := deriveIntermediateSender(packet.GetDestChannel(), data.GetSender())
intermediateSender := DeriveIntermediateSender(packet.GetDestChannel(), data.GetSender())

// The funds sent on this packet need to be transferred to the intermediary account for the sender.
// For this, we override the ICS20 packet's Receiver (essentially hijacking the funds to this new address)
Expand All @@ -60,6 +63,11 @@ func (h EVMHooks) onRecvIcs20Packet(
}

msg.Sender = intermediateSender
localDenom := LocalDenom(packet, data.Denom)
_, err = h.approveERC20(ctx, intermediateSender, common.HexToAddress(msg.ContractAddr), localDenom, data.Amount)
if err != nil {
return newEmitErrorAcknowledgement(err)
}
_, err = h.execMsg(ctx, msg)
if err != nil {
return newEmitErrorAcknowledgement(err)
Expand All @@ -68,6 +76,33 @@ func (h EVMHooks) onRecvIcs20Packet(
return ack
}

func (h EVMHooks) approveERC20(ctx sdk.Context, intermediateSender string, contractAddr common.Address, denom, amount string) (*evmtypes.MsgCallResponse, error) {
amt, ok := new(big.Int).SetString(amount, 10)
if !ok {
return nil, fmt.Errorf("failed to parse amount %s", amount)
}

erc20ABI := h.evmKeeper.ERC20Keeper().GetERC20ABI()
inputBz, err := erc20ABI.Pack("approve", contractAddr, amt)
if err != nil {
return nil, err
}

erc20Addr, err := h.evmKeeper.GetContractAddrByDenom(ctx, denom)
if err != nil {
return nil, err
}

msg := &evmtypes.MsgCall{
Sender: intermediateSender,
ContractAddr: erc20Addr.Hex(),
Input: hexutil.Encode(inputBz),
}

evmMsgServer := evmkeeper.NewMsgServerImpl(h.evmKeeper)
return evmMsgServer.Call(ctx, msg)
}

func (h EVMHooks) onRecvIcs721Packet(
ctx sdk.Context,
im ibchooks.IBCMiddleware,
Expand Down Expand Up @@ -95,7 +130,7 @@ func (h EVMHooks) onRecvIcs721Packet(
}

// Calculate the receiver / contract caller based on the packet's channel and sender
intermediateSender := deriveIntermediateSender(packet.GetDestChannel(), data.GetSender())
intermediateSender := DeriveIntermediateSender(packet.GetDestChannel(), data.GetSender())

// The funds sent on this packet need to be transferred to the intermediary account for the sender.
// For this, we override the ICS721 packet's Receiver (essentially hijacking the funds to this new address)
Expand All @@ -111,7 +146,15 @@ func (h EVMHooks) onRecvIcs721Packet(
return ack
}

// approve the transfer of the NFT to the contract
msg.Sender = intermediateSender
localClassId := LocalClassId(packet, data.ClassId)
for _, tokenId := range data.TokenIds {
_, err = h.approveERC721(ctx, intermediateSender, common.HexToAddress(msg.ContractAddr), localClassId, tokenId)
if err != nil {
return newEmitErrorAcknowledgement(err)
}
}
_, err = h.execMsg(ctx, msg)
if err != nil {
return newEmitErrorAcknowledgement(err)
Expand All @@ -120,7 +163,34 @@ func (h EVMHooks) onRecvIcs721Packet(
return ack
}

func (im EVMHooks) execMsg(ctx sdk.Context, msg *evmtypes.MsgCall) (*evmtypes.MsgCallResponse, error) {
evmMsgServer := evmkeeper.NewMsgServerImpl(im.evmKeeper)
func (h EVMHooks) execMsg(ctx sdk.Context, msg *evmtypes.MsgCall) (*evmtypes.MsgCallResponse, error) {
evmMsgServer := evmkeeper.NewMsgServerImpl(h.evmKeeper)
return evmMsgServer.Call(ctx, msg)
}

func (h EVMHooks) approveERC721(ctx sdk.Context, intermediateSender string, contractAddr common.Address, classId, tokenId string) (*evmtypes.MsgCallResponse, error) {
tid, ok := evmtypes.TokenIdToBigInt(classId, tokenId)
if !ok {
return nil, evmtypes.ErrInvalidTokenId
}

erc721ABI := h.evmKeeper.ERC721Keeper().GetERC721ABI()
inputBz, err := erc721ABI.Pack("approve", contractAddr, tid)
if err != nil {
return nil, err
}

erc721Addr, err := h.evmKeeper.GetContractAddrByClassId(ctx, classId)
if err != nil {
return nil, err
}

msg := &evmtypes.MsgCall{
Sender: intermediateSender,
ContractAddr: erc721Addr.Hex(),
Input: hexutil.Encode(inputBz),
}

evmMsgServer := evmkeeper.NewMsgServerImpl(h.evmKeeper)
return evmMsgServer.Call(ctx, msg)
}
117 changes: 95 additions & 22 deletions app/ibc-hooks/receive_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,14 @@ import (
"github.com/holiman/uint256"
"github.com/stretchr/testify/require"

sdk "github.com/cosmos/cosmos-sdk/types"
transfertypes "github.com/cosmos/ibc-go/v8/modules/apps/transfer/types"
channeltypes "github.com/cosmos/ibc-go/v8/modules/core/04-channel/types"

nfttransfertypes "github.com/initia-labs/initia/x/ibc/nft-transfer/types"
evm_hooks "github.com/initia-labs/minievm/app/ibc-hooks"
"github.com/initia-labs/minievm/x/evm/contracts/counter"
evmtypes "github.com/initia-labs/minievm/x/evm/types"
)

func Test_onReceiveIcs20Packet_noMemo(t *testing.T) {
Expand Down Expand Up @@ -76,29 +79,44 @@ func Test_onReceiveIcs20Packet_memo(t *testing.T) {
dataBz, err := json.Marshal(&data)
require.NoError(t, err)

pk := channeltypes.Packet{
Data: dataBz,
DestinationPort: "transfer-1",
DestinationChannel: "channel-1",
}

// mint for approval test
localDenom := evm_hooks.LocalDenom(pk, data.Denom)
intermediateSender := sdk.MustAccAddressFromBech32(evm_hooks.DeriveIntermediateSender(pk.DestinationChannel, data.Sender))
input.Faucet.Fund(ctx, intermediateSender, sdk.NewInt64Coin(localDenom, 1000000000))

// failed to due to acl
ack := input.IBCHooksMiddleware.OnRecvPacket(ctx, channeltypes.Packet{
Data: dataBz,
}, addr)
ack := input.IBCHooksMiddleware.OnRecvPacket(ctx, pk, addr)
require.False(t, ack.Success())

// set acl
require.NoError(t, input.IBCHooksKeeper.SetAllowed(ctx, contractAddr[:], true))

// success
ack = input.IBCHooksMiddleware.OnRecvPacket(ctx, channeltypes.Packet{
Data: dataBz,
}, addr)
ack = input.IBCHooksMiddleware.OnRecvPacket(ctx, pk, addr)
require.True(t, ack.Success())

queryInputBz, err := abi.Pack("count")
require.NoError(t, err)

// check the contract state
queryRes, logs, err := input.EVMKeeper.EVMCall(ctx, evmAddr, contractAddr, queryInputBz, nil)
queryRes, err := input.EVMKeeper.EVMStaticCall(ctx, evmAddr, contractAddr, queryInputBz)
require.NoError(t, err)
require.Equal(t, uint256.NewInt(1).Bytes32(), [32]byte(queryRes))
require.Empty(t, logs)

// check allowance
erc20Addr, err := input.EVMKeeper.GetContractAddrByDenom(ctx, localDenom)
require.NoError(t, err)
queryInputBz, err = input.EVMKeeper.ERC20Keeper().GetERC20ABI().Pack("allowance", common.BytesToAddress(intermediateSender.Bytes()), contractAddr)
require.NoError(t, err)
queryRes, err = input.EVMKeeper.EVMStaticCall(ctx, evmtypes.StdAddress, erc20Addr, queryInputBz)
require.NoError(t, err)
require.Equal(t, uint256.NewInt(10000).Bytes32(), [32]byte(queryRes))
}

func Test_OnReceivePacket_ICS721(t *testing.T) {
Expand Down Expand Up @@ -168,19 +186,36 @@ func Test_onReceivePacket_memo_ICS721(t *testing.T) {
dataBz, err := json.Marshal(&data)
require.NoError(t, err)

pk := channeltypes.Packet{
Data: dataBz,
DestinationPort: "nfttransfer-1",
DestinationChannel: "channel-1",
}

// mint for approval test
localClassId := evm_hooks.LocalClassId(pk, data.ClassId)
intermediateSender := sdk.MustAccAddressFromBech32(evm_hooks.DeriveIntermediateSender(pk.DestinationChannel, data.Sender))
err = input.EVMKeeper.ERC721Keeper().CreateOrUpdateClass(ctx, localClassId, data.ClassUri, data.ClassData)
require.NoError(t, err)
err = input.EVMKeeper.ERC721Keeper().Mints(
ctx,
intermediateSender,
localClassId,
[]string{"tokenId"},
[]string{"tokenUri"},
[]string{"tokenData"},
)
require.NoError(t, err)

// failed to due to acl
ack := input.IBCHooksMiddleware.OnRecvPacket(ctx, channeltypes.Packet{
Data: dataBz,
}, addr)
ack := input.IBCHooksMiddleware.OnRecvPacket(ctx, pk, addr)
require.False(t, ack.Success())

// set acl
require.NoError(t, input.IBCHooksKeeper.SetAllowed(ctx, contractAddr[:], true))

// success
ack = input.IBCHooksMiddleware.OnRecvPacket(ctx, channeltypes.Packet{
Data: dataBz,
}, addr)
ack = input.IBCHooksMiddleware.OnRecvPacket(ctx, pk, addr)
require.True(t, ack.Success())

queryInputBz, err := abi.Pack("count")
Expand All @@ -191,6 +226,17 @@ func Test_onReceivePacket_memo_ICS721(t *testing.T) {
require.NoError(t, err)
require.Equal(t, uint256.NewInt(1).Bytes32(), [32]byte(queryRes))
require.Empty(t, logs)

// check allowance
tokenId, ok := evmtypes.TokenIdToBigInt(localClassId, data.TokenIds[0])
require.True(t, ok)
erc721Addr, err := input.EVMKeeper.GetContractAddrByClassId(ctx, localClassId)
require.NoError(t, err)
queryInputBz, err = input.EVMKeeper.ERC721Keeper().GetERC721ABI().Pack("getApproved", tokenId)
require.NoError(t, err)
queryRes, err = input.EVMKeeper.EVMStaticCall(ctx, evmtypes.StdAddress, erc721Addr, queryInputBz)
require.NoError(t, err)
require.Equal(t, contractAddr.Bytes(), common.HexToAddress(hexutil.Encode(queryRes)).Bytes())
}

func Test_onReceivePacket_memo_ICS721_Wasm(t *testing.T) {
Expand Down Expand Up @@ -233,21 +279,37 @@ func Test_onReceivePacket_memo_ICS721_Wasm(t *testing.T) {
dataBz, err := json.Marshal(&data)
require.NoError(t, err)

pk := channeltypes.Packet{
SourcePort: "wasm.contract_address",
Data: dataBz,
DestinationPort: "nfttransfer-1",
DestinationChannel: "channel-1",
}

// mint for approval test
localClassId := evm_hooks.LocalClassId(pk, data.ClassId)
intermediateSender := sdk.MustAccAddressFromBech32(evm_hooks.DeriveIntermediateSender(pk.DestinationChannel, data.Sender))
err = input.EVMKeeper.ERC721Keeper().CreateOrUpdateClass(ctx, localClassId, data.ClassUri, data.ClassData)
require.NoError(t, err)
err = input.EVMKeeper.ERC721Keeper().Mints(
ctx,
intermediateSender,
localClassId,
[]string{"tokenId"},
[]string{"tokenUri"},
[]string{"tokenData"},
)
require.NoError(t, err)

// failed to due to acl
ack := input.IBCHooksMiddleware.OnRecvPacket(ctx, channeltypes.Packet{
SourcePort: "wasm.contract_address",
Data: dataBz,
}, addr)
ack := input.IBCHooksMiddleware.OnRecvPacket(ctx, pk, addr)
require.False(t, ack.Success())

// set acl
require.NoError(t, input.IBCHooksKeeper.SetAllowed(ctx, contractAddr[:], true))

// success
ack = input.IBCHooksMiddleware.OnRecvPacket(ctx, channeltypes.Packet{
SourcePort: "wasm.contract_address",
Data: dataBz,
}, addr)
ack = input.IBCHooksMiddleware.OnRecvPacket(ctx, pk, addr)
require.True(t, ack.Success())

queryInputBz, err := abi.Pack("count")
Expand All @@ -258,4 +320,15 @@ func Test_onReceivePacket_memo_ICS721_Wasm(t *testing.T) {
require.NoError(t, err)
require.Equal(t, uint256.NewInt(1).Bytes32(), [32]byte(queryRes))
require.Empty(t, logs)

// check allowance
tokenId, ok := evmtypes.TokenIdToBigInt(localClassId, data.TokenIds[0])
require.True(t, ok)
erc721Addr, err := input.EVMKeeper.GetContractAddrByClassId(ctx, localClassId)
require.NoError(t, err)
queryInputBz, err = input.EVMKeeper.ERC721Keeper().GetERC721ABI().Pack("getApproved", tokenId)
require.NoError(t, err)
queryRes, err = input.EVMKeeper.EVMStaticCall(ctx, evmtypes.StdAddress, erc721Addr, queryInputBz)
require.NoError(t, err)
require.Equal(t, contractAddr.Bytes(), common.HexToAddress(hexutil.Encode(queryRes)).Bytes())
}
Loading

0 comments on commit 0823a2d

Please sign in to comment.