Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: refund remaining rewards to the service's address when a plan ends #101

Merged
merged 1 commit into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions x/rewards/keeper/abci.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,15 @@ import (
"context"
)

// BeginBlocker allocates restaking rewards for the previous block.
// BeginBlocker is called every block and is used to terminate ended rewards
// plans and allocate restaking rewards for the previous block.
func (k *Keeper) BeginBlocker(ctx context.Context) error {
err := k.AllocateRewards(ctx)
err := k.TerminateEndedRewardsPlans(ctx)
if err != nil {
return err
}

err = k.AllocateRewards(ctx)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion x/rewards/keeper/allocation.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ func (k *Keeper) AllocateRewardsByPlan(
rewards = sdk.NewDecCoinsFromCoins(rewardsTruncated...)

// Check if the rewards pool has enough coins to allocate rewards.
planRewardsPoolAddr := plan.MustGetRewardsPoolAddress()
planRewardsPoolAddr := plan.MustGetRewardsPoolAddress(k.accountKeeper.AddressCodec())
balances := k.bankKeeper.GetAllBalances(ctx, planRewardsPoolAddr)
sdkCtx := sdk.UnwrapSDKContext(ctx)
if !balances.IsAllGTE(rewardsTruncated) {
Expand Down
70 changes: 69 additions & 1 deletion x/rewards/keeper/rewards_plan.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package keeper

import (
"context"
"fmt"
"time"

"cosmossdk.io/errors"
Expand Down Expand Up @@ -75,7 +76,7 @@ func (k *Keeper) CreateRewardsPlan(
// types.UsersDistributionTypeBasic only which doesn't need a validation.

// Create a rewards pool account if it doesn't exist
k.createAccountIfNotExists(ctx, plan.MustGetRewardsPoolAddress())
k.createAccountIfNotExists(ctx, plan.MustGetRewardsPoolAddress(k.accountKeeper.AddressCodec()))

// Store the rewards plan
err = k.RewardsPlans.Set(ctx, planID, plan)
Expand Down Expand Up @@ -112,3 +113,70 @@ func (k *Keeper) validateDistributionDelegationTargets(ctx context.Context, dist
func (k *Keeper) GetRewardsPlan(ctx context.Context, planID uint64) (types.RewardsPlan, error) {
return k.RewardsPlans.Get(ctx, planID)
}

// terminateRewardsPlan removes a rewards plan and transfers the remaining
// rewards in the plan's rewards pool to the service's address.
func (k *Keeper) terminateRewardsPlan(ctx context.Context, plan types.RewardsPlan) error {
sdkCtx := sdk.UnwrapSDKContext(ctx)

// Transfer remaining rewards in the plan's rewards pool to the service's
// address.
rewardsPoolAddr := plan.MustGetRewardsPoolAddress(k.accountKeeper.AddressCodec())
remaining := k.bankKeeper.GetAllBalances(ctx, rewardsPoolAddr)
if remaining.IsAllPositive() {
// Get the service's address.
service, found := k.servicesKeeper.GetService(sdkCtx, plan.ServiceID)
if !found {
return servicestypes.ErrServiceNotFound
}
serviceAddr, err := k.accountKeeper.AddressCodec().StringToBytes(service.Address)
if err != nil {
return err
}

// Transfer all the remaining rewards to the service's address.
err = k.bankKeeper.SendCoins(ctx, rewardsPoolAddr, serviceAddr, remaining)
if err != nil {
return err
}
}

// Remove the plan.
err := k.RewardsPlans.Remove(ctx, plan.ID)
if err != nil {
return err
}

sdkCtx.EventManager().EmitEvents(sdk.Events{
sdk.NewEvent(
types.EventTypeTerminateRewardsPlan,
sdk.NewAttribute(types.AttributeKeyRewardsPlanID, fmt.Sprint(plan.ID)),
sdk.NewAttribute(types.AttributeKeyRemainingRewards, remaining.String()),
),
})

return nil
}

// TerminateEndedRewardsPlans terminates all rewards plans that have ended.
func (k *Keeper) TerminateEndedRewardsPlans(ctx context.Context) error {
sdkCtx := sdk.UnwrapSDKContext(ctx)
// Get the current block time
blockTime := sdkCtx.BlockTime()

// Iterate over all rewards plans
err := k.RewardsPlans.Walk(ctx, nil, func(planID uint64, plan types.RewardsPlan) (stop bool, err error) {
// If the plan has already ended, terminate it
if !blockTime.Before(plan.EndTime) {
err = k.terminateRewardsPlan(ctx, plan)
if err != nil {
return false, err
}
}
return false, nil
})
if err != nil {
return err
}
return nil
}
44 changes: 44 additions & 0 deletions x/rewards/keeper/rewards_plan_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package keeper_test
import (
"time"

"cosmossdk.io/collections"

"github.com/milkyway-labs/milkyway/app/testutil"
"github.com/milkyway-labs/milkyway/utils"
rewardskeeper "github.com/milkyway-labs/milkyway/x/rewards/keeper"
Expand Down Expand Up @@ -66,3 +68,45 @@ func (suite *KeeperTestSuite) TestCreateRewardsPlan_PoolOrOperatorNotFound() {
))
suite.Require().EqualError(err, "cannot get delegation target 2: operator not found: not found")
}

func (suite *KeeperTestSuite) TestTerminateEndedRewardsPlans() {
// Cache the context to avoid errors
ctx, _ := suite.Ctx.CacheContext()

service, _ := suite.setupSampleServiceAndOperator(ctx)

// Create an active rewards plan.
plan := suite.CreateBasicRewardsPlan(
ctx,
service.ID,
utils.MustParseCoins("100_000000service"),
time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
utils.MustParseCoins("10000_000000service"),
)

rewardsPoolAddr := plan.MustGetRewardsPoolAddress(suite.App.AccountKeeper.AddressCodec())
remaining := suite.App.BankKeeper.GetAllBalances(ctx, rewardsPoolAddr)
suite.Require().Equal("10000000000service", remaining.String())

// Change the block time so that the plan becomes no more active.
ctx = ctx.WithBlockTime(time.Date(2025, 1, 2, 0, 0, 0, 0, time.UTC))

// Terminate the ended rewards plans
err := suite.keeper.TerminateEndedRewardsPlans(ctx)
suite.Require().NoError(err)

// The plan is removed.
_, err = suite.keeper.GetRewardsPlan(ctx, plan.ID)
suite.Require().ErrorIs(err, collections.ErrNotFound)

// All remaining rewards are transferred to the service's address.
remaining = suite.App.BankKeeper.GetAllBalances(ctx, rewardsPoolAddr)
suite.Require().True(remaining.IsZero())

// Check the service's address balances.
serviceAddr, err := suite.App.AccountKeeper.AddressCodec().StringToBytes(service.Address)
suite.Require().NoError(err)
serviceBalances := suite.App.BankKeeper.GetAllBalances(ctx, serviceAddr)
suite.Require().Equal("10000000000service", serviceBalances.String())
}
14 changes: 8 additions & 6 deletions x/rewards/types/events.go
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
package types

const (
EventTypeCreateRewardsPlan = "create_rewards_plan"
EventTypeSetWithdrawAddress = "set_withdraw_address"
EventTypeRewards = "rewards"
EventTypeCommission = "commission"
EventTypeWithdrawRewards = "withdraw_rewards"
EventTypeWithdrawCommission = "withdraw_commission"
EventTypeCreateRewardsPlan = "create_rewards_plan"
EventTypeSetWithdrawAddress = "set_withdraw_address"
EventTypeRewards = "rewards"
EventTypeCommission = "commission"
EventTypeWithdrawRewards = "withdraw_rewards"
EventTypeWithdrawCommission = "withdraw_commission"
EventTypeTerminateRewardsPlan = "terminate_rewards_plan"

AttributeKeyRewardsPlanID = "rewards_plan_id"
AttributeKeyWithdrawAddress = "withdraw_address"
AttributeKeyDelegationType = "delegation_type"
AttributeKeyDelegationTargetID = "delegation_target_id"
AttributeKeyRemainingRewards = "remaining_rewards"

// AttributeKeyAmountPerPool represents the amount of rewards per pool (per denom).
// See https://github.com/initia-labs/initia/blob/v0.2.10/x/distribution/types/events.go#L3-L6
Expand Down
1 change: 1 addition & 0 deletions x/rewards/types/expected_keepers.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ type AccountKeeper interface {

type BankKeeper interface {
GetAllBalances(ctx context.Context, addr sdk.AccAddress) sdk.Coins
SendCoins(ctx context.Context, fromAddr, toAddr sdk.AccAddress, amt sdk.Coins) error
SendCoinsFromModuleToAccount(ctx context.Context, moduleName string, addr sdk.AccAddress, amt sdk.Coins) error
SendCoinsFromAccountToModule(ctx context.Context, addr sdk.AccAddress, moduleName string, amt sdk.Coins) error
BlockedAddr(addr sdk.AccAddress) bool
Expand Down
5 changes: 3 additions & 2 deletions x/rewards/types/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
"time"

coreaddress "cosmossdk.io/core/address"
codectypes "github.com/cosmos/cosmos-sdk/codec/types"
sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/cosmos/cosmos-sdk/types/address"
Expand Down Expand Up @@ -63,8 +64,8 @@ func (plan RewardsPlan) IsActiveAt(t time.Time) bool {
}

// MustGetRewardsPoolAddress returns the rewards pool address.
func (plan RewardsPlan) MustGetRewardsPoolAddress() sdk.AccAddress {
addr, err := sdk.AccAddressFromBech32(plan.RewardsPool)
func (plan RewardsPlan) MustGetRewardsPoolAddress(addressCodec coreaddress.Codec) sdk.AccAddress {
addr, err := addressCodec.StringToBytes(plan.RewardsPool)
if err != nil {
panic(err)
}
Expand Down
Loading