diff --git a/CHANGELOG.md b/CHANGELOG.md index be86dd75526..f4cd4e0d114 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -308,6 +308,7 @@ - [9747](https://github.com/vegaprotocol/vega/issues/9747) - Return correct destination type - [9541](https://github.com/vegaprotocol/vega/issues/9731) - Add filtering for party to the referral fees API. - [9751](https://github.com/vegaprotocol/vega/issues/9751) - Make sure that LP fee party accounts exists. +- [9762](https://github.com/vegaprotocol/vega/issues/9762) - Referral fees API not filtering by party correctly. ## 0.72.1 diff --git a/datanode/sqlstore/referral_fee_stats.go b/datanode/sqlstore/referral_fee_stats.go index 9dd1b223e57..a217be0a44e 100644 --- a/datanode/sqlstore/referral_fee_stats.go +++ b/datanode/sqlstore/referral_fee_stats.go @@ -21,6 +21,8 @@ import ( "fmt" "strings" + eventspb "code.vegaprotocol.io/vega/protos/vega/events/v1" + "github.com/georgysavva/scany/pgxscan" "code.vegaprotocol.io/vega/datanode/metrics" @@ -120,9 +122,46 @@ func (rfs *ReferralFeeStats) GetFeeStats(ctx context.Context, marketID *entities return nil, errors.New("no referral fee stats found") } + // the query returns the full JSON object and doesn't filter for the referrer/referee, + // it only matches on the records where the json object contains the referrer/referee + if referrerID != nil { + // filter the results to only include the results for the given referrer + stats[0].TotalRewardsPaid = filterPartyAmounts(stats[0].TotalRewardsPaid, *referrerID) + stats[0].ReferrerRewardsGenerated = filterReferrerRewardsGenerated(stats[0].ReferrerRewardsGenerated, *referrerID) + } + + if refereeID != nil { + // filter the results to only include the results for the given referee + for i, rrg := range stats[0].ReferrerRewardsGenerated { + stats[0].ReferrerRewardsGenerated[i].GeneratedReward = filterPartyAmounts(rrg.GeneratedReward, *refereeID) + } + stats[0].RefereesDiscountApplied = filterPartyAmounts(stats[0].RefereesDiscountApplied, *refereeID) + stats[0].VolumeDiscountApplied = filterPartyAmounts(stats[0].VolumeDiscountApplied, *refereeID) + } + return &stats[0], err } +func filterPartyAmounts(totalRewardsPaid []*eventspb.PartyAmount, party string) []*eventspb.PartyAmount { + filteredTotalRewardsPaid := make([]*eventspb.PartyAmount, 0) + for _, reward := range totalRewardsPaid { + if strings.EqualFold(reward.Party, party) { + filteredTotalRewardsPaid = append(filteredTotalRewardsPaid, reward) + } + } + return filteredTotalRewardsPaid +} + +func filterReferrerRewardsGenerated(rewardsGenerated []*eventspb.ReferrerRewardsGenerated, referrer string) []*eventspb.ReferrerRewardsGenerated { + filteredReferrerRewardsGenerated := make([]*eventspb.ReferrerRewardsGenerated, 0) + for _, reward := range rewardsGenerated { + if strings.EqualFold(reward.Referrer, referrer) { + filteredReferrerRewardsGenerated = append(filteredReferrerRewardsGenerated, reward) + } + } + return filteredReferrerRewardsGenerated +} + func getPartyFilter(referrerID, refereeID *string) string { builder := strings.Builder{} if referrerID == nil && refereeID == nil { diff --git a/datanode/sqlstore/referral_fee_stats_test.go b/datanode/sqlstore/referral_fee_stats_test.go index ed1938c1f08..3886be6020c 100644 --- a/datanode/sqlstore/referral_fee_stats_test.go +++ b/datanode/sqlstore/referral_fee_stats_test.go @@ -478,7 +478,38 @@ func testGetFeeStatsForRefereeAndEpoch(t *testing.T) { stats := setupFeeStats(t, ctx, stores.fs) // get the stats for the first market and epoch - want := stats[1] + expected := stats[1] + want := entities.ReferralFeeStats{ + MarketID: entities.MarketID("deadbeef01"), + AssetID: entities.AssetID("deadbaad01"), + EpochSeq: 2, + TotalRewardsPaid: []*eventspb.PartyAmount{ + { + Party: "cafedaad01", + Amount: "1100000", + }, + }, + ReferrerRewardsGenerated: []*eventspb.ReferrerRewardsGenerated{ + { + Referrer: "cafedaad01", + GeneratedReward: []*eventspb.PartyAmount{ + { + Party: "cafed00d01", + Amount: "550000", + }, + }, + }, + }, + RefereesDiscountApplied: []*eventspb.PartyAmount{ + { + Party: "cafed00d01", + Amount: "110000", + }, + }, + VolumeDiscountApplied: []*eventspb.PartyAmount{}, + VegaTime: expected.VegaTime, + } + got, err := stores.fs.GetFeeStats(ctx, nil, &want.AssetID, ptr.From(want.EpochSeq), nil, &want.ReferrerRewardsGenerated[0].GeneratedReward[0].Party) require.NoError(t, err) @@ -503,7 +534,37 @@ func testGetFeeStatsForRefereeLatest(t *testing.T) { stats := setupFeeStats(t, ctx, stores.fs) // get the stats for the first market and epoch - want := stats[2] + expected := stats[2] + want := entities.ReferralFeeStats{ + MarketID: entities.MarketID("deadbeef01"), + AssetID: entities.AssetID("deadbaad01"), + EpochSeq: 3, + TotalRewardsPaid: []*eventspb.PartyAmount{ + { + Party: "cafedaad01", + Amount: "1200000", + }, + }, + ReferrerRewardsGenerated: []*eventspb.ReferrerRewardsGenerated{ + { + Referrer: "cafedaad01", + GeneratedReward: []*eventspb.PartyAmount{ + { + Party: "cafed00d01", + Amount: "600000", + }, + }, + }, + }, + RefereesDiscountApplied: []*eventspb.PartyAmount{ + { + Party: "cafed00d01", + Amount: "120000", + }, + }, + VolumeDiscountApplied: []*eventspb.PartyAmount{}, + VegaTime: expected.VegaTime, + } got, err := stores.fs.GetFeeStats(ctx, nil, &want.AssetID, nil, nil, &want.ReferrerRewardsGenerated[0].GeneratedReward[0].Party) require.NoError(t, err) @@ -516,7 +577,37 @@ func testGetFeeStatsReferee(t *testing.T) { stats := setupFeeStats(t, ctx, stores.fs) // get the stats for the first market and epoch - want := stats[2] + expected := stats[2] + want := entities.ReferralFeeStats{ + MarketID: entities.MarketID("deadbeef01"), + AssetID: entities.AssetID("deadbaad01"), + EpochSeq: 3, + TotalRewardsPaid: []*eventspb.PartyAmount{ + { + Party: "cafedaad01", + Amount: "1200000", + }, + }, + ReferrerRewardsGenerated: []*eventspb.ReferrerRewardsGenerated{ + { + Referrer: "cafedaad01", + GeneratedReward: []*eventspb.PartyAmount{ + { + Party: "cafed00d01", + Amount: "600000", + }, + }, + }, + }, + RefereesDiscountApplied: []*eventspb.PartyAmount{ + { + Party: "cafed00d01", + Amount: "120000", + }, + }, + VolumeDiscountApplied: []*eventspb.PartyAmount{}, + VegaTime: expected.VegaTime, + } got, err := stores.fs.GetFeeStats(ctx, nil, &want.AssetID, nil, nil, &want.ReferrerRewardsGenerated[0].GeneratedReward[0].Party) require.NoError(t, err)