diff --git a/app/app.go b/app/app.go index ba484238..548904b3 100644 --- a/app/app.go +++ b/app/app.go @@ -678,6 +678,13 @@ func NewApp( keys[epochsmoduletypes.MemStoreKey], app.GetSubspace(epochsmoduletypes.ModuleName), ) + + app.EpochsKeeper.SetHooks( + epochsmoduletypes.NewMultiEpochHooks( + app.RegistryKeeper.Hooks(), + // insert hooks here + )) + epochsModule := epochsmodule.NewAppModule(appCodec, app.EpochsKeeper, app.AccountKeeper, app.BankKeeper) app.RegistryKeeper = *registrymodulekeeper.NewKeeper( @@ -726,12 +733,6 @@ func NewApp( ), ) - app.EpochsKeeper.SetHooks( - epochsmoduletypes.NewMultiEpochHooks( - app.RegistryKeeper.Hooks(), - // insert hooks here - )) - /**** Module Options ****/ // NOTE: we may consider parsing `appOpts` inside module constructors. For the moment diff --git a/x/epochs/keeper/hooks.go b/x/epochs/keeper/hooks.go index 8fcc7a2e..5200cbf6 100644 --- a/x/epochs/keeper/hooks.go +++ b/x/epochs/keeper/hooks.go @@ -4,12 +4,18 @@ import ( sdk "github.com/cosmos/cosmos-sdk/types" ) -// AfterEpochEnd executes the indicated hook after epochs ends +// AfterEpochEnd epoch hook func (k Keeper) AfterEpochEnd(ctx sdk.Context, identifier string, epochNumber int64) { + if k.hooks == nil { + panic("hooks not set in keeper") + } k.hooks.AfterEpochEnd(ctx, identifier, epochNumber) } -// BeforeEpochStart executes the indicated hook before the epochs +// BeforeEpochStart epoch hook func (k Keeper) BeforeEpochStart(ctx sdk.Context, identifier string, epochNumber int64) { + if k.hooks == nil { + panic("hooks not set in keeper") + } k.hooks.BeforeEpochStart(ctx, identifier, epochNumber) } diff --git a/x/epochs/keeper/hooks_test.go b/x/epochs/keeper/hooks_test.go index 0050a05c..156aad26 100644 --- a/x/epochs/keeper/hooks_test.go +++ b/x/epochs/keeper/hooks_test.go @@ -58,6 +58,11 @@ func (suite *KeeperTestSuite) TestAfterEpochHooks() { }{ { expBeforeEpochStartEvents: []ExpEvent{ + { + EpochNumber: "2", + }, + }, + expAfterEpochEndEvents: []ExpEvent{ { EpochNumber: "1", }, @@ -70,15 +75,23 @@ func (suite *KeeperTestSuite) TestAfterEpochHooks() { // Check if curent epoch is expected epochInfo, found := suite.app.EpochsKeeper.GetEpochInfo(suite.ctx, types.DayEpochId) suite.Require().True(found) - suite.Require().Equal(int64(1), epochInfo.CurrentEpoch) + suite.Require().Equal(int64(2), epochInfo.CurrentEpoch) }, }, { expBeforeEpochStartEvents: []ExpEvent{ { - EpochNumber: "1", + EpochNumber: "2", }, { + EpochNumber: "3", + }, + }, + expAfterEpochEndEvents: []ExpEvent{ + { + EpochNumber: "1", + }, +{ EpochNumber: "2", }, }, @@ -90,11 +103,16 @@ func (suite *KeeperTestSuite) TestAfterEpochHooks() { // Check if curent epoch is expected epochInfo, found := suite.app.EpochsKeeper.GetEpochInfo(suite.ctx, types.DayEpochId) suite.Require().True(found) - suite.Require().Equal(int64(1), epochInfo.CurrentEpoch) + suite.Require().Equal(int64(2), epochInfo.CurrentEpoch) // Begin second block suite.ctx = suite.ctx.WithBlockHeight(3).WithBlockTime(now.Add(oneDayDuration)) suite.app.EpochsKeeper.BeginBlocker(suite.ctx) + + // Check if curent epoch is expected + epochInfo, found = suite.app.EpochsKeeper.GetEpochInfo(suite.ctx, types.DayEpochId) + suite.Require().True(found) + suite.Require().Equal(int64(3), epochInfo.CurrentEpoch) }, }, } @@ -131,7 +149,7 @@ func (suite *KeeperTestSuite) TestAfterEpochHooks() { if len(tc.expAfterEpochEndEvents) != 0 { afterEpochEndEvents, found := testutil.FindEventsByType(suite.ctx.EventManager().Events(), AfterEpochEndEventType) suite.Require().True(found) - for i, expEvent := range tc.expBeforeEpochStartEvents { + for i, expEvent := range tc.expAfterEpochEndEvents { event := afterEpochEndEvents[i] suite.Require().Equal(AfterEpochEndEventType, event.Type) suite.Require().Equal(EpochIdentifier, event.Attributes[0].Value)