diff --git a/contracts/distribution/DistributionService.sol b/contracts/distribution/DistributionService.sol index 90acba221..67670c14b 100644 --- a/contracts/distribution/DistributionService.sol +++ b/contracts/distribution/DistributionService.sol @@ -191,9 +191,9 @@ contract DistributionService is (NftId distributionNftId, IInstance instance) = _getAndVerifyActiveDistribution(); if (bytes(code).length == 0) { - revert ErrorDistributionServiceInvalidReferral(code); + revert ErrorDistributionServiceInvalidReferral(); } - if (expiryAt.eqz() || expiryAt.lte(TimestampLib.blockTimestamp())) { + if (expiryAt.eqz() || expiryAt < TimestampLib.blockTimestamp()) { revert ErrorDistributionServiceExpirationInvalid(expiryAt); } @@ -208,13 +208,13 @@ contract DistributionService is IDistribution.DistributorTypeInfo memory distributorTypeData = instanceReader.getDistributorTypeInfo(distributorType); if (distributorTypeData.maxReferralCount < maxReferrals) { - revert ErrorDistributionServiceMaxReferralsExceeded(distributorTypeData.maxReferralCount); + revert ErrorDistributionServiceMaxReferralsExceeded(distributorTypeData.maxReferralCount, maxReferrals); } if (distributorTypeData.minDiscountPercentage > discountPercentage) { - revert ErrorDistributionServiceDiscountTooLow(distributorTypeData.minDiscountPercentage.toInt(), discountPercentage.toInt()); + revert ErrorDistributionServiceDiscountTooLow(distributorTypeData.minDiscountPercentage, discountPercentage); } if (distributorTypeData.maxDiscountPercentage < discountPercentage) { - revert ErrorDistributionServiceDiscountTooHigh(distributorTypeData.maxDiscountPercentage.toInt(), discountPercentage.toInt()); + revert ErrorDistributionServiceDiscountTooHigh(distributorTypeData.maxDiscountPercentage, discountPercentage); } if (expiryAt.toInt() - TimestampLib.blockTimestamp().toInt() > distributorTypeData.maxReferralLifetime.toInt()) { revert ErrorDistributionServiceExpiryTooLong(distributorTypeData.maxReferralLifetime, expiryAt); diff --git a/contracts/distribution/IDistributionService.sol b/contracts/distribution/IDistributionService.sol index d3dc5273e..ced5f4d76 100644 --- a/contracts/distribution/IDistributionService.sol +++ b/contracts/distribution/IDistributionService.sol @@ -19,11 +19,11 @@ interface IDistributionService is IService { error ErrorDistributionServiceParentNftIdNotInstance(NftId nftId, NftId parentNftId); error ErrorDistributionServiceCallerNotDistributor(address caller); error ErrorDistributionServiceInvalidReferralId(ReferralId referralId); - error ErrorDistributionServiceMaxReferralsExceeded(uint256 maxReferrals); - error ErrorDistributionServiceDiscountTooLow(uint256 minDiscountPercentage, uint256 discountPercentage); - error ErrorDistributionServiceDiscountTooHigh(uint256 maxDiscountPercentage, uint256 discountPercentage); + error ErrorDistributionServiceMaxReferralsExceeded(uint256 limit, uint256 maxReferrals); + error ErrorDistributionServiceDiscountTooLow(UFixed minDiscountPercentage, UFixed discountPercentage); + error ErrorDistributionServiceDiscountTooHigh(UFixed maxDiscountPercentage, UFixed discountPercentage); error ErrorDistributionServiceExpiryTooLong(Seconds maxReferralLifetime, Timestamp expiryAt); - error ErrorDistributionServiceInvalidReferral(string code); + error ErrorDistributionServiceInvalidReferral(); error ErrorDistributionServiceExpirationInvalid(Timestamp expiryAt); error ErrorDistributionServiceCommissionTooHigh(uint256 commissionPercentage, uint256 maxCommissionPercentage); error ErrorDistributionServiceMinFeeTooHigh(uint256 minFee, uint256 limit); diff --git a/test/component/distribution/Referral.t.sol b/test/component/distribution/Referral.t.sol index c0b72c7f9..a38e62cbb 100644 --- a/test/component/distribution/Referral.t.sol +++ b/test/component/distribution/Referral.t.sol @@ -7,6 +7,7 @@ import {console} from "../../../lib/forge-std/src/Test.sol"; import {FeeLib} from "../../../contracts/type/Fee.sol"; import {IComponents} from "../../../contracts/instance/module/IComponents.sol"; import {IDistribution} from "../../../contracts/instance/module/IDistribution.sol"; +import {IDistributionService} from "../../../contracts/distribution/IDistributionService.sol"; import {IPolicy} from "../../../contracts/instance/module/IPolicy.sol"; import {NftId, NftIdLib} from "../../../contracts/type/NftId.sol"; import {POLICY} from "../../../contracts/type/ObjectType.sol"; @@ -15,8 +16,8 @@ import {ReferralTestBase} from "./ReferralTestBase.sol"; import {RiskId, RiskIdLib} from "../../../contracts/type/RiskId.sol"; import {SecondsLib} from "../../../contracts/type/Seconds.sol"; import {SimpleDistribution} from "../../../contracts/examples/unpermissioned/SimpleDistribution.sol"; -import {TimestampLib} from "../../../contracts/type/Timestamp.sol"; -import {UFixedLib} from "../../../contracts/type/UFixed.sol"; +import {Timestamp, TimestampLib} from "../../../contracts/type/Timestamp.sol"; +import {UFixed, UFixedLib} from "../../../contracts/type/UFixed.sol"; contract ReferralTest is ReferralTestBase { @@ -307,6 +308,140 @@ contract ReferralTest is ReferralTestBase { vm.stopPrank(); } + function test_createReferral_codeEmpty() public { + // GIVEN + _setupTestData(true); + vm.startPrank(customer); + + // THEN + vm.expectRevert(abi.encodeWithSelector( + IDistributionService.ErrorDistributionServiceInvalidReferral.selector)); + + // WHEN + referralId = distribution.createReferral( + distributorNftId, + "", + discountPercentage, + maxReferrals, + expiryAt, + referralData); + } + + function test_createReferral_expirationInvalid() public { + // GIVEN + vm.warp(500); + _setupTestData(true); + vm.startPrank(customer); + Timestamp tsZero = TimestampLib.zero(); + + // THEN + vm.expectRevert(abi.encodeWithSelector( + IDistributionService.ErrorDistributionServiceExpirationInvalid.selector, + 0)); + + // WHEN + referralId = distribution.createReferral( + distributorNftId, + "CODE", + discountPercentage, + maxReferrals, + tsZero, + referralData); + + Timestamp exp = TimestampLib.toTimestamp(TimestampLib.blockTimestamp().toInt() - 10); + + // THEN + vm.expectRevert(abi.encodeWithSelector( + IDistributionService.ErrorDistributionServiceExpirationInvalid.selector, + exp)); + + // WHEN + referralId = distribution.createReferral( + distributorNftId, + "CODE", + discountPercentage, + maxReferrals, + exp, + referralData); + + exp = TimestampLib.toTimestamp(TimestampLib.blockTimestamp().toInt() + maxReferralLifetime.toInt() + 10); + + // THEN + vm.expectRevert(abi.encodeWithSelector( + IDistributionService.ErrorDistributionServiceExpiryTooLong.selector, + maxReferralLifetime, + exp)); + + // WHEN + referralId = distribution.createReferral( + distributorNftId, + "CODE", + discountPercentage, + maxReferrals, + exp, + referralData); + } + + function test_createReferral_referralCountInvalid() public { + // GIVEN + _setupTestData(true); + vm.startPrank(customer); + Timestamp exp = TimestampLib.blockTimestamp().addSeconds(SecondsLib.toSeconds(10)); + + // THEN + vm.expectRevert(abi.encodeWithSelector( + IDistributionService.ErrorDistributionServiceMaxReferralsExceeded.selector, + 20, + 42)); + + // WHEN + referralId = distribution.createReferral( + distributorNftId, + "CODE", + discountPercentage, + 42, + exp, + referralData); + } + + function test_createReferral_discoundInvalid() public { + // GIVEN + _setupTestData(true); + vm.startPrank(customer); + Timestamp exp = TimestampLib.blockTimestamp().addSeconds(SecondsLib.toSeconds(10)); + UFixed discount = UFixedLib.toUFixed(3, -2); + + // THEN + vm.expectRevert(abi.encodeWithSelector( + IDistributionService.ErrorDistributionServiceDiscountTooLow.selector, + minDiscountPercentage, + discount)); + + // WHEN + referralId = distribution.createReferral( + distributorNftId, + "CODE", + discount, + 15, + exp, + referralData); + + // THEN + discount = UFixedLib.toUFixed(22, -2); + vm.expectRevert(abi.encodeWithSelector( + IDistributionService.ErrorDistributionServiceDiscountTooHigh.selector, + maxDiscountPercentage, + discount)); + + // WHEN + referralId = distribution.createReferral( + distributorNftId, + "CODE", + discount, + 15, + exp, + referralData); + } function _setupBundle(uint256 bundleAmount) internal {