From 2a1b64475bbe749b89d4db47a867c56a38c7414b Mon Sep 17 00:00:00 2001 From: 0xchin Date: Fri, 6 Dec 2024 22:29:19 -0300 Subject: [PATCH] perf(onetime): check for precomputed address mismatch before deploying OneTime contract --- src/contracts/Grateful.sol | 13 ++- test/unit/OneTimePayment.t.sol | 184 +++++++++++++++++++++++++++++++++ test/unit/Withdrawal.t.sol | 46 +++++---- test/unit/helpers/Base.t.sol | 1 + 4 files changed, 220 insertions(+), 24 deletions(-) create mode 100644 test/unit/OneTimePayment.t.sol diff --git a/src/contracts/Grateful.sol b/src/contracts/Grateful.sol index 03042d0..7d2bcab 100644 --- a/src/contracts/Grateful.sol +++ b/src/contracts/Grateful.sol @@ -263,14 +263,17 @@ contract Grateful is IGrateful, Ownable2Step, ReentrancyGuard { if (_merchant == address(0)) { revert Grateful_InvalidAddress(); } - oneTimePayments[_precomputed] = true; - oneTime = - new OneTime{salt: bytes32(_salt)}(IGrateful(address(this)), _tokens, _merchant, _amount, _paymentId, _yieldFunds); - if (address(oneTime) != _precomputed) { + address precomputed = address(computeOneTimeAddress(_merchant, _tokens, _amount, _salt, _paymentId, _yieldFunds)); + + if (precomputed != _precomputed) { revert Grateful_PrecomputedAddressMismatch(); } + oneTimePayments[_precomputed] = true; + oneTime = + new OneTime{salt: bytes32(_salt)}(IGrateful(address(this)), _tokens, _merchant, _amount, _paymentId, _yieldFunds); + emit OneTimePaymentCreated(_merchant, _tokens, _amount); } @@ -299,7 +302,7 @@ contract Grateful is IGrateful, Ownable2Step, ReentrancyGuard { uint256 _salt, uint256 _paymentId, bool _yieldFunds - ) external view returns (OneTime oneTime) { + ) public view returns (OneTime oneTime) { bytes memory bytecode = abi.encodePacked( type(OneTime).creationCode, abi.encode(address(this), _tokens, _merchant, _amount, _paymentId, _yieldFunds) ); diff --git a/test/unit/OneTimePayment.t.sol b/test/unit/OneTimePayment.t.sol new file mode 100644 index 0000000..cd34b59 --- /dev/null +++ b/test/unit/OneTimePayment.t.sol @@ -0,0 +1,184 @@ +// SPDX-License-Identifier: MIT +pragma solidity 0.8.26; + +// Base test contract +import {UnitBase} from "./helpers/Base.t.sol"; + +// Contracts and interfaces + +import {Ownable} from "@openzeppelin/contracts/access/Ownable.sol"; +import {IERC20} from "@openzeppelin/contracts/interfaces/IERC20.sol"; +import {Grateful, IGrateful} from "contracts/Grateful.sol"; +import {OneTime} from "contracts/OneTime.sol"; +import {ERC20Mock} from "test/mocks/ERC20Mock.sol"; + +contract UnitOneTimePayment is UnitBase { + function test_createOneTimePaymentSuccessNonYield(uint128 amount, uint128 paymentId, uint128 salt) public { + vm.assume(amount > 0); + + // Compute the precomputed address + address precomputed = address(grateful.computeOneTimeAddress(merchant, tokens, amount, salt, paymentId, false)); + + // User prepares funds + vm.prank(user); + token.mint(user, amount); + vm.prank(user); + token.transfer(precomputed, amount); + + // Simulate the Grateful Automation (just use owner or any address for testing) + vm.prank(gratefulAutomation); + // Expect event for creation + vm.expectEmit(true, true, true, true); + emit IGrateful.OneTimePaymentCreated(merchant, tokens, amount); + OneTime oneTime = grateful.createOneTimePayment(merchant, tokens, amount, salt, paymentId, false, precomputed); + + uint256 feeAmount = (amount * grateful.fee()) / 1e18; + uint256 expectedMerchantAmount = amount - feeAmount; + + assertEq(token.balanceOf(owner), feeAmount, "Owner fee mismatch"); + assertEq(token.balanceOf(merchant), expectedMerchantAmount, "Merchant amount mismatch"); + // Since yieldFunds = false and tokens are transferred directly + assertEq(grateful.shares(merchant, address(token)), 0, "No shares should be credited"); + assertEq(grateful.userDeposits(merchant, address(token)), 0, "No deposits should be recorded"); + } + + function test_createOneTimePaymentSuccessYield(uint128 amount, uint128 paymentId, uint128 salt) public { + vm.assume(amount > 0); + + address precomputed = address(grateful.computeOneTimeAddress(merchant, tokens, amount, salt, paymentId, true)); + + // User prepares funds + vm.prank(user); + token.mint(user, amount); + vm.prank(user); + token.transfer(precomputed, amount); + + vm.prank(gratefulAutomation); + vm.expectEmit(true, true, true, true); + emit IGrateful.OneTimePaymentCreated(merchant, tokens, amount); + OneTime oneTime = grateful.createOneTimePayment(merchant, tokens, amount, salt, paymentId, true, precomputed); + + uint256 feeAmount = (amount * grateful.fee()) / 1e18; + uint256 afterFee = amount - feeAmount; + + assertEq(token.balanceOf(owner), feeAmount, "Owner fee mismatch"); + assertEq(token.balanceOf(merchant), 0, "Merchant should receive no direct tokens"); + assertGt(grateful.shares(merchant, address(token)), 0, "Merchant should get shares"); + assertEq(grateful.userDeposits(merchant, address(token)), afterFee, "Merchant deposit mismatch"); + } + + function test_revertIfCreateOneTimePaymentInvalidMerchant(uint128 amount, uint128 paymentId, uint128 salt) public { + vm.assume(amount > 0); + + address precomputed = address(grateful.computeOneTimeAddress(address(0), tokens, amount, salt, paymentId, false)); + + vm.expectRevert(IGrateful.Grateful_InvalidAddress.selector); + vm.prank(gratefulAutomation); + grateful.createOneTimePayment(address(0), tokens, amount, salt, paymentId, false, precomputed); + } + + function test_revertIfCreateOneTimePaymentNonWhitelistedToken(uint128 amount, uint128 paymentId, uint128 salt) public { + vm.assume(amount > 0); + ERC20Mock nonWhitelisted = new ERC20Mock(); + address[] memory nonWhitelistedTokens = new address[](1); + nonWhitelistedTokens[0] = address(nonWhitelisted); + + address precomputed = + address(grateful.computeOneTimeAddress(merchant, nonWhitelistedTokens, amount, salt, paymentId, false)); + + // User sends funds to precomputed address + vm.prank(user); + nonWhitelisted.mint(user, amount); + vm.prank(user); + nonWhitelisted.transfer(precomputed, amount); + + vm.prank(gratefulAutomation); + vm.expectRevert(IGrateful.Grateful_TokenNotWhitelisted.selector); + grateful.createOneTimePayment(merchant, nonWhitelistedTokens, amount, salt, paymentId, false, precomputed); + } + + function test_revertIfCreateOneTimePaymentNoFundsSent(uint128 amount, uint128 paymentId, uint128 salt) public { + vm.assume(amount > 0); + + address precomputed = address(grateful.computeOneTimeAddress(merchant, tokens, amount, salt, paymentId, false)); + + // No funds are sent to precomputed, so OneTime contract won't trigger receiveOneTimePayment + vm.prank(gratefulAutomation); + vm.expectEmit(true, true, true, true); + emit IGrateful.OneTimePaymentCreated(merchant, tokens, amount); + OneTime oneTime = grateful.createOneTimePayment(merchant, tokens, amount, salt, paymentId, false, precomputed); + + // Without funds, no payment occurs + assertEq(token.balanceOf(merchant), 0, "Merchant should have no funds"); + assertEq(token.balanceOf(owner), 0, "Owner should have no fees"); + // Also, no shares and no deposits since no payment + assertEq(grateful.shares(merchant, address(token)), 0, "No shares should be credited"); + assertEq(grateful.userDeposits(merchant, address(token)), 0, "No deposit should be recorded"); + } + + function test_revertIfPrecomputedAddressMismatch(uint128 amount, uint128 paymentId, uint128 salt) public { + vm.assume(amount > 0); + + // Correct precomputed + address correctPrecomputed = + address(grateful.computeOneTimeAddress(merchant, tokens, amount, salt, paymentId, false)); + // Provide a wrong precomputed + address wrongPrecomputed = address(0x1234); + + vm.prank(user); + token.mint(user, amount); + vm.prank(user); + token.transfer(correctPrecomputed, amount); + + vm.prank(gratefulAutomation); + vm.expectRevert(IGrateful.Grateful_PrecomputedAddressMismatch.selector); + grateful.createOneTimePayment(merchant, tokens, amount, salt, paymentId, false, wrongPrecomputed); + } + + function test_OverpaidOneTimePayment(uint128 amount, uint128 paymentId, uint128 salt) public { + vm.assume(amount > 0); + vm.assume(amount < 100 ether); + + address precomputed = address(grateful.computeOneTimeAddress(merchant, tokens, amount, salt, paymentId, false)); + + // User sends double amount + vm.startPrank(user); + token.mint(user, amount * 2); + token.transfer(precomputed, amount * 2); + vm.stopPrank(); + + vm.prank(gratefulAutomation); + vm.expectEmit(true, true, true, true); + emit IGrateful.OneTimePaymentCreated(merchant, tokens, amount); + OneTime oneTime = grateful.createOneTimePayment(merchant, tokens, amount, salt, paymentId, false, precomputed); + + // Merchant receives correct amount + uint256 feeAmount = (amount * grateful.fee()) / 1e18; + uint256 expectedMerchantAmount = amount - feeAmount; + + assertEq(token.balanceOf(merchant), expectedMerchantAmount, "Merchant didn't receive correct amount"); + assertEq(token.balanceOf(owner), feeAmount, "Owner didn't receive correct fee"); + + // Excess remains in OneTime contract + assertEq(token.balanceOf(address(oneTime)), amount, "Excess funds not in OneTime contract"); + + // Rescue excess + uint256 prevUserBalance = token.balanceOf(user); + vm.prank(owner); + oneTime.rescueFunds(IERC20(address(token)), user, amount); + assertEq(token.balanceOf(user), prevUserBalance + amount, "User didn't get rescued funds back"); + } + + function test_revertIfRescueFundsNotOwner(uint128 amount, uint128 paymentId, uint128 salt) public { + vm.assume(amount > 0); + + address precomputed = address(grateful.computeOneTimeAddress(merchant, tokens, amount, salt, paymentId, false)); + + vm.prank(gratefulAutomation); + OneTime oneTime = grateful.createOneTimePayment(merchant, tokens, amount, salt, paymentId, false, precomputed); + + vm.prank(user); + vm.expectRevert(abi.encodeWithSelector(Ownable.OwnableUnauthorizedAccount.selector, user)); + oneTime.rescueFunds(IERC20(address(token)), user, amount); + } +} diff --git a/test/unit/Withdrawal.t.sol b/test/unit/Withdrawal.t.sol index 1ed213e..e5965eb 100644 --- a/test/unit/Withdrawal.t.sol +++ b/test/unit/Withdrawal.t.sol @@ -20,7 +20,7 @@ contract UnitWithdrawal is UnitBase { vm.assume(amount > 1e8); vm.assume(amount <= 10 ether); vm.assume(amount <= type(uint256).max / grateful.fee()); - uint256 paymentId = 1; + uint256 paymentId = grateful.calculateId(user, merchant, address(token), amount); vm.prank(user); token.mint(user, amount); @@ -30,10 +30,12 @@ contract UnitWithdrawal is UnitBase { grateful.pay(merchant, address(token), amount, paymentId, true); uint256 initialDeposit = grateful.userDeposits(merchant, address(token)); + uint256 assetsToWithdraw = grateful.calculateAssets(merchant, address(token)); + uint256 profit = grateful.calculateProfit(merchant, address(token)); vm.prank(merchant); vm.expectEmit(true, true, true, true); - emit IGrateful.Withdrawal(merchant, address(token), initialDeposit, 0); + emit IGrateful.Withdrawal(merchant, address(token), assetsToWithdraw, profit); grateful.withdraw(address(token)); uint256 finalMerchantBalance = token.balanceOf(merchant); @@ -42,7 +44,7 @@ contract UnitWithdrawal is UnitBase { assertEq(finalShares, 0); assertEq(finalDeposit, 0); - assertEq(finalMerchantBalance, initialDeposit); + assertEq(finalMerchantBalance, assetsToWithdraw); } function test_withdrawPartialSuccess(uint128 amount, uint128 withdrawAmount) public { @@ -53,7 +55,7 @@ contract UnitWithdrawal is UnitBase { vm.assume(withdrawAmount <= grateful.applyFee(merchant, amount)); vm.assume(withdrawAmount >= 100_000); // Ensure withdrawAmount is large enough for meaningful tolerance - uint256 paymentId = 1; + uint256 paymentId = grateful.calculateId(user, merchant, address(token), amount); uint256 tolerance = withdrawAmount / 10_000; // 0.01% precision loss tolerance vm.prank(user); @@ -65,10 +67,11 @@ contract UnitWithdrawal is UnitBase { uint256 initialShares = grateful.shares(merchant, address(token)); uint256 initialDeposit = grateful.userDeposits(merchant, address(token)); + uint256 profit = grateful.calculateProfit(merchant, address(token)); vm.prank(merchant); vm.expectEmit(true, true, true, true); - emit IGrateful.Withdrawal(merchant, address(token), withdrawAmount, 0); + emit IGrateful.Withdrawal(merchant, address(token), withdrawAmount, profit); grateful.withdraw(address(token), withdrawAmount); uint256 finalMerchantBalance = token.balanceOf(merchant); @@ -91,8 +94,8 @@ contract UnitWithdrawal is UnitBase { (address token2, AaveV3Vault vault2) = _deployNewTokenAndVault(); - uint256 paymentId1 = 1; - uint256 paymentId2 = 2; + uint256 paymentId1 = grateful.calculateId(user, merchant, address(token), amount); + uint256 paymentId2 = grateful.calculateId(user, merchant, token2, amount); vm.startPrank(user); token.mint(user, amount); @@ -109,21 +112,23 @@ contract UnitWithdrawal is UnitBase { tokens[0] = address(token); tokens[1] = token2; - uint256 expectedMerchantBalanceToken1 = grateful.userDeposits(merchant, address(token)); - uint256 expectedMerchantBalanceToken2 = grateful.userDeposits(merchant, token2); + uint256 assetsToken1 = grateful.calculateAssets(merchant, address(token)); + uint256 assetsToken2 = grateful.calculateAssets(merchant, token2); + uint256 profitToken1 = grateful.calculateProfit(merchant, address(token)); + uint256 profitToken2 = grateful.calculateProfit(merchant, token2); vm.prank(merchant); vm.expectEmit(true, true, true, true); - emit IGrateful.Withdrawal(merchant, address(token), expectedMerchantBalanceToken1, 0); + emit IGrateful.Withdrawal(merchant, address(token), assetsToken1, profitToken1); vm.expectEmit(true, true, true, true); - emit IGrateful.Withdrawal(merchant, token2, expectedMerchantBalanceToken2, 0); + emit IGrateful.Withdrawal(merchant, token2, assetsToken2, profitToken2); grateful.withdrawMultiple(tokens); uint256 finalMerchantBalanceToken1 = token.balanceOf(merchant); uint256 finalMerchantBalanceToken2 = ERC20Mock(token2).balanceOf(merchant); - assertEq(finalMerchantBalanceToken1, expectedMerchantBalanceToken1); - assertEq(finalMerchantBalanceToken2, expectedMerchantBalanceToken2); + assertEq(finalMerchantBalanceToken1, assetsToken1); + assertEq(finalMerchantBalanceToken2, assetsToken2); } function test_withdrawMultiplePartialSuccess(uint128 amount, uint128 withdrawAmount) public { @@ -136,8 +141,8 @@ contract UnitWithdrawal is UnitBase { (address token2, AaveV3Vault vault2) = _deployNewTokenAndVault(); - uint256 paymentId1 = 1; - uint256 paymentId2 = 2; + uint256 paymentId1 = grateful.calculateId(user, merchant, address(token), amount); + uint256 paymentId2 = grateful.calculateId(user, merchant, token2, amount); vm.startPrank(user); token.mint(user, amount); @@ -157,11 +162,14 @@ contract UnitWithdrawal is UnitBase { assets[0] = withdrawAmount; assets[1] = withdrawAmount; + uint256 profitToken1 = grateful.calculateProfit(merchant, address(token)); + uint256 profitToken2 = grateful.calculateProfit(merchant, token2); + vm.prank(merchant); vm.expectEmit(true, true, true, true); - emit IGrateful.Withdrawal(merchant, address(token), assets[0], 0); + emit IGrateful.Withdrawal(merchant, address(token), assets[0], profitToken1); vm.expectEmit(true, true, true, true); - emit IGrateful.Withdrawal(merchant, token2, assets[1], 0); + emit IGrateful.Withdrawal(merchant, token2, assets[1], profitToken2); grateful.withdrawMultiple(tokens, assets); uint256 finalMerchantBalanceToken1 = token.balanceOf(merchant); @@ -202,7 +210,7 @@ contract UnitWithdrawal is UnitBase { vm.assume(amount > 0); vm.assume(amount <= 10 ether); vm.assume(amount <= type(uint256).max / grateful.fee()); - uint256 paymentId = 1; + uint256 paymentId = grateful.calculateId(user, merchant, address(token), amount); vm.prank(user); token.mint(user, amount); @@ -222,7 +230,7 @@ contract UnitWithdrawal is UnitBase { vm.assume(amount > 0); vm.assume(amount <= 10 ether); vm.assume(amount <= type(uint256).max / grateful.fee()); - uint256 paymentId = 1; + uint256 paymentId = grateful.calculateId(user, merchant, address(token), amount); vm.prank(user); token.mint(user, amount); diff --git a/test/unit/helpers/Base.t.sol b/test/unit/helpers/Base.t.sol index 1037cff..755003d 100644 --- a/test/unit/helpers/Base.t.sol +++ b/test/unit/helpers/Base.t.sol @@ -44,6 +44,7 @@ contract UnitBase is Test { address internal owner = address(0x1); address internal merchant = address(0x2); address internal user = address(0x3); + address internal gratefulAutomation = address(0x4); // Token and fee parameters address[] internal tokens;