From a7fecd5139597ecf8948e01e4815d2658460068d Mon Sep 17 00:00:00 2001 From: 0xchin Date: Mon, 18 Nov 2024 21:20:35 -0300 Subject: [PATCH] feat: add performance fee on withdrawal from vault --- src/contracts/Grateful.sol | 199 ++++++++++++------- src/interfaces/IGrateful.sol | 80 ++++++-- test/integration/Grateful.t.sol | 277 +++++++++++++-------------- test/integration/IntegrationBase.sol | 73 ++++++- 4 files changed, 398 insertions(+), 231 deletions(-) diff --git a/src/contracts/Grateful.sol b/src/contracts/Grateful.sol index 3754b43..443cd09 100644 --- a/src/contracts/Grateful.sol +++ b/src/contracts/Grateful.sol @@ -13,14 +13,16 @@ import {AaveV3Vault} from "contracts/vaults/AaveV3Vault.sol"; import {IGrateful} from "interfaces/IGrateful.sol"; import {Bytes32AddressLib} from "solmate/utils/Bytes32AddressLib.sol"; +import {FixedPointMathLib} from "solmate/utils/FixedPointMathLib.sol"; import {IPool} from "yield-daddy/aave-v3/AaveV3ERC4626.sol"; /** * @title Grateful Contract - * @notice Allows payments in whitelisted tokens with optional yield via AAVE, including payment splitting functionality. + * @notice Allows payments in whitelisted tokens with optional yield via AAVE */ contract Grateful is IGrateful, Ownable2Step { using Bytes32AddressLib for bytes32; + using FixedPointMathLib for uint256; using SafeERC20 for IERC20; /*////////////////////////////////////////////////////////////// @@ -39,6 +41,9 @@ contract Grateful is IGrateful, Ownable2Step { /// @inheritdoc IGrateful mapping(address => mapping(address => uint256)) public shares; + /// @inheritdoc IGrateful + mapping(address => mapping(address => uint256)) public userDeposits; + /// @inheritdoc IGrateful mapping(address => bool) public oneTimePayments; @@ -51,6 +56,9 @@ contract Grateful is IGrateful, Ownable2Step { /// @inheritdoc IGrateful mapping(uint256 => bool) public paymentIds; + /// @inheritdoc IGrateful + uint256 public performanceFeeRate = 500; // 5% fee + /*////////////////////////////////////////////////////////////// MODIFIERS //////////////////////////////////////////////////////////////*/ @@ -72,7 +80,7 @@ contract Grateful is IGrateful, Ownable2Step { } /*////////////////////////////////////////////////////////////// - CONSTRUCTOR + CONSTRUCTOR //////////////////////////////////////////////////////////////*/ /** @@ -86,20 +94,20 @@ contract Grateful is IGrateful, Ownable2Step { fee = _initialFee; for (uint256 i = 0; i < _tokens.length; i++) { tokensWhitelisted[_tokens[i]] = true; - IERC20 _token = IERC20(_tokens[i]); - _token.forceApprove(address(_aavePool), type(uint256).max); + IERC20 token = IERC20(_tokens[i]); // Renamed from _token to token + token.forceApprove(address(_aavePool), type(uint256).max); } } /*////////////////////////////////////////////////////////////// - PUBLIC FUNCTIONS + PUBLIC FUNCTIONS //////////////////////////////////////////////////////////////*/ /// @inheritdoc IGrateful function calculateAssets(address _merchant, address _token) public view returns (uint256 assets) { - AaveV3Vault _vault = vaults[_token]; - uint256 _shares = shares[_merchant][_token]; - assets = _vault.convertToAssets(_shares); + AaveV3Vault vault = vaults[_token]; // Renamed from _vault to vault + uint256 sharesAmount = shares[_merchant][_token]; // Renamed from _shares to sharesAmount + assets = vault.convertToAssets(sharesAmount); } /// @inheritdoc IGrateful @@ -127,8 +135,28 @@ contract Grateful is IGrateful, Ownable2Step { return super.owner(); } + /// @inheritdoc IGrateful + function calculateProfit(address _user, address _token) public view returns (uint256 profit) { + AaveV3Vault vault = vaults[_token]; // Renamed from _vault to vault + uint256 sharesAmount = shares[_user][_token]; // Renamed from _shares to sharesAmount + uint256 assets = vault.previewRedeem(sharesAmount); // Current value of user's shares + uint256 initialDeposit = userDeposits[_user][_token]; // User's initial deposit + if (assets > initialDeposit) { + profit = assets - initialDeposit; + } else { + profit = 0; + } + } + + /// @inheritdoc IGrateful + function calculatePerformanceFee( + uint256 _profit + ) public view returns (uint256 feeAmount) { + feeAmount = (_profit * performanceFeeRate) / 10_000; + } + /*////////////////////////////////////////////////////////////// - EXTERNAL FUNCTIONS + EXTERNAL FUNCTIONS //////////////////////////////////////////////////////////////*/ /// @inheritdoc IGrateful @@ -136,7 +164,8 @@ contract Grateful is IGrateful, Ownable2Step { address _token ) external onlyOwner { tokensWhitelisted[_token] = true; - IERC20(_token).forceApprove(address(aavePool), type(uint256).max); + IERC20 token = IERC20(_token); // Renamed from _token to token + token.forceApprove(address(aavePool), type(uint256).max); emit TokenAdded(_token); } @@ -148,15 +177,17 @@ contract Grateful is IGrateful, Ownable2Step { revert Grateful_TokenOrVaultNotFound(); } delete tokensWhitelisted[_token]; - IERC20(_token).forceApprove(address(aavePool), 0); - IERC20(_token).forceApprove(address(vaults[_token]), 0); + IERC20 token = IERC20(_token); // Renamed from _token to token + token.forceApprove(address(aavePool), 0); + token.forceApprove(address(vaults[_token]), 0); emit TokenRemoved(_token); } /// @inheritdoc IGrateful function addVault(address _token, address _vault) external onlyOwner onlyWhenTokenWhitelisted(_token) { vaults[_token] = AaveV3Vault(_vault); - IERC20(_token).safeIncreaseAllowance(address(_vault), type(uint256).max); + IERC20 token = IERC20(_token); // Renamed from _token to token + token.safeIncreaseAllowance(address(_vault), type(uint256).max); emit VaultAdded(_token, _vault); } @@ -168,7 +199,8 @@ contract Grateful is IGrateful, Ownable2Step { if (address(vault) == address(0)) { revert Grateful_TokenOrVaultNotFound(); } - IERC20(_token).forceApprove(address(vault), 0); + IERC20 token = IERC20(_token); // Renamed from _token to token + token.forceApprove(address(vault), 0); emit VaultRemoved(_token, address(vault)); delete vaults[_token]; } @@ -192,9 +224,9 @@ contract Grateful is IGrateful, Ownable2Step { uint256 _salt, uint256 _paymentId, bool _yieldFunds, - address precomputed + address _precomputed ) external onlyWhenTokensWhitelisted(_tokens) returns (OneTime oneTime) { - oneTimePayments[precomputed] = true; + oneTimePayments[_precomputed] = true; oneTime = new OneTime{salt: bytes32(_salt)}(IGrateful(address(this)), _tokens, _merchant, _amount, _paymentId, _yieldFunds); emit OneTimePaymentCreated(_merchant, _tokens, _amount); @@ -236,28 +268,12 @@ contract Grateful is IGrateful, Ownable2Step { function withdraw( address _token ) external onlyWhenTokenWhitelisted(_token) { - AaveV3Vault vault = vaults[_token]; - if (address(vault) == address(0)) { - revert Grateful_usdcVaultNotSet(); - } - uint256 _shares = shares[msg.sender][_token]; - shares[msg.sender][_token] = 0; - vault.redeem(_shares, msg.sender, address(this)); + _withdraw(_token, 0, true); } /// @inheritdoc IGrateful function withdraw(address _token, uint256 _assets) external onlyWhenTokenWhitelisted(_token) { - AaveV3Vault vault = vaults[_token]; - if (address(vault) == address(0)) { - revert Grateful_usdcVaultNotSet(); - } - uint256 _totalShares = shares[msg.sender][_token]; - uint256 _sharesToWithdraw = vault.convertToShares(_assets); - if (_sharesToWithdraw > _totalShares) { - revert Grateful_WithdrawExceedsShares(); - } - shares[msg.sender][_token] = _totalShares - _sharesToWithdraw; - vault.withdraw(_assets, msg.sender, address(this)); + _withdraw(_token, _assets, false); } /// @inheritdoc IGrateful @@ -266,16 +282,7 @@ contract Grateful is IGrateful, Ownable2Step { ) external onlyWhenTokensWhitelisted(_tokens) { uint256 tokensLength = _tokens.length; for (uint256 i = 0; i < tokensLength; i++) { - address _token = _tokens[i]; - AaveV3Vault vault = vaults[_token]; - if (address(vault) == address(0)) { - revert Grateful_usdcVaultNotSet(); - } - uint256 _shares = shares[msg.sender][_token]; - if (_shares > 0) { - shares[msg.sender][_token] = 0; - vault.redeem(_shares, msg.sender, address(this)); - } + _withdraw(_tokens[i], 0, true); } } @@ -289,19 +296,7 @@ contract Grateful is IGrateful, Ownable2Step { revert Grateful_MismatchedArrays(); } for (uint256 i = 0; i < tokensLength; i++) { - address _token = _tokens[i]; - uint256 _assetsToWithdraw = _assets[i]; - AaveV3Vault vault = vaults[_token]; - if (address(vault) == address(0)) { - revert Grateful_usdcVaultNotSet(); - } - uint256 _totalShares = shares[msg.sender][_token]; - uint256 _sharesToWithdraw = vault.convertToShares(_assetsToWithdraw); - if (_sharesToWithdraw > _totalShares) { - revert Grateful_WithdrawExceedsShares(); - } - shares[msg.sender][_token] = _totalShares - _sharesToWithdraw; - vault.withdraw(_assetsToWithdraw, msg.sender, address(this)); + _withdraw(_tokens[i], _assets[i], false); } } @@ -313,6 +308,17 @@ contract Grateful is IGrateful, Ownable2Step { emit FeeUpdated(_newFee); } + /// @inheritdoc IGrateful + function setPerformanceFeeRate( + uint256 _newPerformanceFeeRate + ) external onlyOwner { + if (_newPerformanceFeeRate > 5000) { + revert Grateful_FeeRateTooHigh(); + } + performanceFeeRate = _newPerformanceFeeRate; + emit PerformanceFeeRateUpdated(_newPerformanceFeeRate); + } + /// @inheritdoc IGrateful function setCustomFee(uint256 _newFee, address _merchant) external onlyOwner { customFees[_merchant] = CustomFee({isSet: true, fee: _newFee}); @@ -328,7 +334,7 @@ contract Grateful is IGrateful, Ownable2Step { } /*////////////////////////////////////////////////////////////// - INTERNAL FUNCTIONS + PRIVATE FUNCTIONS //////////////////////////////////////////////////////////////*/ /** @@ -359,9 +365,10 @@ contract Grateful is IGrateful, Ownable2Step { uint256 _amount, uint256 _paymentId, bool _yieldFunds - ) internal { + ) private { // Transfer the full amount from the sender to this contract - IERC20(_token).safeTransferFrom(_sender, address(this), _amount); + IERC20 token = IERC20(_token); // Renamed from _token to token + token.safeTransferFrom(_sender, address(this), _amount); // Check payment id if (paymentIds[_paymentId]) { @@ -374,21 +381,83 @@ contract Grateful is IGrateful, Ownable2Step { uint256 amountWithFee = applyFee(_merchant, _amount); // Transfer fee to owner - IERC20(_token).safeTransfer(owner(), _amount - amountWithFee); + uint256 feeAmount = _amount - amountWithFee; + token.safeTransfer(owner(), feeAmount); if (_yieldFunds) { AaveV3Vault vault = vaults[_token]; if (address(vault) == address(0)) { - IERC20(_token).safeTransfer(_merchant, amountWithFee); + token.safeTransfer(_merchant, amountWithFee); } else { - uint256 _shares = vault.deposit(amountWithFee, address(this)); - shares[_merchant][_token] += _shares; + uint256 sharesAmount = vault.deposit(amountWithFee, address(this)); + shares[_merchant][_token] += sharesAmount; + userDeposits[_merchant][_token] += amountWithFee; } } else { // Transfer tokens to merchant - IERC20(_token).safeTransfer(_merchant, amountWithFee); + token.safeTransfer(_merchant, amountWithFee); } emit PaymentProcessed(_sender, _merchant, _token, _amount, _yieldFunds, _paymentId); } + + /** + * @dev Internal function to handle withdrawals. + * @param _token The address of the token to withdraw. + * @param _assets The amount of assets to withdraw (ignored if full withdrawal). + * @param _isFullWithdrawal Indicates if it's a full withdrawal. + */ + function _withdraw(address _token, uint256 _assets, bool _isFullWithdrawal) internal { + AaveV3Vault vault = vaults[_token]; + if (address(vault) == address(0)) { + revert Grateful_VaultNotSet(); + } + + uint256 totalShares = shares[msg.sender][_token]; + uint256 sharesToWithdraw; + uint256 assetsToWithdraw; + + if (_isFullWithdrawal) { + sharesToWithdraw = totalShares; + assetsToWithdraw = vault.previewRedeem(sharesToWithdraw); + } else { + sharesToWithdraw = vault.previewWithdraw(_assets); + if (sharesToWithdraw > totalShares) { + revert Grateful_WithdrawExceedsShares(); + } + assetsToWithdraw = _assets; + } + + uint256 totalAssets = vault.previewRedeem(totalShares); + uint256 initialDeposit = userDeposits[msg.sender][_token]; + + // Calculate proportion of withdrawal + uint256 proportion = assetsToWithdraw.divWadDown(totalAssets); + uint256 initialDepositToWithdraw = initialDeposit.mulWadDown(proportion); + + // Calculate profit and performance fee + uint256 profit = 0; + uint256 performanceFeeAmount = 0; + if (assetsToWithdraw > initialDepositToWithdraw) { + profit = assetsToWithdraw - initialDepositToWithdraw; + performanceFeeAmount = calculatePerformanceFee(profit); + assetsToWithdraw -= performanceFeeAmount; // Deduct fee from assets + + // Withdraw performance fee to fee recipient (owner) + vault.withdraw(performanceFeeAmount, owner(), address(this)); + } + + // Update user's shares and deposits + shares[msg.sender][_token] = totalShares - sharesToWithdraw; + userDeposits[msg.sender][_token] = initialDeposit - initialDepositToWithdraw; + + // Ensure balances are zero in case of full withdrawal to handle rounding errors + if (_isFullWithdrawal) { + shares[msg.sender][_token] = 0; + userDeposits[msg.sender][_token] = 0; + } + + // Withdraw assets to user + vault.withdraw(assetsToWithdraw, msg.sender, address(this)); + } } diff --git a/src/interfaces/IGrateful.sol b/src/interfaces/IGrateful.sol index d7a7b64..8058991 100644 --- a/src/interfaces/IGrateful.sol +++ b/src/interfaces/IGrateful.sol @@ -13,7 +13,7 @@ import {IPool} from "yield-daddy/aave-v3/AaveV3ERC4626.sol"; */ interface IGrateful { /*/////////////////////////////////////////////////////////////// - STRUCTS + STRUCTS //////////////////////////////////////////////////////////////*/ struct CustomFee { @@ -22,7 +22,7 @@ interface IGrateful { } /*/////////////////////////////////////////////////////////////// - EVENTS + EVENTS //////////////////////////////////////////////////////////////*/ /** @@ -65,6 +65,12 @@ interface IGrateful { */ event FeeUpdated(uint256 newFee); + /** + * @notice Emitted when the performance fee rate is updated. + * @param newRate The new performance fee rate in basis points. + */ + event PerformanceFeeRateUpdated(uint256 newRate); + /** * @notice Emitted when a custom fee is set for a merchant. * @param merchant Address of the merchant. @@ -105,7 +111,7 @@ interface IGrateful { event VaultRemoved(address indexed token, address indexed vault); /*/////////////////////////////////////////////////////////////// - ERRORS + ERRORS //////////////////////////////////////////////////////////////*/ /// @notice Thrown when the token is not whitelisted. error Grateful_TokenNotWhitelisted(); @@ -113,14 +119,8 @@ interface IGrateful { /// @notice Thrown when array lengths mismatch. error Grateful_MismatchedArrays(); - /// @notice Thrown when the total percentage is invalid. - error Grateful_InvalidTotalPercentage(); - /// @notice Thrown when the vault for a token is not set. - error Grateful_usdcVaultNotSet(); - - /// @notice Thrown when a token transfer fails. - error Grateful_TransferFailed(); + error Grateful_VaultNotSet(); /// @notice Thrown when the one-time payment is not found. error Grateful_OneTimeNotFound(); @@ -134,8 +134,11 @@ interface IGrateful { /// @notice Thrown when attempting to remove a token or vault that does not exist. error Grateful_TokenOrVaultNotFound(); + /// @notice Thrown when the fee rate is too high. + error Grateful_FeeRateTooHigh(); + /*/////////////////////////////////////////////////////////////// - VARIABLES + VARIABLES //////////////////////////////////////////////////////////////*/ /// @notice Returns the owner of the contract. @@ -164,6 +167,12 @@ interface IGrateful { /// @return Amount of shares. function shares(address _merchant, address _token) external view returns (uint256); + /// @notice Returns the user deposit amount for a merchant and token. + /// @param _merchant Address of the merchant. + /// @param _token Address of the token. + /// @return Amount of initial deposit. + function userDeposits(address _merchant, address _token) external view returns (uint256); + /// @notice Checks if an address is a registered one-time payment. /// @param _address Address to check. /// @return True if it's a registered one-time payment, false otherwise. @@ -175,6 +184,10 @@ interface IGrateful { /// @return Fee in basis points (10000 = 100%). function fee() external view returns (uint256); + /// @notice Returns the performance fee rate. + /// @return Performance fee rate in basis points. + function performanceFeeRate() external view returns (uint256); + /// @notice Returns the custom fee applied to the payments for a merchant. /// @param _merchant Address of the merchant. /// @return isSet True if a custom fee is set for the merchant. @@ -191,7 +204,7 @@ interface IGrateful { ) external view returns (bool isUsed); /*/////////////////////////////////////////////////////////////// - LOGIC + LOGIC //////////////////////////////////////////////////////////////*/ /** @@ -205,9 +218,9 @@ interface IGrateful { /** * @notice Adds a vault for a specific token. * @param _token Address of the token. - * @param _usdcVault Address of the vault contract. + * @param _vault Address of the vault contract. */ - function addVault(address _token, address _usdcVault) external; + function addVault(address _token, address _vault) external; /** * @notice Removes a token from the whitelist. @@ -231,18 +244,18 @@ interface IGrateful { * @param _token Address of the token used for payment. * @param _amount Amount of the token to be paid. * @param _id ID of the payment. - * @param _yieldFunds Whether to yield funds or not + * @param _yieldFunds Whether to yield funds or not. */ function pay(address _merchant, address _token, uint256 _amount, uint256 _id, bool _yieldFunds) external; /** - * @notice Creates a one-time payment without payment splitting. + * @notice Creates a one-time payment. * @param _merchant Address of the merchant. * @param _tokens Array of token addresses. * @param _amount Amount of the token. * @param _salt Salt used for address computation. * @param _paymentId ID of the payment. - * @param _yieldFunds Whether to yield funds or not + * @param _yieldFunds Whether to yield funds or not. * @param precomputed Precomputed address of the OneTime contract. * @return oneTime Address of the created OneTime contract. */ @@ -257,12 +270,12 @@ interface IGrateful { ) external returns (OneTime oneTime); /** - * @notice Receives a one-time payment without payment splitting. + * @notice Receives a one-time payment. * @param _merchant Address of the merchant. * @param _token Token address. * @param _paymentId ID of the payment. - * @param _yieldFunds Whether to yield funds or not * @param _amount Amount of the token. + * @param _yieldFunds Whether to yield funds or not. */ function receiveOneTimePayment( address _merchant, @@ -273,13 +286,13 @@ interface IGrateful { ) external; /** - * @notice Computes the address of a one-time payment contract without payment splitting. + * @notice Computes the address of a one-time payment contract. * @param _merchant Address of the merchant. * @param _tokens Array of token addresses. * @param _amount Amount of the token. * @param _salt Salt used for address computation. * @param _paymentId ID of the payment. - * @param _yieldFunds Whether to yield funds or not + * @param _yieldFunds Whether to yield funds or not. * @return oneTime Address of the computed OneTime contract. */ function computeOneTimeAddress( @@ -352,6 +365,23 @@ interface IGrateful { */ function applyFee(address _merchant, uint256 _amount) external view returns (uint256 amountWithFee); + /** + * @notice Calculates the profit (yield) earned by a user on a specific token. + * @param _user The address of the user. + * @param _token The address of the token. + * @return profit The profit amount. + */ + function calculateProfit(address _user, address _token) external view returns (uint256 profit); + + /** + * @notice Calculates the performance fee on a given profit amount. + * @param _profit The profit amount. + * @return feeAmount The performance fee amount. + */ + function calculatePerformanceFee( + uint256 _profit + ) external view returns (uint256 feeAmount); + /** * @notice Sets a new fee. * @param _newFee New fee to be applied (in basis points, 10000 = 100%). @@ -360,6 +390,14 @@ interface IGrateful { uint256 _newFee ) external; + /** + * @notice Sets the performance fee rate. + * @param _newPerformanceFeeRate The new performance fee rate in basis points. + */ + function setPerformanceFeeRate( + uint256 _newPerformanceFeeRate + ) external; + /** * @notice Sets a new custom fee for a certain merchant. * @param _newFee New fee to be applied (in basis points, 10000 = 100%). diff --git a/test/integration/Grateful.t.sol b/test/integration/Grateful.t.sol index bfd9633..dd49e53 100644 --- a/test/integration/Grateful.t.sol +++ b/test/integration/Grateful.t.sol @@ -2,194 +2,187 @@ pragma solidity 0.8.26; import {OneTime} from "contracts/OneTime.sol"; -import {AaveV3Vault, IntegrationBase} from "test/integration/IntegrationBase.sol"; +import {IntegrationBase} from "test/integration/IntegrationBase.sol"; contract IntegrationGreeter is IntegrationBase { - function _approveAndPay(address payer, address merchant, uint256 amount, bool _yieldFunds) internal { - uint256 paymentId = _grateful.calculateId(payer, merchant, address(_usdc), amount); - vm.startPrank(payer); - _usdc.approve(address(_grateful), amount); - _grateful.pay(merchant, address(_usdc), amount, paymentId, _yieldFunds); - vm.stopPrank(); - } + /*////////////////////////////////////////////////////////////// + TESTS + //////////////////////////////////////////////////////////////*/ + // Tests for Standard Payments function test_Payment() public { - _approveAndPay(_payer, _merchant, _AMOUNT_USDC, false); + _approveAndPay(_payer, _merchant, _AMOUNT_USDC, _NOT_YIELDING_FUNDS); - assertEq(_usdc.balanceOf(_merchant), _grateful.applyFee(_merchant, _AMOUNT_USDC)); + assertEq( + _usdc.balanceOf(_merchant), _grateful.applyFee(_merchant, _AMOUNT_USDC), "Merchant balance mismatch after payment" + ); } function test_PaymentYieldingFunds() public { - _approveAndPay(_payer, _merchant, _AMOUNT_USDC, true); + // Capture owner's initial balance before payment + uint256 ownerInitialBalance = _usdc.balanceOf(_owner); + _approveAndPay(_payer, _merchant, _AMOUNT_USDC, _YIELDING_FUNDS); + + // Advance time to accrue yield vm.warp(block.timestamp + 1 days); - // Get total assets - uint256 _assets = _grateful.calculateAssets(_merchant, address(_usdc)); + // Calculate profit before withdrawal + uint256 profit = _grateful.calculateProfit(_merchant, address(_usdc)); + // Merchant withdraws funds vm.prank(_merchant); _grateful.withdraw(address(_usdc)); - assertEq(_assets, _usdc.balanceOf(_merchant)); - assertGt(_usdc.balanceOf(_merchant), _grateful.applyFee(_merchant, _AMOUNT_USDC)); - } + // Calculate performance fee after withdrawal + uint256 performanceFee = _grateful.calculatePerformanceFee(profit); - function test_OneTimePayment() public { - // 1. Calculate payment id - uint256 paymentId = _grateful.calculateId(_payer, _merchant, address(_usdc), _AMOUNT_USDC); + uint256 initialDeposit = _grateful.applyFee(_merchant, _AMOUNT_USDC); + uint256 expectedMerchantBalance = initialDeposit + profit - performanceFee; - // 2. Precompute address - address precomputed = - address(_grateful.computeOneTimeAddress(_merchant, _tokens, _AMOUNT_USDC, 4, paymentId, false)); - - // 3. Once the payment address is precomputed, the client sends the payment - vm.prank(_payer); - _usdc.transfer(precomputed, _AMOUNT_USDC); // Only tx sent by the client, doesn't need contract interaction + // Verify merchant's balance + assertEq(_usdc.balanceOf(_merchant), expectedMerchantBalance, "Merchant balance mismatch after withdrawal"); - // 4. Merchant calls api to make one time payment to his address - vm.prank(_gratefulAutomation); - _grateful.createOneTimePayment(_merchant, _tokens, _AMOUNT_USDC, 4, paymentId, false, precomputed); + // Verify owner's balance + uint256 ownerFinalBalance = _usdc.balanceOf(_owner); + uint256 initialFee = _AMOUNT_USDC - initialDeposit; + uint256 ownerExpectedBalanceIncrease = initialFee + performanceFee; - // Merchant receives the payment - assertEq(_usdc.balanceOf(_merchant), _grateful.applyFee(_merchant, _AMOUNT_USDC)); + assertEq( + ownerFinalBalance - ownerInitialBalance, + ownerExpectedBalanceIncrease, + "Owner did not receive correct performance fee" + ); } - function test_OverpaidOneTimePayment() public { - // 1. Calculate payment id - uint256 paymentId = _grateful.calculateId(_payer, _merchant, address(_usdc), _AMOUNT_USDC); - - // 2. Precompute address - address precomputed = - address(_grateful.computeOneTimeAddress(_merchant, _tokens, _AMOUNT_USDC, 4, paymentId, false)); - - // 3. Once the payment address is precomputed, the client sends the payment - vm.prank(_payer); - _usdc.transfer(precomputed, _AMOUNT_USDC * 2); // Only tx sent by the client, doesn't need contract interaction - - // 4. Merchant calls api to make one time payment to his address - vm.prank(_gratefulAutomation); - OneTime _oneTime = - _grateful.createOneTimePayment(_merchant, _tokens, _AMOUNT_USDC, 4, paymentId, false, precomputed); + // Tests for One-Time Payments + function test_OneTimePayment() public { + _setupAndExecuteOneTimePayment(_payer, _merchant, _AMOUNT_USDC, _PAYMENT_SALT, _NOT_YIELDING_FUNDS); // Merchant receives the payment - assertEq(_usdc.balanceOf(_merchant), _grateful.applyFee(_merchant, _AMOUNT_USDC)); - - // There are funds in the onetime contract stucked - assertEq(_usdc.balanceOf(address(_oneTime)), _AMOUNT_USDC); - - uint256 prevWhaleBalance = _usdc.balanceOf(_payer); - - // Rescue funds - vm.prank(_owner); - _oneTime.rescueFunds(_usdc, _payer, _AMOUNT_USDC); - - // Client has received his funds - assertEq(_usdc.balanceOf(address(_payer)), prevWhaleBalance + _AMOUNT_USDC); + assertEq( + _usdc.balanceOf(_merchant), + _grateful.applyFee(_merchant, _AMOUNT_USDC), + "Merchant balance mismatch after one-time payment" + ); } - function test_PaymentWithCustomFee() public { - // ------------------------------ - // 1. Set custom fee of 2% (200 basis points) for the merchant - // ------------------------------ - vm.prank(_owner); - _grateful.setCustomFee(200, _merchant); - - // Process payment with custom fee of 2% - _approveAndPay(_payer, _merchant, _AMOUNT_USDC, false); - - // Expected amounts - uint256 expectedCustomFee = (_AMOUNT_USDC * 200) / 10_000; // 2% fee - uint256 expectedMerchantAmount = _AMOUNT_USDC - expectedCustomFee; - - // Verify balances after first payment - assertEq(_usdc.balanceOf(_merchant), expectedMerchantAmount, "Merchant balance mismatch after first payment"); - assertEq(_usdc.balanceOf(_owner), expectedCustomFee, "Owner balance mismatch after first payment"); - - // ------------------------------ - // 2. Set custom fee of 0% (no fee) for the _merchant - // ------------------------------ - vm.prank(_owner); - _grateful.setCustomFee(0, _merchant); + function test_OneTimePaymentYieldingFunds() public { + // Capture owner's initial balance before payment + uint256 ownerInitialBalance = _usdc.balanceOf(_owner); - // Advance time so calculated paymentId doesn't collide - vm.warp(block.timestamp + 1); + // Setup one-time payment with yielding funds + _setupAndExecuteOneTimePayment(_payer, _merchant, _AMOUNT_USDC, _PAYMENT_SALT, _YIELDING_FUNDS); - // Process payment with custom fee of 0% - _approveAndPay(_payer, _merchant, _AMOUNT_USDC, false); + // Advance time to accrue yield + vm.warp(block.timestamp + 1 days); - // Expected amounts - uint256 expectedZeroFee = 0; // 0% fee - uint256 expectedMerchantAmount2 = _AMOUNT_USDC; + // Calculate profit before withdrawal + uint256 profit = _grateful.calculateProfit(_merchant, address(_usdc)); - // Verify balances after second payment - assertEq( - _usdc.balanceOf(_merchant), - expectedMerchantAmount + expectedMerchantAmount2, - "Merchant balance mismatch after second payment" - ); - assertEq( - _usdc.balanceOf(_owner), expectedCustomFee + expectedZeroFee, "Owner balance mismatch after second payment" - ); + // Merchant withdraws funds + vm.prank(_merchant); + _grateful.withdraw(address(_usdc)); - // ------------------------------ - // 3. Unset custom fee for the _merchant (should revert to default fee) - // ------------------------------ - vm.prank(_owner); - _grateful.unsetCustomFee(_merchant); + // Calculate performance fee after withdrawal + uint256 performanceFee = _grateful.calculatePerformanceFee(profit); - // Advance time so calculated paymentId doesn't collide - vm.warp(block.timestamp + 1); + uint256 initialDeposit = _grateful.applyFee(_merchant, _AMOUNT_USDC); + uint256 expectedMerchantBalance = initialDeposit + profit - performanceFee; - // Process payment after unsetting custom fee - _approveAndPay(_payer, _merchant, _AMOUNT_USDC, false); + // Verify merchant's balance + assertEq(_usdc.balanceOf(_merchant), expectedMerchantBalance, "Merchant balance mismatch after withdrawal"); - // Expected amounts - uint256 expectedFeeAfterUnset = (_AMOUNT_USDC * 100) / 10_000; // 1% fee - uint256 expectedMerchantAmount3 = _AMOUNT_USDC - expectedFeeAfterUnset; + // Verify owner's balance + uint256 ownerFinalBalance = _usdc.balanceOf(_owner); + uint256 initialFee = _AMOUNT_USDC - initialDeposit; + uint256 ownerExpectedBalanceIncrease = initialFee + performanceFee; - // Verify balances after fourth payment assertEq( - _usdc.balanceOf(_merchant), - expectedMerchantAmount + expectedMerchantAmount2 + expectedMerchantAmount3, - "Merchant balance mismatch after third payment" - ); - assertEq( - _usdc.balanceOf(_owner), - expectedCustomFee + expectedZeroFee + expectedFeeAfterUnset, - "Owner balance mismatch after fourth payment" + ownerFinalBalance - ownerInitialBalance, + ownerExpectedBalanceIncrease, + "Owner did not receive correct performance fee" ); } - function test_OneTimePaymentYieldingFunds() public { - address[] memory _tokens2 = new address[](1); - _tokens2[0] = _tokens[0]; - - // 1. Calculate payment id + function test_OverpaidOneTimePayment() public { uint256 paymentId = _grateful.calculateId(_payer, _merchant, address(_usdc), _AMOUNT_USDC); + address precomputed = address( + _grateful.computeOneTimeAddress(_merchant, _tokens, _AMOUNT_USDC, _PAYMENT_SALT, paymentId, _NOT_YIELDING_FUNDS) + ); - // 2. Precompute address - address precomputed = address(_grateful.computeOneTimeAddress(_merchant, _tokens, _AMOUNT_USDC, 4, paymentId, true)); - - // 3. Once the payment address is precomputed, the client sends the payment + // Payer sends double the amount vm.prank(_payer); - _usdc.transfer(precomputed, _AMOUNT_USDC); + _usdc.transfer(precomputed, _AMOUNT_USDC * 2); - // 4. Grateful automation calls api to make one time payment to his address vm.prank(_gratefulAutomation); - _grateful.createOneTimePayment(_merchant, _tokens, _AMOUNT_USDC, 4, paymentId, true, precomputed); + OneTime _oneTime = _grateful.createOneTimePayment( + _merchant, _tokens, _AMOUNT_USDC, _PAYMENT_SALT, paymentId, _NOT_YIELDING_FUNDS, precomputed + ); - // 5. Advance time so yield is generated - vm.warp(block.timestamp + 1 days); + // Merchant receives the correct amount + assertEq( + _usdc.balanceOf(_merchant), + _grateful.applyFee(_merchant, _AMOUNT_USDC), + "Merchant balance mismatch after overpaid one-time payment" + ); - // 6. Merchant withdraws funds - vm.prank(_merchant); - _grateful.withdraw(address(_usdc)); + // Verify excess funds are in the OneTime contract + assertEq( + _usdc.balanceOf(address(_oneTime)), _AMOUNT_USDC, "Unexpected balance in OneTime contract after overpayment" + ); + + // Rescue funds + uint256 prevPayerBalance = _usdc.balanceOf(_payer); + vm.prank(_owner); + _oneTime.rescueFunds(_usdc, _payer, _AMOUNT_USDC); - // 7. Check if merchant's balance is greater than the amount with fee applied - assertGt(_usdc.balanceOf(_merchant), _grateful.applyFee(_merchant, _AMOUNT_USDC)); + // Verify payer's balance after rescuing funds + assertEq(_usdc.balanceOf(_payer), prevPayerBalance + _AMOUNT_USDC, "Payer balance mismatch after rescuing funds"); + } - // 8. Check that owner holds the fee amount - uint256 feeAmount = _AMOUNT_USDC - _grateful.applyFee(_merchant, _AMOUNT_USDC); - assertEq(_usdc.balanceOf(_owner), feeAmount); + function test_PaymentWithCustomFee() public { + uint256[] memory customFees = new uint256[](3); + customFees[0] = 200; // 2% + customFees[1] = 0; // 0% + customFees[2] = _FEE; // Default fee after unsetting custom fee + + uint256 expectedOwnerBalance = 0; + uint256 expectedMerchantBalance = 0; + + for (uint256 i = 0; i < customFees.length; i++) { + // Set custom fee + vm.prank(_owner); + if (i < 2) { + _grateful.setCustomFee(customFees[i], _merchant); + } else { + _grateful.unsetCustomFee(_merchant); + } + + // Advance time to prevent payment ID collision + vm.warp(block.timestamp + 1); + + // Process payment + _approveAndPay(_payer, _merchant, _AMOUNT_USDC, _NOT_YIELDING_FUNDS); + + // Calculate expected amounts + uint256 feeAmount = (_AMOUNT_USDC * customFees[i]) / 10_000; + uint256 merchantAmount = _AMOUNT_USDC - feeAmount; + + expectedOwnerBalance += feeAmount; + expectedMerchantBalance += merchantAmount; + + // Verify balances + assertEq( + _usdc.balanceOf(_merchant), + expectedMerchantBalance, + string(abi.encodePacked("Merchant balance mismatch at iteration ", i)) + ); + assertEq( + _usdc.balanceOf(_owner), + expectedOwnerBalance, + string(abi.encodePacked("Owner balance mismatch at iteration ", i)) + ); + } } } diff --git a/test/integration/IntegrationBase.sol b/test/integration/IntegrationBase.sol index fa2f352..fc7c1df 100644 --- a/test/integration/IntegrationBase.sol +++ b/test/integration/IntegrationBase.sol @@ -11,12 +11,22 @@ import {ERC20} from "solmate/tokens/ERC20.sol"; import {IPool, IRewardsController} from "yield-daddy/aave-v3/AaveV3ERC4626.sol"; contract IntegrationBase is Test { - // Constants + /*////////////////////////////////////////////////////////////// + CONSTANTS + //////////////////////////////////////////////////////////////*/ + uint256 internal constant _FORK_BLOCK = 18_920_905; uint256 internal constant _AMOUNT_USDC = 10 * 10 ** 6; // 10 USDC uint256 internal constant _AMOUNT_DAI = 10 * 10 ** 18; // 10 DAI uint256 internal constant _SUBSCRIPTION_PLAN_ID = 0; - uint256 internal constant _FEE = 100; + uint256 internal constant _FEE = 100; // 1% fee + uint256 internal constant _PAYMENT_SALT = 4; // Salt for computing payment addresses + bool internal constant _YIELDING_FUNDS = true; + bool internal constant _NOT_YIELDING_FUNDS = false; + + /*////////////////////////////////////////////////////////////// + ADDRESSES + //////////////////////////////////////////////////////////////*/ // EOAs address internal _user = makeAddr("user"); @@ -44,15 +54,21 @@ contract IntegrationBase is Test { AaveV3Vault internal _usdtVault; AaveV3Vault internal _daiVault; + /*////////////////////////////////////////////////////////////// + SETUP FUNCTION + //////////////////////////////////////////////////////////////*/ + function setUp() public { vm.startPrank(_owner); vm.createSelectFork(vm.rpcUrl("mainnet"), _FORK_BLOCK); - vm.label(address(_usdcVault), "Vault"); + _tokens = new address[](3); _tokens[0] = address(_usdc); _tokens[1] = address(_usdt); _tokens[2] = address(_dai); + _grateful = new Grateful(_tokens, _aavePool, _FEE); + _usdcVault = new AaveV3Vault( ERC20(address(_usdc)), ERC20(_aUsdc), @@ -77,10 +93,61 @@ contract IntegrationBase is Test { IRewardsController(_rewardsController), address(_grateful) ); + vm.label(address(_grateful), "Grateful"); + vm.label(address(_usdcVault), "USDC Vault"); + vm.label(address(_daiVault), "DAI Vault"); + vm.label(address(_usdtVault), "USDT Vault"); + _grateful.addVault(address(_usdc), address(_usdcVault)); _grateful.addVault(address(_usdt), address(_usdtVault)); _grateful.addVault(address(_dai), address(_daiVault)); vm.stopPrank(); } + + /*////////////////////////////////////////////////////////////// + HELPER FUNCTIONS + //////////////////////////////////////////////////////////////*/ + + function _approveAndPay(address payer, address merchant, uint256 amount, bool yieldFunds) internal { + uint256 paymentId = _grateful.calculateId(payer, merchant, address(_usdc), amount); + vm.startPrank(payer); + _usdc.approve(address(_grateful), amount); + _grateful.pay(merchant, address(_usdc), amount, paymentId, yieldFunds); + vm.stopPrank(); + } + + function _calculateExpectedMerchantBalance( + uint256 initialDeposit, + uint256 profit, + uint256 performanceFeeRate + ) internal pure returns (uint256) { + uint256 performanceFee = (profit * performanceFeeRate) / 10_000; + uint256 totalAssets = initialDeposit + profit; + return totalAssets - performanceFee; + } + + function _getOwnerBalanceIncrease(uint256 initialFee, uint256 performanceFee) internal pure returns (uint256) { + return initialFee + performanceFee; + } + + function _captureBalances() internal view returns (uint256 ownerBalance, uint256 merchantBalance) { + ownerBalance = _usdc.balanceOf(_owner); + merchantBalance = _usdc.balanceOf(_merchant); + } + + function _setupAndExecuteOneTimePayment( + address payer, + address merchant, + uint256 amount, + uint256 salt, + bool yieldFunds + ) internal returns (uint256 paymentId, address precomputed) { + paymentId = _grateful.calculateId(payer, merchant, address(_usdc), amount); + precomputed = address(_grateful.computeOneTimeAddress(merchant, _tokens, amount, salt, paymentId, yieldFunds)); + vm.prank(payer); + _usdc.transfer(precomputed, amount); + vm.prank(_gratefulAutomation); + _grateful.createOneTimePayment(merchant, _tokens, amount, salt, paymentId, yieldFunds, precomputed); + } }