diff --git a/app/app.go b/app/app.go index 0eacd4282c..79f901e33e 100644 --- a/app/app.go +++ b/app/app.go @@ -675,6 +675,9 @@ func New( app.DistrKeeper, app.OracleKeeper, app.TransferKeeper, + app.IBCKeeper.ClientKeeper, + app.IBCKeeper.ConnectionKeeper, + app.IBCKeeper.ChannelKeeper, ); err != nil { panic(err) } diff --git a/precompiles/common/expected_keepers.go b/precompiles/common/expected_keepers.go index 3b85964315..f175d6719b 100644 --- a/precompiles/common/expected_keepers.go +++ b/precompiles/common/expected_keepers.go @@ -3,6 +3,10 @@ package common import ( "context" + connectiontypes "github.com/cosmos/ibc-go/v3/modules/core/03-connection/types" + "github.com/cosmos/ibc-go/v3/modules/core/04-channel/types" + "github.com/cosmos/ibc-go/v3/modules/core/exported" + sdk "github.com/cosmos/cosmos-sdk/types" banktypes "github.com/cosmos/cosmos-sdk/x/bank/types" govtypes "github.com/cosmos/cosmos-sdk/x/gov/types" @@ -79,3 +83,16 @@ type TransferKeeper interface { timeoutTimestamp uint64, ) error } + +type ClientKeeper interface { + GetClientState(ctx sdk.Context, clientID string) (exported.ClientState, bool) + GetClientConsensusState(ctx sdk.Context, clientID string, height exported.Height) (exported.ConsensusState, bool) +} + +type ConnectionKeeper interface { + GetConnection(ctx sdk.Context, connectionID string) (connectiontypes.ConnectionEnd, bool) +} + +type ChannelKeeper interface { + GetChannel(ctx sdk.Context, portID, channelID string) (types.Channel, bool) +} diff --git a/precompiles/ibc/IBC.sol b/precompiles/ibc/IBC.sol index 649d3a3b60..892652a8ba 100644 --- a/precompiles/ibc/IBC.sol +++ b/precompiles/ibc/IBC.sol @@ -19,4 +19,12 @@ interface IBC { uint64 revisionHeight, uint64 timeoutTimestamp ) external returns (bool success); + + function transferWithDefaultTimeout( + string toAddress, + string memory port, + string memory channel, + string memory denom, + uint256 amount + ) external returns (bool success); } diff --git a/precompiles/ibc/abi.json b/precompiles/ibc/abi.json index a1ff1e9fc2..93db335e43 100644 --- a/precompiles/ibc/abi.json +++ b/precompiles/ibc/abi.json @@ -50,7 +50,46 @@ "type": "bool" } ], - "stateMutability": "view", + "stateMutability": "payable", + "type": "function" + }, + { + "inputs": [ + { + "internalType": "string", + "name": "toAddress", + "type": "string" + }, + { + "internalType": "string", + "name": "port", + "type": "string" + }, + { + "internalType": "string", + "name": "channel", + "type": "string" + }, + { + "internalType": "string", + "name": "denom", + "type": "string" + }, + { + "internalType": "uint256", + "name": "amount", + "type": "uint256" + } + ], + "name": "transferWithDefaultTimeout", + "outputs": [ + { + "internalType": "bool", + "name": "success", + "type": "bool" + } + ], + "stateMutability": "payable", "type": "function" } ] diff --git a/precompiles/ibc/ibc.go b/precompiles/ibc/ibc.go index c6f1921d47..9d84be672d 100644 --- a/precompiles/ibc/ibc.go +++ b/precompiles/ibc/ibc.go @@ -7,12 +7,15 @@ import ( "fmt" "math/big" + "github.com/cosmos/ibc-go/v3/modules/apps/transfer/types" + "github.com/cosmos/cosmos-sdk/types/bech32" "github.com/sei-protocol/sei-chain/utils" sdk "github.com/cosmos/cosmos-sdk/types" clienttypes "github.com/cosmos/ibc-go/v3/modules/core/02-client/types" + connectiontypes "github.com/cosmos/ibc-go/v3/modules/core/03-connection/types" "github.com/ethereum/go-ethereum/accounts/abi" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/tracing" @@ -21,7 +24,8 @@ import ( ) const ( - TransferMethod = "transfer" + TransferMethod = "transfer" + TransferWithDefaultTimeoutMethod = "transferWithDefaultTimeout" ) const ( @@ -51,27 +55,41 @@ func GetABI() abi.ABI { type Precompile struct { pcommon.Precompile - address common.Address - transferKeeper pcommon.TransferKeeper - evmKeeper pcommon.EVMKeeper - - TransferID []byte + address common.Address + transferKeeper pcommon.TransferKeeper + evmKeeper pcommon.EVMKeeper + clientKeeper pcommon.ClientKeeper + connectionKeeper pcommon.ConnectionKeeper + channelKeeper pcommon.ChannelKeeper + + TransferID []byte + TransferWithDefaultTimeoutID []byte } -func NewPrecompile(transferKeeper pcommon.TransferKeeper, evmKeeper pcommon.EVMKeeper) (*Precompile, error) { +func NewPrecompile( + transferKeeper pcommon.TransferKeeper, + evmKeeper pcommon.EVMKeeper, + clientKeeper pcommon.ClientKeeper, + connectionKeeper pcommon.ConnectionKeeper, + channelKeeper pcommon.ChannelKeeper) (*Precompile, error) { newAbi := GetABI() p := &Precompile{ - Precompile: pcommon.Precompile{ABI: newAbi}, - address: common.HexToAddress(IBCAddress), - transferKeeper: transferKeeper, - evmKeeper: evmKeeper, + Precompile: pcommon.Precompile{ABI: newAbi}, + address: common.HexToAddress(IBCAddress), + transferKeeper: transferKeeper, + evmKeeper: evmKeeper, + clientKeeper: clientKeeper, + connectionKeeper: connectionKeeper, + channelKeeper: channelKeeper, } for name, m := range newAbi.Methods { switch name { case TransferMethod: p.TransferID = m.ID + case TransferWithDefaultTimeoutMethod: + p.TransferWithDefaultTimeoutID = m.ID } } @@ -116,6 +134,8 @@ func (p Precompile) RunAndCalculateGas(evm *vm.EVM, caller common.Address, calli switch method.Name { case TransferMethod: return p.transfer(ctx, method, args, caller) + case TransferWithDefaultTimeoutMethod: + return p.transferWithDefaultTimeout(ctx, method, args, caller) } return } @@ -138,61 +158,87 @@ func (p Precompile) transfer(ctx sdk.Context, method *abi.Method, args []interfa rerr = err return } - senderSeiAddr, ok := p.evmKeeper.GetSeiAddress(ctx, caller) - if !ok { - rerr = errors.New("caller is not a valid SEI address") + validatedArgs, err := p.validateCommonArgs(ctx, args, caller) + if err != nil { + rerr = err return } - receiverAddressString, ok := args[0].(string) - if !ok { - rerr = errors.New("receiverAddress is not a string") + if validatedArgs.amount.Cmp(big.NewInt(0)) == 0 { + // short circuit + remainingGas = pcommon.GetRemainingGas(ctx, p.evmKeeper) + ret, rerr = method.Outputs.Pack(true) return } - _, bz, err := bech32.DecodeAndConvert(receiverAddressString) - if err != nil { - rerr = err - return + + coin := sdk.Coin{ + Denom: validatedArgs.denom, + Amount: sdk.NewIntFromBigInt(validatedArgs.amount), } - err = sdk.VerifyAddressFormat(bz) - if err != nil { - rerr = err + + revisionNumber, ok := args[5].(uint64) + if !ok { + rerr = errors.New("revisionNumber is not a uint64") return } - port, ok := args[1].(string) + revisionHeight, ok := args[6].(uint64) if !ok { - rerr = errors.New("port is not a string") + rerr = errors.New("revisionHeight is not a uint64") return } - if port == "" { - rerr = errors.New("port cannot be empty") - return + + height := clienttypes.Height{ + RevisionNumber: revisionNumber, + RevisionHeight: revisionHeight, } - channelID, ok := args[2].(string) + timeoutTimestamp, ok := args[7].(uint64) if !ok { - rerr = errors.New("channelID is not a string") + rerr = errors.New("timeoutTimestamp is not a uint64") return } - if channelID == "" { - rerr = errors.New("channelID cannot be empty") + + err = p.transferKeeper.SendTransfer( + ctx, + validatedArgs.port, + validatedArgs.channelID, + coin, + validatedArgs.senderSeiAddr, + validatedArgs.receiverAddressString, + height, + timeoutTimestamp) + + if err != nil { + rerr = err return } + remainingGas = pcommon.GetRemainingGas(ctx, p.evmKeeper) + ret, rerr = method.Outputs.Pack(true) + return +} - denom := args[3].(string) - if denom == "" { - rerr = errors.New("invalid denom") +func (p Precompile) transferWithDefaultTimeout(ctx sdk.Context, method *abi.Method, args []interface{}, caller common.Address) (ret []byte, remainingGas uint64, rerr error) { + defer func() { + if err := recover(); err != nil { + ret = nil + remainingGas = 0 + rerr = fmt.Errorf("%s", err) + return + } + }() + + if err := pcommon.ValidateArgsLength(args, 5); err != nil { + rerr = err return } - - amount, ok := args[4].(*big.Int) - if !ok { - rerr = errors.New("amount is not a big.Int") + validatedArgs, err := p.validateCommonArgs(ctx, args, caller) + if err != nil { + rerr = err return } - if amount.Cmp(big.NewInt(0)) == 0 { + if validatedArgs.amount.Cmp(big.NewInt(0)) == 0 { // short circuit remainingGas = pcommon.GetRemainingGas(ctx, p.evmKeeper) ret, rerr = method.Outputs.Pack(true) @@ -200,34 +246,44 @@ func (p Precompile) transfer(ctx sdk.Context, method *abi.Method, args []interfa } coin := sdk.Coin{ - Denom: denom, - Amount: sdk.NewIntFromBigInt(amount), + Denom: validatedArgs.denom, + Amount: sdk.NewIntFromBigInt(validatedArgs.amount), } - revisionNumber, ok := args[5].(uint64) - if !ok { - rerr = errors.New("revisionNumber is not a uint64") + connection, err := p.getChannelConnection(ctx, validatedArgs.port, validatedArgs.channelID) + + if err != nil { + rerr = err return } - revisionHeight, ok := args[6].(uint64) - if !ok { - rerr = errors.New("revisionHeight is not a uint64") + latestConsensusHeight, err := p.getConsensusLatestHeight(ctx, *connection) + if err != nil { + rerr = err return } - height := clienttypes.Height{ - RevisionNumber: revisionNumber, - RevisionHeight: revisionHeight, + height, err := GetAdjustedHeight(*latestConsensusHeight) + if err != nil { + rerr = err + return } - timeoutTimestamp, ok := args[7].(uint64) - if !ok { - rerr = errors.New("timeoutTimestamp is not a uint64") + timeoutTimestamp, err := p.GetAdjustedTimestamp(ctx, connection.ClientId, *latestConsensusHeight) + if err != nil { + rerr = err return } - err = p.transferKeeper.SendTransfer(ctx, port, channelID, coin, senderSeiAddr, receiverAddressString, height, timeoutTimestamp) + err = p.transferKeeper.SendTransfer( + ctx, + validatedArgs.port, + validatedArgs.channelID, + coin, + validatedArgs.senderSeiAddr, + validatedArgs.receiverAddressString, + height, + timeoutTimestamp) if err != nil { rerr = err @@ -266,3 +322,127 @@ func (p Precompile) accAddressFromArg(ctx sdk.Context, arg interface{}) (sdk.Acc } return seiAddr, nil } + +func (p Precompile) getChannelConnection(ctx sdk.Context, port string, channelID string) (*connectiontypes.ConnectionEnd, error) { + channel, found := p.channelKeeper.GetChannel(ctx, port, channelID) + if !found { + return nil, errors.New("channel not found") + } + + connection, found := p.connectionKeeper.GetConnection(ctx, channel.ConnectionHops[0]) + + if !found { + return nil, errors.New("connection not found") + } + return &connection, nil +} + +func (p Precompile) getConsensusLatestHeight(ctx sdk.Context, connection connectiontypes.ConnectionEnd) (*clienttypes.Height, error) { + clientState, found := p.clientKeeper.GetClientState(ctx, connection.ClientId) + + if !found { + return nil, errors.New("could not get the client state") + } + + latestHeight := clientState.GetLatestHeight() + return &clienttypes.Height{ + RevisionNumber: latestHeight.GetRevisionNumber(), + RevisionHeight: latestHeight.GetRevisionHeight(), + }, nil +} + +func GetAdjustedHeight(latestConsensusHeight clienttypes.Height) (clienttypes.Height, error) { + defaultTimeoutHeight, err := clienttypes.ParseHeight(types.DefaultRelativePacketTimeoutHeight) + if err != nil { + return clienttypes.Height{}, err + } + + absoluteHeight := latestConsensusHeight + absoluteHeight.RevisionNumber += defaultTimeoutHeight.RevisionNumber + absoluteHeight.RevisionHeight += defaultTimeoutHeight.RevisionHeight + return absoluteHeight, nil +} + +func (p Precompile) GetAdjustedTimestamp(ctx sdk.Context, clientId string, height clienttypes.Height) (uint64, error) { + consensusState, found := p.clientKeeper.GetClientConsensusState(ctx, clientId, height) + var consensusStateTimestamp uint64 + if found { + consensusStateTimestamp = consensusState.GetTimestamp() + } + + defaultRelativePacketTimeoutTimestamp := types.DefaultRelativePacketTimeoutTimestamp + blockTime := ctx.BlockTime().UnixNano() + if blockTime > 0 { + now := uint64(blockTime) + if now > consensusStateTimestamp { + return now + defaultRelativePacketTimeoutTimestamp, nil + } else { + return consensusStateTimestamp + defaultRelativePacketTimeoutTimestamp, nil + } + } else { + return 0, errors.New("block time is not greater than Jan 1st, 1970 12:00 AM") + } +} + +type ValidatedArgs struct { + senderSeiAddr sdk.AccAddress + receiverAddressString string + port string + channelID string + denom string + amount *big.Int +} + +func (p Precompile) validateCommonArgs(ctx sdk.Context, args []interface{}, caller common.Address) (*ValidatedArgs, error) { + senderSeiAddr, ok := p.evmKeeper.GetSeiAddress(ctx, caller) + if !ok { + return nil, errors.New("caller is not a valid SEI address") + } + + receiverAddressString, ok := args[0].(string) + if !ok { + return nil, errors.New("receiverAddress is not a string") + } + _, bz, err := bech32.DecodeAndConvert(receiverAddressString) + if err != nil { + return nil, err + } + err = sdk.VerifyAddressFormat(bz) + if err != nil { + return nil, err + } + + port, ok := args[1].(string) + if !ok { + return nil, errors.New("port is not a string") + } + if port == "" { + return nil, errors.New("port cannot be empty") + } + + channelID, ok := args[2].(string) + if !ok { + return nil, errors.New("channelID is not a string") + } + if channelID == "" { + return nil, errors.New("channelID cannot be empty") + } + + denom := args[3].(string) + if denom == "" { + return nil, errors.New("invalid denom") + } + + amount, ok := args[4].(*big.Int) + if !ok { + return nil, errors.New("amount is not a big.Int") + } + return &ValidatedArgs{ + senderSeiAddr: senderSeiAddr, + receiverAddressString: receiverAddressString, + port: port, + channelID: channelID, + denom: denom, + amount: amount, + }, nil +} diff --git a/precompiles/ibc/ibc_test.go b/precompiles/ibc/ibc_test.go index fc07b9cb8a..8a6da8c149 100644 --- a/precompiles/ibc/ibc_test.go +++ b/precompiles/ibc/ibc_test.go @@ -2,9 +2,11 @@ package ibc_test import ( "errors" + "github.com/cosmos/ibc-go/v3/modules/core/exported" "math/big" "reflect" "testing" + "time" sdk "github.com/cosmos/cosmos-sdk/types" clienttypes "github.com/cosmos/ibc-go/v3/modules/core/02-client/types" @@ -36,7 +38,7 @@ func TestPrecompile_Run(t *testing.T) { senderSeiAddress, senderEvmAddress := testkeeper.MockAddressPair() receiverAddress := "cosmos1yykwxjzr2tv4mhx5tsf8090sdg96f2ax8fydk2" - pre, _ := ibc.NewPrecompile(nil, nil) + pre, _ := ibc.NewPrecompile(nil, nil, nil, nil, nil) testTransfer, _ := pre.ABI.MethodById(pre.TransferID) packedTrue, _ := testTransfer.Outputs.Pack(true) @@ -216,7 +218,7 @@ func TestPrecompile_Run(t *testing.T) { StateDB: stateDb, TxContext: vm.TxContext{Origin: senderEvmAddress}, } - p, _ := ibc.NewPrecompile(tt.fields.transferKeeper, k) + p, _ := ibc.NewPrecompile(tt.fields.transferKeeper, k, nil, nil, nil) transfer, err := p.ABI.MethodById(p.TransferID) require.Nil(t, err) inputs, err := transfer.Inputs.Pack(tt.args.input.receiverAddr, @@ -241,3 +243,364 @@ func TestPrecompile_Run(t *testing.T) { }) } } + +func TestTransferWithDefaultTimeoutPrecompile_Run(t *testing.T) { + senderSeiAddress, senderEvmAddress := testkeeper.MockAddressPair() + receiverAddress := "cosmos1yykwxjzr2tv4mhx5tsf8090sdg96f2ax8fydk2" + + type fields struct { + transferKeeper pcommon.TransferKeeper + clientKeeper pcommon.ClientKeeper + connectionKeeper pcommon.ConnectionKeeper + channelKeeper pcommon.ChannelKeeper + } + + type input struct { + receiverAddr string + sourcePort string + sourceChannel string + denom string + amount *big.Int + } + type args struct { + caller common.Address + callingContract common.Address + input *input + suppliedGas uint64 + value *big.Int + } + + commonArgs := args{ + caller: senderEvmAddress, + callingContract: senderEvmAddress, + input: &input{ + receiverAddr: receiverAddress, + sourcePort: "transfer", + sourceChannel: "channel-0", + denom: "denom", + amount: big.NewInt(100), + }, + suppliedGas: uint64(1000000), + value: nil, + } + + tests := []struct { + name string + fields fields + args args + wantBz []byte + wantRemainingGas uint64 + wantErr bool + wantErrMsg string + }{ + { + name: "failed transfer: caller not whitelisted", + fields: fields{transferKeeper: &MockTransferKeeper{}}, + args: args{caller: senderEvmAddress, callingContract: common.Address{}, input: commonArgs.input, suppliedGas: 1000000, value: nil}, + wantBz: nil, + wantErr: true, + wantErrMsg: "cannot delegatecall IBC", + }, + { + name: "failed transfer: empty sourcePort", + fields: fields{transferKeeper: &MockTransferKeeper{}}, + args: args{ + caller: senderEvmAddress, + callingContract: senderEvmAddress, + input: &input{ + receiverAddr: receiverAddress, + sourcePort: "", // empty sourcePort + sourceChannel: "channel-0", + denom: "denom", + amount: big.NewInt(100), + }, + suppliedGas: uint64(1000000), + value: nil, + }, + wantBz: nil, + wantErr: true, + wantErrMsg: "port cannot be empty", + }, + { + name: "failed transfer: empty sourceChannel", + fields: fields{transferKeeper: &MockTransferKeeper{}}, + args: args{ + caller: senderEvmAddress, + callingContract: senderEvmAddress, + input: &input{ + receiverAddr: receiverAddress, + sourcePort: "transfer", + sourceChannel: "", + denom: "denom", + amount: big.NewInt(100), + }, + suppliedGas: uint64(1000000), + value: nil, + }, + wantBz: nil, + wantErr: true, + wantErrMsg: "channelID cannot be empty", + }, + { + name: "failed transfer: invalid denom", + fields: fields{transferKeeper: &MockTransferKeeper{}}, + args: args{ + caller: senderEvmAddress, + callingContract: senderEvmAddress, + input: &input{ + receiverAddr: receiverAddress, + sourcePort: "transfer", + sourceChannel: "channel-0", + denom: "", + amount: big.NewInt(100), + }, + suppliedGas: uint64(1000000), + value: nil, + }, + wantBz: nil, + wantErr: true, + wantErrMsg: "invalid denom", + }, + { + name: "failed transfer: invalid receiver address", + fields: fields{transferKeeper: &MockTransferKeeper{}}, + args: args{ + caller: senderEvmAddress, + callingContract: senderEvmAddress, + input: &input{ + receiverAddr: "invalid", + sourcePort: "transfer", + sourceChannel: "channel-0", + denom: "", + amount: big.NewInt(100), + }, + suppliedGas: uint64(1000000), + value: nil, + }, + wantBz: nil, + wantErr: true, + wantErrMsg: "decoding bech32 failed: invalid bech32 string length 7", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + testApp := testkeeper.EVMTestApp + ctx := testApp.NewContext(false, tmtypes.Header{}).WithBlockHeight(2) + k := &testApp.EvmKeeper + k.SetAddressMapping(ctx, senderSeiAddress, senderEvmAddress) + stateDb := state.NewDBImpl(ctx, k, true) + evm := vm.EVM{ + StateDB: stateDb, + TxContext: vm.TxContext{Origin: senderEvmAddress}, + } + + p, _ := ibc.NewPrecompile(tt.fields.transferKeeper, + k, tt.fields.clientKeeper, + tt.fields.connectionKeeper, + tt.fields.channelKeeper) + transfer, err := p.ABI.MethodById(p.TransferWithDefaultTimeoutID) + require.Nil(t, err) + inputs, err := transfer.Inputs.Pack(tt.args.input.receiverAddr, + tt.args.input.sourcePort, tt.args.input.sourceChannel, tt.args.input.denom, tt.args.input.amount) + require.Nil(t, err) + gotBz, gotRemainingGas, err := p.RunAndCalculateGas(&evm, + tt.args.caller, + tt.args.callingContract, + append(p.TransferWithDefaultTimeoutID, inputs...), + tt.args.suppliedGas, + tt.args.value, + nil, + false) + if (err != nil) != tt.wantErr { + t.Errorf("Run() error = %v, wantErr %v", err, tt.wantErr) + return + } + if err != nil { + require.Equal(t, tt.wantErrMsg, err.Error()) + } + + if !reflect.DeepEqual(gotBz, tt.wantBz) { + t.Errorf("Run() gotBz = %v, want %v", gotBz, tt.wantBz) + } + if !reflect.DeepEqual(gotRemainingGas, tt.wantRemainingGas) { + t.Errorf("Run() gotRemainingGas = %v, want %v", gotRemainingGas, tt.wantRemainingGas) + } + }) + } +} + +func TestPrecompile_GetAdjustedHeight(t *testing.T) { + type args struct { + latestConsensusHeight clienttypes.Height + } + tests := []struct { + name string + args args + want clienttypes.Height + wantErr bool + }{ + { + name: "height is adjusted with defaults", + args: args{ + latestConsensusHeight: clienttypes.NewHeight(2, 3), + }, + want: clienttypes.Height{ + RevisionNumber: 2, + RevisionHeight: 1003, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ibc.GetAdjustedHeight(tt.args.latestConsensusHeight) + if (err != nil) != tt.wantErr { + t.Errorf("GetAdjustedHeight() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("GetAdjustedHeight() got = %v, want %v", got, tt.want) + } + }) + } +} + +type MockClientKeeper struct { + consensusState *MockConsensusState + returnConsensusState bool +} + +func (ck *MockClientKeeper) GetClientState(ctx sdk.Context, clientID string) (exported.ClientState, bool) { + return nil, false +} + +func (ck *MockClientKeeper) GetClientConsensusState(ctx sdk.Context, clientID string, height exported.Height) (exported.ConsensusState, bool) { + return ck.consensusState, ck.returnConsensusState +} + +type MockConsensusState struct { + timestamp uint64 +} + +func (m *MockConsensusState) Reset() { + panic("implement me") +} + +func (m *MockConsensusState) String() string { + panic("implement me") +} + +func (m *MockConsensusState) ProtoMessage() { + panic("implement me") +} + +func (m *MockConsensusState) ClientType() string { + return "mock" +} + +func (m *MockConsensusState) GetRoot() exported.Root { + return nil +} + +func (m *MockConsensusState) GetTimestamp() uint64 { + return m.timestamp +} + +func (m *MockConsensusState) ValidateBasic() error { + return nil +} + +func TestPrecompile_GetAdjustedTimestamp(t *testing.T) { + type fields struct { + transferKeeper pcommon.TransferKeeper + evmKeeper pcommon.EVMKeeper + clientKeeper pcommon.ClientKeeper + connectionKeeper pcommon.ConnectionKeeper + channelKeeper pcommon.ChannelKeeper + } + type args struct { + ctx sdk.Context + clientId string + height clienttypes.Height + } + timestampSeconds := 1714680155 + ctx := sdk.Context{} + tests := []struct { + name string + fields fields + args args + want uint64 + wantErr bool + }{ + { + name: "if consensus timestamp is less than the given time, return the given time adjusted with default", + fields: fields{ + clientKeeper: &MockClientKeeper{ + consensusState: &MockConsensusState{ + timestamp: uint64(timestampSeconds - 1), + }, + returnConsensusState: true, + }, + }, + args: args{ + ctx: ctx.WithBlockTime(time.Unix(int64(timestampSeconds), 0)), + }, + want: uint64(timestampSeconds)*1_000_000_000 + uint64((time.Duration(10) * time.Minute).Nanoseconds()), + wantErr: false, + }, + { + name: "if consensus state is not found, return the given time adjusted with default", + fields: fields{ + clientKeeper: &MockClientKeeper{ + returnConsensusState: false, + }, + }, + args: args{ + ctx: ctx.WithBlockTime(time.Unix(int64(timestampSeconds), 0)), + }, + want: uint64(timestampSeconds)*1_000_000_000 + uint64((time.Duration(10) * time.Minute).Nanoseconds()), + wantErr: false, + }, + { + name: "if time from local clock can not be retrieved, return error", + fields: fields{ + clientKeeper: &MockClientKeeper{ + returnConsensusState: false, + }, + }, + args: args{ + ctx: ctx.WithBlockTime(time.Unix(int64(0), 0)), + }, + wantErr: true, + }, + { + name: "if consensus timestamp is > than the given time, return the consensus time adjusted with default", + fields: fields{ + clientKeeper: &MockClientKeeper{ + consensusState: &MockConsensusState{ + timestamp: uint64(timestampSeconds+1) * 1_000_000_000, + }, + returnConsensusState: true, + }, + }, + args: args{ + ctx: ctx.WithBlockTime(time.Unix(int64(timestampSeconds), 0)), + }, + want: uint64(timestampSeconds+1)*1_000_000_000 + uint64((time.Duration(10) * time.Minute).Nanoseconds()), + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p, _ := ibc.NewPrecompile(tt.fields.transferKeeper, tt.fields.evmKeeper, tt.fields.clientKeeper, tt.fields.connectionKeeper, tt.fields.channelKeeper) + got, err := p.GetAdjustedTimestamp(tt.args.ctx, tt.args.clientId, tt.args.height) + if (err != nil) != tt.wantErr { + t.Errorf("GetAdjustedTimestamp() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("GetAdjustedTimestamp() got = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/precompiles/setup.go b/precompiles/setup.go index 3cadd9792d..052c05e9d2 100644 --- a/precompiles/setup.go +++ b/precompiles/setup.go @@ -49,6 +49,9 @@ func InitializePrecompiles( distrKeeper common.DistributionKeeper, oracleKeeper common.OracleKeeper, transferKeeper common.TransferKeeper, + clientKeeper common.ClientKeeper, + connectionKeeper common.ConnectionKeeper, + channelKeeper common.ChannelKeeper, ) error { SetupMtx.Lock() defer SetupMtx.Unlock() @@ -87,7 +90,7 @@ func InitializePrecompiles( if err != nil { return err } - ibcp, err := ibc.NewPrecompile(transferKeeper, evmKeeper) + ibcp, err := ibc.NewPrecompile(transferKeeper, evmKeeper, clientKeeper, connectionKeeper, channelKeeper) if err != nil { return err } @@ -130,7 +133,7 @@ func InitializePrecompiles( func GetPrecompileInfo(name string) PrecompileInfo { if !Initialized { // Precompile Info does not require any keeper state - _ = InitializePrecompiles(true, nil, nil, nil, nil, nil, nil, nil, nil, nil) + _ = InitializePrecompiles(true, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) } i, ok := PrecompileNamesToInfo[name] if !ok {