From f0f08f4087f2585463a40f5893ca849dbdd58311 Mon Sep 17 00:00:00 2001 From: andreivladbrg Date: Tue, 10 Oct 2023 17:57:19 +0300 Subject: [PATCH] test: use solady to sort merkle tree --- .../merkle-streamer/MerkleStreamerLL.t.sol | 19 +++-- test/utils/Defaults.sol | 48 ++++++------ test/utils/MerkleBuilder.sol | 77 ++++++------------- test/utils/MerkleBuilder.t.sol | 11 +-- 4 files changed, 69 insertions(+), 86 deletions(-) diff --git a/test/fork/merkle-streamer/MerkleStreamerLL.t.sol b/test/fork/merkle-streamer/MerkleStreamerLL.t.sol index 69f28a25..88504342 100644 --- a/test/fork/merkle-streamer/MerkleStreamerLL.t.sol +++ b/test/fork/merkle-streamer/MerkleStreamerLL.t.sol @@ -1,6 +1,7 @@ // SPDX-License-Identifier: UNLICENSED pragma solidity >=0.8.19 <0.9.0; +import { Arrays } from "@openzeppelin/contracts/utils/Arrays.sol"; import { IERC20 } from "@openzeppelin/contracts/token/ERC20/IERC20.sol"; import { Lockup, LockupLinear } from "@sablier/v2-core/src/types/DataTypes.sol"; @@ -10,6 +11,8 @@ import { MerkleBuilder } from "../../utils/MerkleBuilder.sol"; import { Fork_Test } from "../Fork.t.sol"; abstract contract MerkleStreamerLL_Fork_Test is Fork_Test { + using MerkleBuilder for uint256[]; + constructor(IERC20 asset_) Fork_Test(asset_) { } function setUp() public virtual override { @@ -41,14 +44,16 @@ abstract contract MerkleStreamerLL_Fork_Test is Fork_Test { uint256 expectedStreamId; uint256[] indexes; uint256 leafPos; - bytes32 leafToClaim; - bytes32[] leaves; + uint256 leafToClaim; ISablierV2MerkleStreamerLL merkleStreamerLL; bytes32 merkleRoot; address[] recipients; uint256 recipientsCount; } + // We need the leaves as storage variable so that we can use Arrays.findUpperBound function. + uint256[] public leaves; + function testForkFuzz_MerkleStreamerLL(Params memory params) external { vm.assume(params.admin != address(0) && params.admin != users.admin.addr); vm.assume(params.expiration == 0 || params.expiration > block.timestamp); @@ -77,8 +82,10 @@ abstract contract MerkleStreamerLL_Fork_Test is Fork_Test { vars.recipients[i] = address(uint160(boundedRecipientSeed)); } - vars.leaves = MerkleBuilder.sort(MerkleBuilder.computeLeaves(vars.indexes, vars.recipients, vars.amounts)); - vars.merkleRoot = getRoot(vars.leaves); + leaves = new uint256[](vars.recipientsCount); + leaves = MerkleBuilder.computeLeaves(vars.indexes, vars.recipients, vars.amounts); + MerkleBuilder.sortLeaves(leaves); + vars.merkleRoot = getRoot(leaves.toBytes32()); vars.expectedStreamerLL = computeMerkleStreamerLLAddress(params.admin, vars.merkleRoot, params.expiration); vm.expectEmit({ emitter: address(merkleStreamerFactory) }); @@ -132,7 +139,7 @@ abstract contract MerkleStreamerLL_Fork_Test is Fork_Test { vars.recipients[params.beforeSortPos], vars.amounts[params.beforeSortPos] ); - vars.leafPos = MerkleBuilder.binarySearch(vars.leaves, vars.leafToClaim); + vars.leafPos = Arrays.findUpperBound(leaves, vars.leafToClaim); vars.expectedStreamId = lockupLinear.nextStreamId(); emit Claim( @@ -145,7 +152,7 @@ abstract contract MerkleStreamerLL_Fork_Test is Fork_Test { index: vars.indexes[params.beforeSortPos], recipient: vars.recipients[params.beforeSortPos], amount: vars.amounts[params.beforeSortPos], - merkleProof: getProof(vars.leaves, vars.leafPos) + merkleProof: getProof(leaves.toBytes32(), vars.leafPos) }); vars.actualStream = lockupLinear.getStream(vars.actualStreamId); diff --git a/test/utils/Defaults.sol b/test/utils/Defaults.sol index 048abd04..24838493 100644 --- a/test/utils/Defaults.sol +++ b/test/utils/Defaults.sol @@ -1,6 +1,7 @@ // SPDX-License-Identifier: GPL-3.0-or-later pragma solidity >=0.8.19 <0.9.0; +import { Arrays } from "@openzeppelin/contracts/utils/Arrays.sol"; import { IERC20 } from "@openzeppelin/contracts/token/ERC20/IERC20.sol"; import { IPRBProxy } from "@prb/proxy/src/interfaces/IPRBProxy.sol"; import { ud2x18, UD60x18 } from "@sablier/v2-core/src/types/Math.sol"; @@ -19,6 +20,8 @@ import { Users } from "./Types.sol"; /// @notice Contract with default values for testing. contract Defaults is Merkle, PermitSignature { + using MerkleBuilder for uint256[]; + /*////////////////////////////////////////////////////////////////////////// GENERIC CONSTANTS //////////////////////////////////////////////////////////////////////////*/ @@ -51,44 +54,43 @@ contract Defaults is Merkle, PermitSignature { uint256 public constant INDEX3 = 3; uint256 public constant INDEX4 = 4; string public IPFS_CID = "QmbWqxBEKC3P8tqsKc98xmWNzrzDtRLMiMPL8wBuTGsMnR"; + uint256[] public leaves = new uint256[](RECIPIENTS_COUNT); + bytes32 public merkleRoot; uint256 public constant RECIPIENTS_COUNT = 4; bool public constant TRANSFERABLE = false; function index1Proof() public view returns (bytes32[] memory) { - bytes32 leaf = MerkleBuilder.computeLeaf(INDEX1, users.recipient1.addr, CLAIM_AMOUNT); - uint256 pos = MerkleBuilder.binarySearch(leaves(), leaf); - return getProof(leaves(), pos); + uint256 leaf = MerkleBuilder.computeLeaf(INDEX1, users.recipient1.addr, CLAIM_AMOUNT); + uint256 pos = Arrays.findUpperBound(leaves, leaf); + return getProof(leaves.toBytes32(), pos); } function index2Proof() public view returns (bytes32[] memory) { - bytes32 leaf = MerkleBuilder.computeLeaf(INDEX2, users.recipient2.addr, CLAIM_AMOUNT); - uint256 pos = MerkleBuilder.binarySearch(leaves(), leaf); - return getProof(leaves(), pos); + uint256 leaf = MerkleBuilder.computeLeaf(INDEX2, users.recipient2.addr, CLAIM_AMOUNT); + uint256 pos = Arrays.findUpperBound(leaves, leaf); + return getProof(leaves.toBytes32(), pos); } function index3Proof() public view returns (bytes32[] memory) { - bytes32 leaf = MerkleBuilder.computeLeaf(INDEX3, users.recipient3.addr, CLAIM_AMOUNT); - uint256 pos = MerkleBuilder.binarySearch(leaves(), leaf); - return getProof(leaves(), pos); + uint256 leaf = MerkleBuilder.computeLeaf(INDEX3, users.recipient3.addr, CLAIM_AMOUNT); + uint256 pos = Arrays.findUpperBound(leaves, leaf); + return getProof(leaves.toBytes32(), pos); } function index4Proof() public view returns (bytes32[] memory) { - bytes32 leaf = MerkleBuilder.computeLeaf(INDEX4, users.recipient4.addr, CLAIM_AMOUNT); - uint256 pos = MerkleBuilder.binarySearch(leaves(), leaf); - return getProof(leaves(), pos); + uint256 leaf = MerkleBuilder.computeLeaf(INDEX4, users.recipient4.addr, CLAIM_AMOUNT); + uint256 pos = Arrays.findUpperBound(leaves, leaf); + return getProof(leaves.toBytes32(), pos); } - function leaves() public view returns (bytes32[] memory) { - bytes32[] memory leaves_ = new bytes32[](RECIPIENTS_COUNT); - leaves_[0] = MerkleBuilder.computeLeaf(INDEX1, users.recipient1.addr, CLAIM_AMOUNT); - leaves_[1] = MerkleBuilder.computeLeaf(INDEX2, users.recipient2.addr, CLAIM_AMOUNT); - leaves_[2] = MerkleBuilder.computeLeaf(INDEX3, users.recipient3.addr, CLAIM_AMOUNT); - leaves_[3] = MerkleBuilder.computeLeaf(INDEX4, users.recipient4.addr, CLAIM_AMOUNT); - return MerkleBuilder.sort(leaves_); - } + function _initMerkleTree() internal { + leaves[0] = MerkleBuilder.computeLeaf(INDEX1, users.recipient1.addr, CLAIM_AMOUNT); + leaves[1] = MerkleBuilder.computeLeaf(INDEX2, users.recipient2.addr, CLAIM_AMOUNT); + leaves[2] = MerkleBuilder.computeLeaf(INDEX3, users.recipient3.addr, CLAIM_AMOUNT); + leaves[3] = MerkleBuilder.computeLeaf(INDEX4, users.recipient4.addr, CLAIM_AMOUNT); - function merkleRoot() public view returns (bytes32) { - return getRoot(leaves()); + MerkleBuilder.sortLeaves(leaves); + merkleRoot = getRoot(leaves.toBytes32()); } /*////////////////////////////////////////////////////////////////////////// @@ -158,6 +160,8 @@ contract Defaults is Merkle, PermitSignature { CLIFF_TIME = START_TIME + CLIFF_DURATION; END_TIME = START_TIME + TOTAL_DURATION; EXPIRATION = uint40(block.timestamp) + 12 weeks; + + _initMerkleTree(); } /*////////////////////////////////////////////////////////////////////////// diff --git a/test/utils/MerkleBuilder.sol b/test/utils/MerkleBuilder.sol index f9b08844..eb766229 100644 --- a/test/utils/MerkleBuilder.sol +++ b/test/utils/MerkleBuilder.sol @@ -2,11 +2,13 @@ // solhint-disable reason-string pragma solidity >=0.8.19; +import { LibSort } from "solady/utils/LibSort.sol"; + /// @dev A helper library for building Merkle leaves, roots, and proofs. library MerkleBuilder { /// @dev Function that hashes together the data needed for a Merkle tree leaf. - function computeLeaf(uint256 index, address recipient, uint128 amount) internal pure returns (bytes32 leaf) { - leaf = keccak256(bytes.concat(keccak256(abi.encode(index, recipient, amount)))); + function computeLeaf(uint256 index, address recipient, uint128 amount) internal pure returns (uint256 leaf) { + leaf = uint256(keccak256(bytes.concat(keccak256(abi.encode(index, recipient, amount))))); } /// @dev A batch function for `computeLeaf`. @@ -17,72 +19,41 @@ library MerkleBuilder { ) internal pure - returns (bytes32[] memory leaves) + returns (uint256[] memory leaves) { uint256 count = indexes.length; require(count == recipient.length && count == amount.length, "Merkle leaves arrays must have the same length"); - leaves = new bytes32[](count); + leaves = new uint256[](count); for (uint256 i = 0; i < count; ++i) { leaves[i] = computeLeaf(indexes[i], recipient[i], amount[i]); } } - /// @dev Function that binary searchs the position of a specific leaf in a sorted `bytes32` array. - function binarySearch(bytes32[] memory arr, bytes32 val) internal pure returns (uint256) { - uint256 low = 0; - uint256 high = arr.length - 1; - uint256 mid; - - while (low <= high) { - mid = low + (high - low) / 2; + /// @dev Function that sorts a storage array of `uint256` in ascending order. We need this function because + /// `LibSort` does not support storage arrays. + function sortLeaves(uint256[] storage leaves) internal { + uint256 leavesCount = leaves.length; - if (arr[mid] == val) { - return mid; - } - - if (arr[mid] < val) { - low = mid + 1; - } else { - if (mid == 0) { - break; - } - high = mid - 1; - } + // Declare the memory array. + uint256[] memory _leaves = new uint256[](leavesCount); + for (uint256 i = 0; i < leavesCount; ++i) { + _leaves[i] = leaves[i]; } - return mid; - } - - /// @dev Function that sorts an array of `bytes32` in ascending order. - function sort(bytes32[] memory arr) internal pure returns (bytes32[] memory) { - _quickSort(arr, 0, arr.length - 1); - return arr; - } + // Sort the memory array. + LibSort.sort(_leaves); - function _quickSort(bytes32[] memory arr, uint256 i, uint256 j) private pure { - if (i < j) { - uint256 p = _partition(arr, i, j); - if (p > 0) { - _quickSort(arr, i, p - 1); - } - _quickSort(arr, p + 1, j); + // Copy the memory array back to storage. + for (uint256 i = 0; i < leavesCount; ++i) { + leaves[i] = _leaves[i]; } } - function _partition(bytes32[] memory arr, uint256 i, uint256 j) private pure returns (uint256) { - bytes32 pivot = arr[j]; - uint256 low = i; - for (uint256 k = i; k < j; ++k) { - if (arr[k] < pivot) { - _swap(arr, low, k); - ++low; - } + /// @dev Function that converts an array of `uint256` to an array of `bytes32`. + function toBytes32(uint256[] storage _arr) internal view returns (bytes32[] memory arr) { + arr = new bytes32[](_arr.length); + for (uint256 i = 0; i < _arr.length; ++i) { + arr[i] = bytes32(_arr[i]); } - _swap(arr, low, j); - return low; - } - - function _swap(bytes32[] memory arr, uint256 i, uint256 j) private pure { - (arr[i], arr[j]) = (arr[j], arr[i]); } } diff --git a/test/utils/MerkleBuilder.t.sol b/test/utils/MerkleBuilder.t.sol index cfbd5e90..1bfd67ca 100644 --- a/test/utils/MerkleBuilder.t.sol +++ b/test/utils/MerkleBuilder.t.sol @@ -8,8 +8,8 @@ import { MerkleBuilder } from "./MerkleBuilder.sol"; contract MerkleBuilder_Test is PRBTest, StdUtils { function testFuzz_ComputeLeaf(uint256 index, address recipient, uint128 amount) external { - bytes32 actualLeaf = MerkleBuilder.computeLeaf(index, recipient, amount); - bytes32 expectedLeaf = keccak256(bytes.concat(keccak256(abi.encode(index, recipient, amount)))); + uint256 actualLeaf = MerkleBuilder.computeLeaf(index, recipient, amount); + uint256 expectedLeaf = uint256(keccak256(bytes.concat(keccak256(abi.encode(index, recipient, amount))))); assertEq(actualLeaf, expectedLeaf, "computeLeaf"); } @@ -32,12 +32,13 @@ contract MerkleBuilder_Test is PRBTest, StdUtils { amounts[i] = params[i].amounts; } - bytes32[] memory actualLeaves = new bytes32[](count); + uint256[] memory actualLeaves = new uint256[](count); actualLeaves = MerkleBuilder.computeLeaves(indexes, recipients, amounts); - bytes32[] memory expectedLeaves = new bytes32[](count); + uint256[] memory expectedLeaves = new uint256[](count); for (uint256 i = 0; i < count; ++i) { - expectedLeaves[i] = keccak256(bytes.concat(keccak256(abi.encode(indexes[i], recipients[i], amounts[i])))); + expectedLeaves[i] = + uint256(keccak256(bytes.concat(keccak256(abi.encode(indexes[i], recipients[i], amounts[i]))))); } assertEq(actualLeaves, expectedLeaves, "computeLeaves");