Skip to content

Commit

Permalink
test: use solady to sort merkle tree
Browse files Browse the repository at this point in the history
  • Loading branch information
andreivladbrg committed Oct 10, 2023
1 parent aa5c84f commit f0f08f4
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 86 deletions.
19 changes: 13 additions & 6 deletions test/fork/merkle-streamer/MerkleStreamerLL.t.sol
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// SPDX-License-Identifier: UNLICENSED
pragma solidity >=0.8.19 <0.9.0;

import { Arrays } from "@openzeppelin/contracts/utils/Arrays.sol";
import { IERC20 } from "@openzeppelin/contracts/token/ERC20/IERC20.sol";
import { Lockup, LockupLinear } from "@sablier/v2-core/src/types/DataTypes.sol";

Expand All @@ -10,6 +11,8 @@ import { MerkleBuilder } from "../../utils/MerkleBuilder.sol";
import { Fork_Test } from "../Fork.t.sol";

abstract contract MerkleStreamerLL_Fork_Test is Fork_Test {
using MerkleBuilder for uint256[];

constructor(IERC20 asset_) Fork_Test(asset_) { }

function setUp() public virtual override {
Expand Down Expand Up @@ -41,14 +44,16 @@ abstract contract MerkleStreamerLL_Fork_Test is Fork_Test {
uint256 expectedStreamId;
uint256[] indexes;
uint256 leafPos;
bytes32 leafToClaim;
bytes32[] leaves;
uint256 leafToClaim;
ISablierV2MerkleStreamerLL merkleStreamerLL;
bytes32 merkleRoot;
address[] recipients;
uint256 recipientsCount;
}

// We need the leaves as storage variable so that we can use Arrays.findUpperBound function.
uint256[] public leaves;

function testForkFuzz_MerkleStreamerLL(Params memory params) external {
vm.assume(params.admin != address(0) && params.admin != users.admin.addr);
vm.assume(params.expiration == 0 || params.expiration > block.timestamp);
Expand Down Expand Up @@ -77,8 +82,10 @@ abstract contract MerkleStreamerLL_Fork_Test is Fork_Test {
vars.recipients[i] = address(uint160(boundedRecipientSeed));
}

vars.leaves = MerkleBuilder.sort(MerkleBuilder.computeLeaves(vars.indexes, vars.recipients, vars.amounts));
vars.merkleRoot = getRoot(vars.leaves);
leaves = new uint256[](vars.recipientsCount);
leaves = MerkleBuilder.computeLeaves(vars.indexes, vars.recipients, vars.amounts);
MerkleBuilder.sortLeaves(leaves);
vars.merkleRoot = getRoot(leaves.toBytes32());

vars.expectedStreamerLL = computeMerkleStreamerLLAddress(params.admin, vars.merkleRoot, params.expiration);
vm.expectEmit({ emitter: address(merkleStreamerFactory) });
Expand Down Expand Up @@ -132,7 +139,7 @@ abstract contract MerkleStreamerLL_Fork_Test is Fork_Test {
vars.recipients[params.beforeSortPos],
vars.amounts[params.beforeSortPos]
);
vars.leafPos = MerkleBuilder.binarySearch(vars.leaves, vars.leafToClaim);
vars.leafPos = Arrays.findUpperBound(leaves, vars.leafToClaim);

vars.expectedStreamId = lockupLinear.nextStreamId();
emit Claim(
Expand All @@ -145,7 +152,7 @@ abstract contract MerkleStreamerLL_Fork_Test is Fork_Test {
index: vars.indexes[params.beforeSortPos],
recipient: vars.recipients[params.beforeSortPos],
amount: vars.amounts[params.beforeSortPos],
merkleProof: getProof(vars.leaves, vars.leafPos)
merkleProof: getProof(leaves.toBytes32(), vars.leafPos)
});

vars.actualStream = lockupLinear.getStream(vars.actualStreamId);
Expand Down
48 changes: 26 additions & 22 deletions test/utils/Defaults.sol
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// SPDX-License-Identifier: GPL-3.0-or-later
pragma solidity >=0.8.19 <0.9.0;

import { Arrays } from "@openzeppelin/contracts/utils/Arrays.sol";
import { IERC20 } from "@openzeppelin/contracts/token/ERC20/IERC20.sol";
import { IPRBProxy } from "@prb/proxy/src/interfaces/IPRBProxy.sol";
import { ud2x18, UD60x18 } from "@sablier/v2-core/src/types/Math.sol";
Expand All @@ -19,6 +20,8 @@ import { Users } from "./Types.sol";

/// @notice Contract with default values for testing.
contract Defaults is Merkle, PermitSignature {
using MerkleBuilder for uint256[];

/*//////////////////////////////////////////////////////////////////////////
GENERIC CONSTANTS
//////////////////////////////////////////////////////////////////////////*/
Expand Down Expand Up @@ -51,44 +54,43 @@ contract Defaults is Merkle, PermitSignature {
uint256 public constant INDEX3 = 3;
uint256 public constant INDEX4 = 4;
string public IPFS_CID = "QmbWqxBEKC3P8tqsKc98xmWNzrzDtRLMiMPL8wBuTGsMnR";
uint256[] public leaves = new uint256[](RECIPIENTS_COUNT);
bytes32 public merkleRoot;
uint256 public constant RECIPIENTS_COUNT = 4;
bool public constant TRANSFERABLE = false;

function index1Proof() public view returns (bytes32[] memory) {
bytes32 leaf = MerkleBuilder.computeLeaf(INDEX1, users.recipient1.addr, CLAIM_AMOUNT);
uint256 pos = MerkleBuilder.binarySearch(leaves(), leaf);
return getProof(leaves(), pos);
uint256 leaf = MerkleBuilder.computeLeaf(INDEX1, users.recipient1.addr, CLAIM_AMOUNT);
uint256 pos = Arrays.findUpperBound(leaves, leaf);
return getProof(leaves.toBytes32(), pos);
}

function index2Proof() public view returns (bytes32[] memory) {
bytes32 leaf = MerkleBuilder.computeLeaf(INDEX2, users.recipient2.addr, CLAIM_AMOUNT);
uint256 pos = MerkleBuilder.binarySearch(leaves(), leaf);
return getProof(leaves(), pos);
uint256 leaf = MerkleBuilder.computeLeaf(INDEX2, users.recipient2.addr, CLAIM_AMOUNT);
uint256 pos = Arrays.findUpperBound(leaves, leaf);
return getProof(leaves.toBytes32(), pos);
}

function index3Proof() public view returns (bytes32[] memory) {
bytes32 leaf = MerkleBuilder.computeLeaf(INDEX3, users.recipient3.addr, CLAIM_AMOUNT);
uint256 pos = MerkleBuilder.binarySearch(leaves(), leaf);
return getProof(leaves(), pos);
uint256 leaf = MerkleBuilder.computeLeaf(INDEX3, users.recipient3.addr, CLAIM_AMOUNT);
uint256 pos = Arrays.findUpperBound(leaves, leaf);
return getProof(leaves.toBytes32(), pos);
}

function index4Proof() public view returns (bytes32[] memory) {
bytes32 leaf = MerkleBuilder.computeLeaf(INDEX4, users.recipient4.addr, CLAIM_AMOUNT);
uint256 pos = MerkleBuilder.binarySearch(leaves(), leaf);
return getProof(leaves(), pos);
uint256 leaf = MerkleBuilder.computeLeaf(INDEX4, users.recipient4.addr, CLAIM_AMOUNT);
uint256 pos = Arrays.findUpperBound(leaves, leaf);
return getProof(leaves.toBytes32(), pos);
}

function leaves() public view returns (bytes32[] memory) {
bytes32[] memory leaves_ = new bytes32[](RECIPIENTS_COUNT);
leaves_[0] = MerkleBuilder.computeLeaf(INDEX1, users.recipient1.addr, CLAIM_AMOUNT);
leaves_[1] = MerkleBuilder.computeLeaf(INDEX2, users.recipient2.addr, CLAIM_AMOUNT);
leaves_[2] = MerkleBuilder.computeLeaf(INDEX3, users.recipient3.addr, CLAIM_AMOUNT);
leaves_[3] = MerkleBuilder.computeLeaf(INDEX4, users.recipient4.addr, CLAIM_AMOUNT);
return MerkleBuilder.sort(leaves_);
}
function _initMerkleTree() internal {
leaves[0] = MerkleBuilder.computeLeaf(INDEX1, users.recipient1.addr, CLAIM_AMOUNT);
leaves[1] = MerkleBuilder.computeLeaf(INDEX2, users.recipient2.addr, CLAIM_AMOUNT);
leaves[2] = MerkleBuilder.computeLeaf(INDEX3, users.recipient3.addr, CLAIM_AMOUNT);
leaves[3] = MerkleBuilder.computeLeaf(INDEX4, users.recipient4.addr, CLAIM_AMOUNT);

function merkleRoot() public view returns (bytes32) {
return getRoot(leaves());
MerkleBuilder.sortLeaves(leaves);
merkleRoot = getRoot(leaves.toBytes32());
}

/*//////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -158,6 +160,8 @@ contract Defaults is Merkle, PermitSignature {
CLIFF_TIME = START_TIME + CLIFF_DURATION;
END_TIME = START_TIME + TOTAL_DURATION;
EXPIRATION = uint40(block.timestamp) + 12 weeks;

_initMerkleTree();
}

/*//////////////////////////////////////////////////////////////////////////
Expand Down
77 changes: 24 additions & 53 deletions test/utils/MerkleBuilder.sol
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
// solhint-disable reason-string
pragma solidity >=0.8.19;

import { LibSort } from "solady/utils/LibSort.sol";

/// @dev A helper library for building Merkle leaves, roots, and proofs.
library MerkleBuilder {
/// @dev Function that hashes together the data needed for a Merkle tree leaf.
function computeLeaf(uint256 index, address recipient, uint128 amount) internal pure returns (bytes32 leaf) {
leaf = keccak256(bytes.concat(keccak256(abi.encode(index, recipient, amount))));
function computeLeaf(uint256 index, address recipient, uint128 amount) internal pure returns (uint256 leaf) {
leaf = uint256(keccak256(bytes.concat(keccak256(abi.encode(index, recipient, amount)))));
}

/// @dev A batch function for `computeLeaf`.
Expand All @@ -17,72 +19,41 @@ library MerkleBuilder {
)
internal
pure
returns (bytes32[] memory leaves)
returns (uint256[] memory leaves)
{
uint256 count = indexes.length;
require(count == recipient.length && count == amount.length, "Merkle leaves arrays must have the same length");
leaves = new bytes32[](count);
leaves = new uint256[](count);
for (uint256 i = 0; i < count; ++i) {
leaves[i] = computeLeaf(indexes[i], recipient[i], amount[i]);
}
}

/// @dev Function that binary searchs the position of a specific leaf in a sorted `bytes32` array.
function binarySearch(bytes32[] memory arr, bytes32 val) internal pure returns (uint256) {
uint256 low = 0;
uint256 high = arr.length - 1;
uint256 mid;

while (low <= high) {
mid = low + (high - low) / 2;
/// @dev Function that sorts a storage array of `uint256` in ascending order. We need this function because
/// `LibSort` does not support storage arrays.
function sortLeaves(uint256[] storage leaves) internal {
uint256 leavesCount = leaves.length;

if (arr[mid] == val) {
return mid;
}

if (arr[mid] < val) {
low = mid + 1;
} else {
if (mid == 0) {
break;
}
high = mid - 1;
}
// Declare the memory array.
uint256[] memory _leaves = new uint256[](leavesCount);
for (uint256 i = 0; i < leavesCount; ++i) {
_leaves[i] = leaves[i];
}

return mid;
}

/// @dev Function that sorts an array of `bytes32` in ascending order.
function sort(bytes32[] memory arr) internal pure returns (bytes32[] memory) {
_quickSort(arr, 0, arr.length - 1);
return arr;
}
// Sort the memory array.
LibSort.sort(_leaves);

function _quickSort(bytes32[] memory arr, uint256 i, uint256 j) private pure {
if (i < j) {
uint256 p = _partition(arr, i, j);
if (p > 0) {
_quickSort(arr, i, p - 1);
}
_quickSort(arr, p + 1, j);
// Copy the memory array back to storage.
for (uint256 i = 0; i < leavesCount; ++i) {
leaves[i] = _leaves[i];
}
}

function _partition(bytes32[] memory arr, uint256 i, uint256 j) private pure returns (uint256) {
bytes32 pivot = arr[j];
uint256 low = i;
for (uint256 k = i; k < j; ++k) {
if (arr[k] < pivot) {
_swap(arr, low, k);
++low;
}
/// @dev Function that converts an array of `uint256` to an array of `bytes32`.
function toBytes32(uint256[] storage _arr) internal view returns (bytes32[] memory arr) {
arr = new bytes32[](_arr.length);
for (uint256 i = 0; i < _arr.length; ++i) {
arr[i] = bytes32(_arr[i]);
}
_swap(arr, low, j);
return low;
}

function _swap(bytes32[] memory arr, uint256 i, uint256 j) private pure {
(arr[i], arr[j]) = (arr[j], arr[i]);
}
}
11 changes: 6 additions & 5 deletions test/utils/MerkleBuilder.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ import { MerkleBuilder } from "./MerkleBuilder.sol";

contract MerkleBuilder_Test is PRBTest, StdUtils {
function testFuzz_ComputeLeaf(uint256 index, address recipient, uint128 amount) external {
bytes32 actualLeaf = MerkleBuilder.computeLeaf(index, recipient, amount);
bytes32 expectedLeaf = keccak256(bytes.concat(keccak256(abi.encode(index, recipient, amount))));
uint256 actualLeaf = MerkleBuilder.computeLeaf(index, recipient, amount);
uint256 expectedLeaf = uint256(keccak256(bytes.concat(keccak256(abi.encode(index, recipient, amount)))));
assertEq(actualLeaf, expectedLeaf, "computeLeaf");
}

Expand All @@ -32,12 +32,13 @@ contract MerkleBuilder_Test is PRBTest, StdUtils {
amounts[i] = params[i].amounts;
}

bytes32[] memory actualLeaves = new bytes32[](count);
uint256[] memory actualLeaves = new uint256[](count);
actualLeaves = MerkleBuilder.computeLeaves(indexes, recipients, amounts);

bytes32[] memory expectedLeaves = new bytes32[](count);
uint256[] memory expectedLeaves = new uint256[](count);
for (uint256 i = 0; i < count; ++i) {
expectedLeaves[i] = keccak256(bytes.concat(keccak256(abi.encode(indexes[i], recipients[i], amounts[i]))));
expectedLeaves[i] =
uint256(keccak256(bytes.concat(keccak256(abi.encode(indexes[i], recipients[i], amounts[i])))));
}

assertEq(actualLeaves, expectedLeaves, "computeLeaves");
Expand Down

0 comments on commit f0f08f4

Please sign in to comment.