Skip to content

Commit

Permalink
perf(onetime): check for precomputed address mismatch before deployin…
Browse files Browse the repository at this point in the history
…g OneTime contract
  • Loading branch information
0xChin committed Dec 7, 2024
1 parent 698eaa3 commit 2a1b644
Show file tree
Hide file tree
Showing 4 changed files with 220 additions and 24 deletions.
13 changes: 8 additions & 5 deletions src/contracts/Grateful.sol
Original file line number Diff line number Diff line change
Expand Up @@ -263,14 +263,17 @@ contract Grateful is IGrateful, Ownable2Step, ReentrancyGuard {
if (_merchant == address(0)) {
revert Grateful_InvalidAddress();
}
oneTimePayments[_precomputed] = true;
oneTime =
new OneTime{salt: bytes32(_salt)}(IGrateful(address(this)), _tokens, _merchant, _amount, _paymentId, _yieldFunds);

if (address(oneTime) != _precomputed) {
address precomputed = address(computeOneTimeAddress(_merchant, _tokens, _amount, _salt, _paymentId, _yieldFunds));

if (precomputed != _precomputed) {
revert Grateful_PrecomputedAddressMismatch();
}

oneTimePayments[_precomputed] = true;
oneTime =
new OneTime{salt: bytes32(_salt)}(IGrateful(address(this)), _tokens, _merchant, _amount, _paymentId, _yieldFunds);

emit OneTimePaymentCreated(_merchant, _tokens, _amount);
}

Expand Down Expand Up @@ -299,7 +302,7 @@ contract Grateful is IGrateful, Ownable2Step, ReentrancyGuard {
uint256 _salt,
uint256 _paymentId,
bool _yieldFunds
) external view returns (OneTime oneTime) {
) public view returns (OneTime oneTime) {
bytes memory bytecode = abi.encodePacked(
type(OneTime).creationCode, abi.encode(address(this), _tokens, _merchant, _amount, _paymentId, _yieldFunds)
);
Expand Down
184 changes: 184 additions & 0 deletions test/unit/OneTimePayment.t.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
// SPDX-License-Identifier: MIT
pragma solidity 0.8.26;

// Base test contract
import {UnitBase} from "./helpers/Base.t.sol";

// Contracts and interfaces

import {Ownable} from "@openzeppelin/contracts/access/Ownable.sol";
import {IERC20} from "@openzeppelin/contracts/interfaces/IERC20.sol";
import {Grateful, IGrateful} from "contracts/Grateful.sol";
import {OneTime} from "contracts/OneTime.sol";
import {ERC20Mock} from "test/mocks/ERC20Mock.sol";

contract UnitOneTimePayment is UnitBase {
function test_createOneTimePaymentSuccessNonYield(uint128 amount, uint128 paymentId, uint128 salt) public {
vm.assume(amount > 0);

// Compute the precomputed address
address precomputed = address(grateful.computeOneTimeAddress(merchant, tokens, amount, salt, paymentId, false));

// User prepares funds
vm.prank(user);
token.mint(user, amount);
vm.prank(user);
token.transfer(precomputed, amount);

// Simulate the Grateful Automation (just use owner or any address for testing)
vm.prank(gratefulAutomation);
// Expect event for creation
vm.expectEmit(true, true, true, true);
emit IGrateful.OneTimePaymentCreated(merchant, tokens, amount);
OneTime oneTime = grateful.createOneTimePayment(merchant, tokens, amount, salt, paymentId, false, precomputed);

uint256 feeAmount = (amount * grateful.fee()) / 1e18;
uint256 expectedMerchantAmount = amount - feeAmount;

assertEq(token.balanceOf(owner), feeAmount, "Owner fee mismatch");
assertEq(token.balanceOf(merchant), expectedMerchantAmount, "Merchant amount mismatch");
// Since yieldFunds = false and tokens are transferred directly
assertEq(grateful.shares(merchant, address(token)), 0, "No shares should be credited");
assertEq(grateful.userDeposits(merchant, address(token)), 0, "No deposits should be recorded");
}

function test_createOneTimePaymentSuccessYield(uint128 amount, uint128 paymentId, uint128 salt) public {
vm.assume(amount > 0);

address precomputed = address(grateful.computeOneTimeAddress(merchant, tokens, amount, salt, paymentId, true));

// User prepares funds
vm.prank(user);
token.mint(user, amount);
vm.prank(user);
token.transfer(precomputed, amount);

vm.prank(gratefulAutomation);
vm.expectEmit(true, true, true, true);
emit IGrateful.OneTimePaymentCreated(merchant, tokens, amount);
OneTime oneTime = grateful.createOneTimePayment(merchant, tokens, amount, salt, paymentId, true, precomputed);

uint256 feeAmount = (amount * grateful.fee()) / 1e18;
uint256 afterFee = amount - feeAmount;

assertEq(token.balanceOf(owner), feeAmount, "Owner fee mismatch");
assertEq(token.balanceOf(merchant), 0, "Merchant should receive no direct tokens");
assertGt(grateful.shares(merchant, address(token)), 0, "Merchant should get shares");
assertEq(grateful.userDeposits(merchant, address(token)), afterFee, "Merchant deposit mismatch");
}

function test_revertIfCreateOneTimePaymentInvalidMerchant(uint128 amount, uint128 paymentId, uint128 salt) public {
vm.assume(amount > 0);

address precomputed = address(grateful.computeOneTimeAddress(address(0), tokens, amount, salt, paymentId, false));

vm.expectRevert(IGrateful.Grateful_InvalidAddress.selector);
vm.prank(gratefulAutomation);
grateful.createOneTimePayment(address(0), tokens, amount, salt, paymentId, false, precomputed);
}

function test_revertIfCreateOneTimePaymentNonWhitelistedToken(uint128 amount, uint128 paymentId, uint128 salt) public {
vm.assume(amount > 0);
ERC20Mock nonWhitelisted = new ERC20Mock();
address[] memory nonWhitelistedTokens = new address[](1);
nonWhitelistedTokens[0] = address(nonWhitelisted);

address precomputed =
address(grateful.computeOneTimeAddress(merchant, nonWhitelistedTokens, amount, salt, paymentId, false));

// User sends funds to precomputed address
vm.prank(user);
nonWhitelisted.mint(user, amount);
vm.prank(user);
nonWhitelisted.transfer(precomputed, amount);

vm.prank(gratefulAutomation);
vm.expectRevert(IGrateful.Grateful_TokenNotWhitelisted.selector);
grateful.createOneTimePayment(merchant, nonWhitelistedTokens, amount, salt, paymentId, false, precomputed);
}

function test_revertIfCreateOneTimePaymentNoFundsSent(uint128 amount, uint128 paymentId, uint128 salt) public {
vm.assume(amount > 0);

address precomputed = address(grateful.computeOneTimeAddress(merchant, tokens, amount, salt, paymentId, false));

// No funds are sent to precomputed, so OneTime contract won't trigger receiveOneTimePayment
vm.prank(gratefulAutomation);
vm.expectEmit(true, true, true, true);
emit IGrateful.OneTimePaymentCreated(merchant, tokens, amount);
OneTime oneTime = grateful.createOneTimePayment(merchant, tokens, amount, salt, paymentId, false, precomputed);

// Without funds, no payment occurs
assertEq(token.balanceOf(merchant), 0, "Merchant should have no funds");
assertEq(token.balanceOf(owner), 0, "Owner should have no fees");
// Also, no shares and no deposits since no payment
assertEq(grateful.shares(merchant, address(token)), 0, "No shares should be credited");
assertEq(grateful.userDeposits(merchant, address(token)), 0, "No deposit should be recorded");
}

function test_revertIfPrecomputedAddressMismatch(uint128 amount, uint128 paymentId, uint128 salt) public {
vm.assume(amount > 0);

// Correct precomputed
address correctPrecomputed =
address(grateful.computeOneTimeAddress(merchant, tokens, amount, salt, paymentId, false));
// Provide a wrong precomputed
address wrongPrecomputed = address(0x1234);

vm.prank(user);
token.mint(user, amount);
vm.prank(user);
token.transfer(correctPrecomputed, amount);

vm.prank(gratefulAutomation);
vm.expectRevert(IGrateful.Grateful_PrecomputedAddressMismatch.selector);
grateful.createOneTimePayment(merchant, tokens, amount, salt, paymentId, false, wrongPrecomputed);
}

function test_OverpaidOneTimePayment(uint128 amount, uint128 paymentId, uint128 salt) public {
vm.assume(amount > 0);
vm.assume(amount < 100 ether);

address precomputed = address(grateful.computeOneTimeAddress(merchant, tokens, amount, salt, paymentId, false));

// User sends double amount
vm.startPrank(user);
token.mint(user, amount * 2);
token.transfer(precomputed, amount * 2);
vm.stopPrank();

vm.prank(gratefulAutomation);
vm.expectEmit(true, true, true, true);
emit IGrateful.OneTimePaymentCreated(merchant, tokens, amount);
OneTime oneTime = grateful.createOneTimePayment(merchant, tokens, amount, salt, paymentId, false, precomputed);

// Merchant receives correct amount
uint256 feeAmount = (amount * grateful.fee()) / 1e18;
uint256 expectedMerchantAmount = amount - feeAmount;

assertEq(token.balanceOf(merchant), expectedMerchantAmount, "Merchant didn't receive correct amount");
assertEq(token.balanceOf(owner), feeAmount, "Owner didn't receive correct fee");

// Excess remains in OneTime contract
assertEq(token.balanceOf(address(oneTime)), amount, "Excess funds not in OneTime contract");

// Rescue excess
uint256 prevUserBalance = token.balanceOf(user);
vm.prank(owner);
oneTime.rescueFunds(IERC20(address(token)), user, amount);
assertEq(token.balanceOf(user), prevUserBalance + amount, "User didn't get rescued funds back");
}

function test_revertIfRescueFundsNotOwner(uint128 amount, uint128 paymentId, uint128 salt) public {
vm.assume(amount > 0);

address precomputed = address(grateful.computeOneTimeAddress(merchant, tokens, amount, salt, paymentId, false));

vm.prank(gratefulAutomation);
OneTime oneTime = grateful.createOneTimePayment(merchant, tokens, amount, salt, paymentId, false, precomputed);

vm.prank(user);
vm.expectRevert(abi.encodeWithSelector(Ownable.OwnableUnauthorizedAccount.selector, user));
oneTime.rescueFunds(IERC20(address(token)), user, amount);
}
}
46 changes: 27 additions & 19 deletions test/unit/Withdrawal.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ contract UnitWithdrawal is UnitBase {
vm.assume(amount > 1e8);
vm.assume(amount <= 10 ether);
vm.assume(amount <= type(uint256).max / grateful.fee());
uint256 paymentId = 1;
uint256 paymentId = grateful.calculateId(user, merchant, address(token), amount);

vm.prank(user);
token.mint(user, amount);
Expand All @@ -30,10 +30,12 @@ contract UnitWithdrawal is UnitBase {
grateful.pay(merchant, address(token), amount, paymentId, true);

uint256 initialDeposit = grateful.userDeposits(merchant, address(token));
uint256 assetsToWithdraw = grateful.calculateAssets(merchant, address(token));
uint256 profit = grateful.calculateProfit(merchant, address(token));

vm.prank(merchant);
vm.expectEmit(true, true, true, true);
emit IGrateful.Withdrawal(merchant, address(token), initialDeposit, 0);
emit IGrateful.Withdrawal(merchant, address(token), assetsToWithdraw, profit);
grateful.withdraw(address(token));

uint256 finalMerchantBalance = token.balanceOf(merchant);
Expand All @@ -42,7 +44,7 @@ contract UnitWithdrawal is UnitBase {

assertEq(finalShares, 0);
assertEq(finalDeposit, 0);
assertEq(finalMerchantBalance, initialDeposit);
assertEq(finalMerchantBalance, assetsToWithdraw);
}

function test_withdrawPartialSuccess(uint128 amount, uint128 withdrawAmount) public {
Expand All @@ -53,7 +55,7 @@ contract UnitWithdrawal is UnitBase {
vm.assume(withdrawAmount <= grateful.applyFee(merchant, amount));
vm.assume(withdrawAmount >= 100_000); // Ensure withdrawAmount is large enough for meaningful tolerance

uint256 paymentId = 1;
uint256 paymentId = grateful.calculateId(user, merchant, address(token), amount);
uint256 tolerance = withdrawAmount / 10_000; // 0.01% precision loss tolerance

vm.prank(user);
Expand All @@ -65,10 +67,11 @@ contract UnitWithdrawal is UnitBase {

uint256 initialShares = grateful.shares(merchant, address(token));
uint256 initialDeposit = grateful.userDeposits(merchant, address(token));
uint256 profit = grateful.calculateProfit(merchant, address(token));

vm.prank(merchant);
vm.expectEmit(true, true, true, true);
emit IGrateful.Withdrawal(merchant, address(token), withdrawAmount, 0);
emit IGrateful.Withdrawal(merchant, address(token), withdrawAmount, profit);
grateful.withdraw(address(token), withdrawAmount);

uint256 finalMerchantBalance = token.balanceOf(merchant);
Expand All @@ -91,8 +94,8 @@ contract UnitWithdrawal is UnitBase {

(address token2, AaveV3Vault vault2) = _deployNewTokenAndVault();

uint256 paymentId1 = 1;
uint256 paymentId2 = 2;
uint256 paymentId1 = grateful.calculateId(user, merchant, address(token), amount);
uint256 paymentId2 = grateful.calculateId(user, merchant, token2, amount);

vm.startPrank(user);
token.mint(user, amount);
Expand All @@ -109,21 +112,23 @@ contract UnitWithdrawal is UnitBase {
tokens[0] = address(token);
tokens[1] = token2;

uint256 expectedMerchantBalanceToken1 = grateful.userDeposits(merchant, address(token));
uint256 expectedMerchantBalanceToken2 = grateful.userDeposits(merchant, token2);
uint256 assetsToken1 = grateful.calculateAssets(merchant, address(token));
uint256 assetsToken2 = grateful.calculateAssets(merchant, token2);
uint256 profitToken1 = grateful.calculateProfit(merchant, address(token));
uint256 profitToken2 = grateful.calculateProfit(merchant, token2);

vm.prank(merchant);
vm.expectEmit(true, true, true, true);
emit IGrateful.Withdrawal(merchant, address(token), expectedMerchantBalanceToken1, 0);
emit IGrateful.Withdrawal(merchant, address(token), assetsToken1, profitToken1);
vm.expectEmit(true, true, true, true);
emit IGrateful.Withdrawal(merchant, token2, expectedMerchantBalanceToken2, 0);
emit IGrateful.Withdrawal(merchant, token2, assetsToken2, profitToken2);
grateful.withdrawMultiple(tokens);

uint256 finalMerchantBalanceToken1 = token.balanceOf(merchant);
uint256 finalMerchantBalanceToken2 = ERC20Mock(token2).balanceOf(merchant);

assertEq(finalMerchantBalanceToken1, expectedMerchantBalanceToken1);
assertEq(finalMerchantBalanceToken2, expectedMerchantBalanceToken2);
assertEq(finalMerchantBalanceToken1, assetsToken1);
assertEq(finalMerchantBalanceToken2, assetsToken2);
}

function test_withdrawMultiplePartialSuccess(uint128 amount, uint128 withdrawAmount) public {
Expand All @@ -136,8 +141,8 @@ contract UnitWithdrawal is UnitBase {

(address token2, AaveV3Vault vault2) = _deployNewTokenAndVault();

uint256 paymentId1 = 1;
uint256 paymentId2 = 2;
uint256 paymentId1 = grateful.calculateId(user, merchant, address(token), amount);
uint256 paymentId2 = grateful.calculateId(user, merchant, token2, amount);

vm.startPrank(user);
token.mint(user, amount);
Expand All @@ -157,11 +162,14 @@ contract UnitWithdrawal is UnitBase {
assets[0] = withdrawAmount;
assets[1] = withdrawAmount;

uint256 profitToken1 = grateful.calculateProfit(merchant, address(token));
uint256 profitToken2 = grateful.calculateProfit(merchant, token2);

vm.prank(merchant);
vm.expectEmit(true, true, true, true);
emit IGrateful.Withdrawal(merchant, address(token), assets[0], 0);
emit IGrateful.Withdrawal(merchant, address(token), assets[0], profitToken1);
vm.expectEmit(true, true, true, true);
emit IGrateful.Withdrawal(merchant, token2, assets[1], 0);
emit IGrateful.Withdrawal(merchant, token2, assets[1], profitToken2);
grateful.withdrawMultiple(tokens, assets);

uint256 finalMerchantBalanceToken1 = token.balanceOf(merchant);
Expand Down Expand Up @@ -202,7 +210,7 @@ contract UnitWithdrawal is UnitBase {
vm.assume(amount > 0);
vm.assume(amount <= 10 ether);
vm.assume(amount <= type(uint256).max / grateful.fee());
uint256 paymentId = 1;
uint256 paymentId = grateful.calculateId(user, merchant, address(token), amount);

vm.prank(user);
token.mint(user, amount);
Expand All @@ -222,7 +230,7 @@ contract UnitWithdrawal is UnitBase {
vm.assume(amount > 0);
vm.assume(amount <= 10 ether);
vm.assume(amount <= type(uint256).max / grateful.fee());
uint256 paymentId = 1;
uint256 paymentId = grateful.calculateId(user, merchant, address(token), amount);

vm.prank(user);
token.mint(user, amount);
Expand Down
1 change: 1 addition & 0 deletions test/unit/helpers/Base.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ contract UnitBase is Test {
address internal owner = address(0x1);
address internal merchant = address(0x2);
address internal user = address(0x3);
address internal gratefulAutomation = address(0x4);

// Token and fee parameters
address[] internal tokens;
Expand Down

0 comments on commit 2a1b644

Please sign in to comment.