Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(core): Wire packet handler to core message server #7091

Merged
merged 9 commits into from
Aug 8, 2024
6 changes: 6 additions & 0 deletions modules/core/04-channel/keeper/export_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package keeper
import (
sdk "github.com/cosmos/cosmos-sdk/types"

capabilitytypes "github.com/cosmos/ibc-go/modules/capability/types"
"github.com/cosmos/ibc-go/v9/modules/core/04-channel/types"
)

Expand Down Expand Up @@ -34,3 +35,8 @@ func (k *Keeper) SetUpgradeErrorReceipt(ctx sdk.Context, portID, channelID strin
func (k *Keeper) SetRecvStartSequence(ctx sdk.Context, portID, channelID string, sequence uint64) {
k.setRecvStartSequence(ctx, portID, channelID, sequence)
}

// TimeoutExecuted is a wrapper around timeoutExecuted to allow the function to be directly called in tests.
func (k *Keeper) TimeoutExecuted(ctx sdk.Context, capability *capabilitytypes.Capability, packet types.Packet) error {
return k.timeoutExecuted(ctx, capability, packet)
}
15 changes: 11 additions & 4 deletions modules/core/04-channel/keeper/timeout.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
// ante handler.
func (k *Keeper) TimeoutPacket(
ctx sdk.Context,
chanCap *capabilitytypes.Capability,
packet types.Packet,
proof []byte,
proofHeight exported.Height,
Expand Down Expand Up @@ -119,18 +120,21 @@ func (k *Keeper) TimeoutPacket(
return "", err
}

// NOTE: the remaining code is located in the TimeoutExecuted function
if err = k.timeoutExecuted(ctx, chanCap, packet); err != nil {
return "", err
}

return channel.Version, nil
}

// TimeoutExecuted deletes the commitment send from this chain after it verifies timeout.
// timeoutExecuted deletes the commitment send from this chain after it verifies timeout.
// If the timed-out packet came from an ORDERED channel then this channel will be closed.
// If the channel is in the FLUSHING state and there is a counterparty upgrade, then the
// upgrade will be aborted if the upgrade has timed out. Otherwise, if there are no more inflight packets,
// then the channel will be set to the FLUSHCOMPLETE state.
//
// CONTRACT: this function must be called in the IBC handler
func (k *Keeper) TimeoutExecuted(
func (k *Keeper) timeoutExecuted(
ctx sdk.Context,
chanCap *capabilitytypes.Capability,
packet types.Packet,
Expand Down Expand Up @@ -298,6 +302,9 @@ func (k *Keeper) TimeoutOnClose(
return "", err
}

// NOTE: the remaining code is located in the TimeoutExecuted function
if err = k.timeoutExecuted(ctx, chanCap, packet); err != nil {
return "", err
}

return channel.Version, nil
}
7 changes: 6 additions & 1 deletion modules/core/04-channel/keeper/timeout_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ func (suite *KeeperTestSuite) TestTimeoutPacket() {
var (
path *ibctesting.Path
packet types.Packet
chanCap *capabilitytypes.Capability
nextSeqRecv uint64
ordered bool
expError *errorsmod.Error
Expand All @@ -47,6 +48,8 @@ func (suite *KeeperTestSuite) TestTimeoutPacket() {
// need to update chainA's client representing chainB to prove missing ack
err = path.EndpointA.UpdateClient()
suite.Require().NoError(err)

chanCap = suite.chainA.GetChannelCapability(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)
}, true},
{"success: UNORDERED", func() {
ordered = false
Expand All @@ -60,6 +63,7 @@ func (suite *KeeperTestSuite) TestTimeoutPacket() {
// need to update chainA's client representing chainB to prove missing ack
err = path.EndpointA.UpdateClient()
suite.Require().NoError(err)
chanCap = suite.chainA.GetChannelCapability(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)
}, true},
{"packet already timed out: ORDERED", func() {
expError = types.ErrNoOpMsg
Expand Down Expand Up @@ -144,6 +148,7 @@ func (suite *KeeperTestSuite) TestTimeoutPacket() {
timeoutTimestamp := uint64(suite.chainB.GetContext().BlockTime().UnixNano())

sequence, err := path.EndpointA.SendPacket(defaultTimeoutHeight, timeoutTimestamp, ibctesting.MockPacketData)

suite.Require().NoError(err)
packet = types.NewPacket(ibctesting.MockPacketData, sequence, path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID, defaultTimeoutHeight, timeoutTimestamp)
err = path.EndpointA.UpdateClient()
Expand Down Expand Up @@ -220,7 +225,7 @@ func (suite *KeeperTestSuite) TestTimeoutPacket() {
}
}

channelVersion, err := suite.chainA.App.GetIBCKeeper().ChannelKeeper.TimeoutPacket(suite.chainA.GetContext(), packet, proof, proofHeight, nextSeqRecv)
channelVersion, err := suite.chainA.App.GetIBCKeeper().ChannelKeeper.TimeoutPacket(suite.chainA.GetContext(), chanCap, packet, proof, proofHeight, nextSeqRecv)

if tc.expPass {
suite.Require().NoError(err)
Expand Down
43 changes: 43 additions & 0 deletions modules/core/keeper/expected_keeper.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package keeper

import (
sdk "github.com/cosmos/cosmos-sdk/types"

capabilitytypes "github.com/cosmos/ibc-go/modules/capability/types"
channeltypes "github.com/cosmos/ibc-go/v9/modules/core/04-channel/types"
"github.com/cosmos/ibc-go/v9/modules/core/exported"
)

type PacketHandler interface {
RecvPacket(
ctx sdk.Context,
chanCap *capabilitytypes.Capability,
packet channeltypes.Packet,
proof []byte,
proofHeight exported.Height) (string, error)

WriteAcknowledgement(
ctx sdk.Context,
chanCap *capabilitytypes.Capability,
packet exported.PacketI,
acknowledgement exported.Acknowledgement,
) error

AcknowledgePacket(
ctx sdk.Context,
chanCap *capabilitytypes.Capability,
packet channeltypes.Packet,
acknowledgement []byte,
proof []byte,
proofHeight exported.Height,
) (string, error)

TimeoutPacket(
ctx sdk.Context,
chanCap *capabilitytypes.Capability,
Copy link
Member

@AdityaSripal AdityaSripal Aug 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm if there wasn't originally a chanCapability in this signature why do we need to add it only to then remove it later?

Copy link
Contributor Author

@DimitrisJim DimitrisJim Aug 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because I inlined timeoutExecuted and need to have the signatures matching (TimeoutPacket/OnClose didn't require capability previously), unfortunate but pile it on the list of clean-ups after port-router!

packet channeltypes.Packet,
proof []byte,
proofHeight exported.Height,
nextSequenceRecv uint64,
) (string, error)
}
24 changes: 14 additions & 10 deletions modules/core/keeper/keeper.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,17 @@ import (
channelkeeper "github.com/cosmos/ibc-go/v9/modules/core/04-channel/keeper"
portkeeper "github.com/cosmos/ibc-go/v9/modules/core/05-port/keeper"
porttypes "github.com/cosmos/ibc-go/v9/modules/core/05-port/types"
packetserver "github.com/cosmos/ibc-go/v9/modules/core/packet-server/keeper"
"github.com/cosmos/ibc-go/v9/modules/core/types"
)

// Keeper defines each ICS keeper for IBC
type Keeper struct {
ClientKeeper *clientkeeper.Keeper
ConnectionKeeper *connectionkeeper.Keeper
ChannelKeeper *channelkeeper.Keeper
PortKeeper *portkeeper.Keeper
ClientKeeper *clientkeeper.Keeper
ConnectionKeeper *connectionkeeper.Keeper
ChannelKeeper *channelkeeper.Keeper
PacketServerKeeper *packetserver.Keeper
PortKeeper *portkeeper.Keeper

cdc codec.BinaryCodec

Expand Down Expand Up @@ -54,14 +56,16 @@ func NewKeeper(
connectionKeeper := connectionkeeper.NewKeeper(cdc, key, paramSpace, clientKeeper)
portKeeper := portkeeper.NewKeeper(scopedKeeper)
channelKeeper := channelkeeper.NewKeeper(cdc, key, clientKeeper, connectionKeeper, portKeeper, scopedKeeper)
packetKeeper := packetserver.NewKeeper(cdc, channelKeeper, clientKeeper)

return &Keeper{
cdc: cdc,
ClientKeeper: clientKeeper,
ConnectionKeeper: connectionKeeper,
ChannelKeeper: channelKeeper,
PortKeeper: portKeeper,
authority: authority,
cdc: cdc,
ClientKeeper: clientKeeper,
ConnectionKeeper: connectionKeeper,
ChannelKeeper: channelKeeper,
PacketServerKeeper: packetKeeper,
PortKeeper: portKeeper,
authority: authority,
}
}

Expand Down
98 changes: 69 additions & 29 deletions modules/core/keeper/msg_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@ package keeper
import (
"context"
"errors"
"fmt"

errorsmod "cosmossdk.io/errors"

sdk "github.com/cosmos/cosmos-sdk/types"

capabilitytypes "github.com/cosmos/ibc-go/modules/capability/types"
clienttypes "github.com/cosmos/ibc-go/v9/modules/core/02-client/types"
connectiontypes "github.com/cosmos/ibc-go/v9/modules/core/03-connection/types"
"github.com/cosmos/ibc-go/v9/modules/core/04-channel/keeper"
Expand Down Expand Up @@ -460,6 +462,11 @@ func (k *Keeper) ChannelCloseConfirm(goCtx context.Context, msg *channeltypes.Ms

// RecvPacket defines a rpc handler method for MsgRecvPacket.
func (k *Keeper) RecvPacket(goCtx context.Context, msg *channeltypes.MsgRecvPacket) (*channeltypes.MsgRecvPacketResponse, error) {
var (
packetHandler PacketHandler
module string
capability *capabilitytypes.Capability
)
ctx := sdk.UnwrapSDKContext(goCtx)

relayer, err := sdk.AccAddressFromBech32(msg.Signer)
Expand All @@ -468,11 +475,22 @@ func (k *Keeper) RecvPacket(goCtx context.Context, msg *channeltypes.MsgRecvPack
return nil, errorsmod.Wrap(err, "Invalid address for msg Signer")
}

// Lookup module by channel capability
module, capability, err := k.ChannelKeeper.LookupModuleByChannel(ctx, msg.Packet.DestinationPort, msg.Packet.DestinationChannel)
if err != nil {
ctx.Logger().Error("receive packet failed", "port-id", msg.Packet.SourcePort, "channel-id", msg.Packet.SourceChannel, "error", errorsmod.Wrap(err, "could not retrieve module from port-id"))
return nil, errorsmod.Wrap(err, "could not retrieve module from port-id")
switch msg.Packet.ProtocolVersion {
case channeltypes.IBC_VERSION_UNSPECIFIED, channeltypes.IBC_VERSION_1:
packetHandler = k.ChannelKeeper

// Lookup module by channel capability
module, capability, err = k.ChannelKeeper.LookupModuleByChannel(ctx, msg.Packet.DestinationPort, msg.Packet.DestinationChannel)
if err != nil {
ctx.Logger().Error("acknowledgement failed", "port-id", msg.Packet.SourcePort, "channel-id", msg.Packet.SourceChannel, "error", errorsmod.Wrap(err, "could not retrieve module from port-id"))
return nil, errorsmod.Wrap(err, "could not retrieve module from port-id")
}

case channeltypes.IBC_VERSION_2:
packetHandler = k.PacketServerKeeper
module = msg.Packet.DestinationPort
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was not sure if we want to do fallback yet. can tweak and add tho

default:
panic(fmt.Errorf("unsupported protocol version %d", msg.Packet.ProtocolVersion))
}

// Retrieve callbacks from router
Expand All @@ -487,7 +505,7 @@ func (k *Keeper) RecvPacket(goCtx context.Context, msg *channeltypes.MsgRecvPack
// If the packet was already received, perform a no-op
// Use a cached context to prevent accidental state changes
cacheCtx, writeFn := ctx.CacheContext()
channelVersion, err := k.ChannelKeeper.RecvPacket(cacheCtx, capability, msg.Packet, msg.ProofCommitment, msg.ProofHeight)
channelVersion, err := packetHandler.RecvPacket(cacheCtx, capability, msg.Packet, msg.ProofCommitment, msg.ProofHeight)

switch err {
case nil:
Expand Down Expand Up @@ -518,7 +536,7 @@ func (k *Keeper) RecvPacket(goCtx context.Context, msg *channeltypes.MsgRecvPack
// NOTE: IBC applications modules may call the WriteAcknowledgement asynchronously if the
// acknowledgement is nil.
if ack != nil {
if err := k.ChannelKeeper.WriteAcknowledgement(ctx, capability, msg.Packet, ack); err != nil {
if err := packetHandler.WriteAcknowledgement(ctx, capability, msg.Packet, ack); err != nil {
return nil, err
}
}
Expand All @@ -532,6 +550,11 @@ func (k *Keeper) RecvPacket(goCtx context.Context, msg *channeltypes.MsgRecvPack

// Timeout defines a rpc handler method for MsgTimeout.
func (k *Keeper) Timeout(goCtx context.Context, msg *channeltypes.MsgTimeout) (*channeltypes.MsgTimeoutResponse, error) {
var (
packetHandler PacketHandler
module string
capability *capabilitytypes.Capability
)
ctx := sdk.UnwrapSDKContext(goCtx)

relayer, err := sdk.AccAddressFromBech32(msg.Signer)
Expand All @@ -540,11 +563,22 @@ func (k *Keeper) Timeout(goCtx context.Context, msg *channeltypes.MsgTimeout) (*
return nil, errorsmod.Wrap(err, "Invalid address for msg Signer")
}

// Lookup module by channel capability
module, capability, err := k.ChannelKeeper.LookupModuleByChannel(ctx, msg.Packet.SourcePort, msg.Packet.SourceChannel)
if err != nil {
ctx.Logger().Error("timeout failed", "port-id", msg.Packet.SourcePort, "channel-id", msg.Packet.SourceChannel, "error", errorsmod.Wrap(err, "could not retrieve module from port-id"))
return nil, errorsmod.Wrap(err, "could not retrieve module from port-id")
switch msg.Packet.ProtocolVersion {
case channeltypes.IBC_VERSION_UNSPECIFIED, channeltypes.IBC_VERSION_1:
packetHandler = k.ChannelKeeper

// Lookup module by channel capability
module, capability, err = k.ChannelKeeper.LookupModuleByChannel(ctx, msg.Packet.SourcePort, msg.Packet.SourceChannel)
if err != nil {
ctx.Logger().Error("acknowledgement failed", "port-id", msg.Packet.SourcePort, "channel-id", msg.Packet.SourceChannel, "error", errorsmod.Wrap(err, "could not retrieve module from port-id"))
return nil, errorsmod.Wrap(err, "could not retrieve module from port-id")
}

case channeltypes.IBC_VERSION_2:
packetHandler = k.PacketServerKeeper
module = msg.Packet.SourcePort
default:
panic(fmt.Errorf("unsupported protocol version %d", msg.Packet.ProtocolVersion))
}

// Retrieve callbacks from router
Expand All @@ -559,7 +593,7 @@ func (k *Keeper) Timeout(goCtx context.Context, msg *channeltypes.MsgTimeout) (*
// If the timeout was already received, perform a no-op
// Use a cached context to prevent accidental state changes
cacheCtx, writeFn := ctx.CacheContext()
channelVersion, err := k.ChannelKeeper.TimeoutPacket(cacheCtx, msg.Packet, msg.ProofUnreceived, msg.ProofHeight, msg.NextSequenceRecv)
channelVersion, err := packetHandler.TimeoutPacket(cacheCtx, capability, msg.Packet, msg.ProofUnreceived, msg.ProofHeight, msg.NextSequenceRecv)

switch err {
case nil:
Expand All @@ -573,11 +607,6 @@ func (k *Keeper) Timeout(goCtx context.Context, msg *channeltypes.MsgTimeout) (*
return nil, errorsmod.Wrap(err, "timeout packet verification failed")
}

// Delete packet commitment
if err = k.ChannelKeeper.TimeoutExecuted(ctx, capability, msg.Packet); err != nil {
return nil, err
}

// Perform application logic callback
err = cbs.OnTimeoutPacket(ctx, channelVersion, msg.Packet, relayer)
if err != nil {
Expand Down Expand Up @@ -635,11 +664,6 @@ func (k *Keeper) TimeoutOnClose(goCtx context.Context, msg *channeltypes.MsgTime
return nil, errorsmod.Wrap(err, "timeout on close packet verification failed")
}

// Delete packet commitment
if err = k.ChannelKeeper.TimeoutExecuted(ctx, capability, msg.Packet); err != nil {
return nil, err
}

// Perform application logic callback
//
// NOTE: MsgTimeout and MsgTimeoutOnClose use the same "OnTimeoutPacket"
Expand All @@ -659,6 +683,11 @@ func (k *Keeper) TimeoutOnClose(goCtx context.Context, msg *channeltypes.MsgTime

// Acknowledgement defines a rpc handler method for MsgAcknowledgement.
func (k *Keeper) Acknowledgement(goCtx context.Context, msg *channeltypes.MsgAcknowledgement) (*channeltypes.MsgAcknowledgementResponse, error) {
var (
packetHandler PacketHandler
module string
capability *capabilitytypes.Capability
)
ctx := sdk.UnwrapSDKContext(goCtx)

relayer, err := sdk.AccAddressFromBech32(msg.Signer)
Expand All @@ -667,11 +696,22 @@ func (k *Keeper) Acknowledgement(goCtx context.Context, msg *channeltypes.MsgAck
return nil, errorsmod.Wrap(err, "Invalid address for msg Signer")
}

// Lookup module by channel capability
module, capability, err := k.ChannelKeeper.LookupModuleByChannel(ctx, msg.Packet.SourcePort, msg.Packet.SourceChannel)
if err != nil {
ctx.Logger().Error("acknowledgement failed", "port-id", msg.Packet.SourcePort, "channel-id", msg.Packet.SourceChannel, "error", errorsmod.Wrap(err, "could not retrieve module from port-id"))
return nil, errorsmod.Wrap(err, "could not retrieve module from port-id")
switch msg.Packet.ProtocolVersion {
case channeltypes.IBC_VERSION_UNSPECIFIED, channeltypes.IBC_VERSION_1:
packetHandler = k.ChannelKeeper

// Lookup module by channel capability
module, capability, err = k.ChannelKeeper.LookupModuleByChannel(ctx, msg.Packet.SourcePort, msg.Packet.SourceChannel)
if err != nil {
ctx.Logger().Error("acknowledgement failed", "port-id", msg.Packet.SourcePort, "channel-id", msg.Packet.SourceChannel, "error", errorsmod.Wrap(err, "could not retrieve module from port-id"))
return nil, errorsmod.Wrap(err, "could not retrieve module from port-id")
}

case channeltypes.IBC_VERSION_2:
packetHandler = k.PacketServerKeeper
module = msg.Packet.SourcePort
default:
panic(fmt.Errorf("unsupported protocol version %d", msg.Packet.ProtocolVersion))
}

// Retrieve callbacks from router
Expand All @@ -686,7 +726,7 @@ func (k *Keeper) Acknowledgement(goCtx context.Context, msg *channeltypes.MsgAck
// If the acknowledgement was already received, perform a no-op
// Use a cached context to prevent accidental state changes
cacheCtx, writeFn := ctx.CacheContext()
channelVersion, err := k.ChannelKeeper.AcknowledgePacket(cacheCtx, capability, msg.Packet, msg.Acknowledgement, msg.ProofAcked, msg.ProofHeight)
channelVersion, err := packetHandler.AcknowledgePacket(cacheCtx, capability, msg.Packet, msg.Acknowledgement, msg.ProofAcked, msg.ProofHeight)

switch err {
case nil:
Expand Down
Loading
Loading