Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use solady to sort merkle tree #207

Merged
merged 1 commit into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions test/fork/merkle-streamer/MerkleStreamerLL.t.sol
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// SPDX-License-Identifier: UNLICENSED
pragma solidity >=0.8.19 <0.9.0;

import { Arrays } from "@openzeppelin/contracts/utils/Arrays.sol";
import { IERC20 } from "@openzeppelin/contracts/token/ERC20/IERC20.sol";
import { Lockup, LockupLinear } from "@sablier/v2-core/src/types/DataTypes.sol";

Expand All @@ -10,6 +11,8 @@ import { MerkleBuilder } from "../../utils/MerkleBuilder.sol";
import { Fork_Test } from "../Fork.t.sol";

abstract contract MerkleStreamerLL_Fork_Test is Fork_Test {
using MerkleBuilder for uint256[];

constructor(IERC20 asset_) Fork_Test(asset_) { }

function setUp() public virtual override {
Expand Down Expand Up @@ -41,14 +44,16 @@ abstract contract MerkleStreamerLL_Fork_Test is Fork_Test {
uint256 expectedStreamId;
uint256[] indexes;
uint256 leafPos;
bytes32 leafToClaim;
bytes32[] leaves;
uint256 leafToClaim;
ISablierV2MerkleStreamerLL merkleStreamerLL;
bytes32 merkleRoot;
address[] recipients;
uint256 recipientsCount;
}

// We need the leaves as storage variable so that we can use Arrays.findUpperBound function.
uint256[] public leaves;

function testForkFuzz_MerkleStreamerLL(Params memory params) external {
vm.assume(params.admin != address(0) && params.admin != users.admin.addr);
vm.assume(params.expiration == 0 || params.expiration > block.timestamp);
Expand Down Expand Up @@ -77,8 +82,10 @@ abstract contract MerkleStreamerLL_Fork_Test is Fork_Test {
vars.recipients[i] = address(uint160(boundedRecipientSeed));
}

vars.leaves = MerkleBuilder.sort(MerkleBuilder.computeLeaves(vars.indexes, vars.recipients, vars.amounts));
vars.merkleRoot = getRoot(vars.leaves);
leaves = new uint256[](vars.recipientsCount);
leaves = MerkleBuilder.computeLeaves(vars.indexes, vars.recipients, vars.amounts);
MerkleBuilder.sortLeaves(leaves);
vars.merkleRoot = getRoot(leaves.toBytes32());

vars.expectedStreamerLL = computeMerkleStreamerLLAddress(params.admin, vars.merkleRoot, params.expiration);
vm.expectEmit({ emitter: address(merkleStreamerFactory) });
Expand Down Expand Up @@ -132,7 +139,7 @@ abstract contract MerkleStreamerLL_Fork_Test is Fork_Test {
vars.recipients[params.beforeSortPos],
vars.amounts[params.beforeSortPos]
);
vars.leafPos = MerkleBuilder.binarySearch(vars.leaves, vars.leafToClaim);
vars.leafPos = Arrays.findUpperBound(leaves, vars.leafToClaim);

vars.expectedStreamId = lockupLinear.nextStreamId();
emit Claim(
Expand All @@ -145,7 +152,7 @@ abstract contract MerkleStreamerLL_Fork_Test is Fork_Test {
index: vars.indexes[params.beforeSortPos],
recipient: vars.recipients[params.beforeSortPos],
amount: vars.amounts[params.beforeSortPos],
merkleProof: getProof(vars.leaves, vars.leafPos)
merkleProof: getProof(leaves.toBytes32(), vars.leafPos)
});

vars.actualStream = lockupLinear.getStream(vars.actualStreamId);
Expand Down
48 changes: 26 additions & 22 deletions test/utils/Defaults.sol
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// SPDX-License-Identifier: GPL-3.0-or-later
pragma solidity >=0.8.19 <0.9.0;

import { Arrays } from "@openzeppelin/contracts/utils/Arrays.sol";
import { IERC20 } from "@openzeppelin/contracts/token/ERC20/IERC20.sol";
import { IPRBProxy } from "@prb/proxy/src/interfaces/IPRBProxy.sol";
import { ud2x18, UD60x18 } from "@sablier/v2-core/src/types/Math.sol";
Expand All @@ -19,6 +20,8 @@ import { Users } from "./Types.sol";

/// @notice Contract with default values for testing.
contract Defaults is Merkle, PermitSignature {
using MerkleBuilder for uint256[];

/*//////////////////////////////////////////////////////////////////////////
GENERIC CONSTANTS
//////////////////////////////////////////////////////////////////////////*/
Expand Down Expand Up @@ -51,44 +54,43 @@ contract Defaults is Merkle, PermitSignature {
uint256 public constant INDEX3 = 3;
uint256 public constant INDEX4 = 4;
string public IPFS_CID = "QmbWqxBEKC3P8tqsKc98xmWNzrzDtRLMiMPL8wBuTGsMnR";
uint256[] public leaves = new uint256[](RECIPIENTS_COUNT);
bytes32 public merkleRoot;
uint256 public constant RECIPIENTS_COUNT = 4;
bool public constant TRANSFERABLE = false;

function index1Proof() public view returns (bytes32[] memory) {
bytes32 leaf = MerkleBuilder.computeLeaf(INDEX1, users.recipient1.addr, CLAIM_AMOUNT);
uint256 pos = MerkleBuilder.binarySearch(leaves(), leaf);
return getProof(leaves(), pos);
uint256 leaf = MerkleBuilder.computeLeaf(INDEX1, users.recipient1.addr, CLAIM_AMOUNT);
uint256 pos = Arrays.findUpperBound(leaves, leaf);
return getProof(leaves.toBytes32(), pos);
}

function index2Proof() public view returns (bytes32[] memory) {
bytes32 leaf = MerkleBuilder.computeLeaf(INDEX2, users.recipient2.addr, CLAIM_AMOUNT);
uint256 pos = MerkleBuilder.binarySearch(leaves(), leaf);
return getProof(leaves(), pos);
uint256 leaf = MerkleBuilder.computeLeaf(INDEX2, users.recipient2.addr, CLAIM_AMOUNT);
uint256 pos = Arrays.findUpperBound(leaves, leaf);
return getProof(leaves.toBytes32(), pos);
}

function index3Proof() public view returns (bytes32[] memory) {
bytes32 leaf = MerkleBuilder.computeLeaf(INDEX3, users.recipient3.addr, CLAIM_AMOUNT);
uint256 pos = MerkleBuilder.binarySearch(leaves(), leaf);
return getProof(leaves(), pos);
uint256 leaf = MerkleBuilder.computeLeaf(INDEX3, users.recipient3.addr, CLAIM_AMOUNT);
uint256 pos = Arrays.findUpperBound(leaves, leaf);
return getProof(leaves.toBytes32(), pos);
}

function index4Proof() public view returns (bytes32[] memory) {
bytes32 leaf = MerkleBuilder.computeLeaf(INDEX4, users.recipient4.addr, CLAIM_AMOUNT);
uint256 pos = MerkleBuilder.binarySearch(leaves(), leaf);
return getProof(leaves(), pos);
uint256 leaf = MerkleBuilder.computeLeaf(INDEX4, users.recipient4.addr, CLAIM_AMOUNT);
uint256 pos = Arrays.findUpperBound(leaves, leaf);
return getProof(leaves.toBytes32(), pos);
}

function leaves() public view returns (bytes32[] memory) {
bytes32[] memory leaves_ = new bytes32[](RECIPIENTS_COUNT);
leaves_[0] = MerkleBuilder.computeLeaf(INDEX1, users.recipient1.addr, CLAIM_AMOUNT);
leaves_[1] = MerkleBuilder.computeLeaf(INDEX2, users.recipient2.addr, CLAIM_AMOUNT);
leaves_[2] = MerkleBuilder.computeLeaf(INDEX3, users.recipient3.addr, CLAIM_AMOUNT);
leaves_[3] = MerkleBuilder.computeLeaf(INDEX4, users.recipient4.addr, CLAIM_AMOUNT);
return MerkleBuilder.sort(leaves_);
}
function _initMerkleTree() internal {
leaves[0] = MerkleBuilder.computeLeaf(INDEX1, users.recipient1.addr, CLAIM_AMOUNT);
leaves[1] = MerkleBuilder.computeLeaf(INDEX2, users.recipient2.addr, CLAIM_AMOUNT);
leaves[2] = MerkleBuilder.computeLeaf(INDEX3, users.recipient3.addr, CLAIM_AMOUNT);
leaves[3] = MerkleBuilder.computeLeaf(INDEX4, users.recipient4.addr, CLAIM_AMOUNT);

function merkleRoot() public view returns (bytes32) {
return getRoot(leaves());
MerkleBuilder.sortLeaves(leaves);
merkleRoot = getRoot(leaves.toBytes32());
}

/*//////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -158,6 +160,8 @@ contract Defaults is Merkle, PermitSignature {
CLIFF_TIME = START_TIME + CLIFF_DURATION;
END_TIME = START_TIME + TOTAL_DURATION;
EXPIRATION = uint40(block.timestamp) + 12 weeks;

_initMerkleTree();
}

/*//////////////////////////////////////////////////////////////////////////
Expand Down
77 changes: 24 additions & 53 deletions test/utils/MerkleBuilder.sol
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
// solhint-disable reason-string
pragma solidity >=0.8.19;

import { LibSort } from "solady/utils/LibSort.sol";

/// @dev A helper library for building Merkle leaves, roots, and proofs.
library MerkleBuilder {
/// @dev Function that hashes together the data needed for a Merkle tree leaf.
function computeLeaf(uint256 index, address recipient, uint128 amount) internal pure returns (bytes32 leaf) {
leaf = keccak256(bytes.concat(keccak256(abi.encode(index, recipient, amount))));
function computeLeaf(uint256 index, address recipient, uint128 amount) internal pure returns (uint256 leaf) {
leaf = uint256(keccak256(bytes.concat(keccak256(abi.encode(index, recipient, amount)))));
}

/// @dev A batch function for `computeLeaf`.
Expand All @@ -17,72 +19,41 @@ library MerkleBuilder {
)
internal
pure
returns (bytes32[] memory leaves)
returns (uint256[] memory leaves)
{
uint256 count = indexes.length;
require(count == recipient.length && count == amount.length, "Merkle leaves arrays must have the same length");
leaves = new bytes32[](count);
leaves = new uint256[](count);
for (uint256 i = 0; i < count; ++i) {
leaves[i] = computeLeaf(indexes[i], recipient[i], amount[i]);
}
}

/// @dev Function that binary searchs the position of a specific leaf in a sorted `bytes32` array.
function binarySearch(bytes32[] memory arr, bytes32 val) internal pure returns (uint256) {
uint256 low = 0;
uint256 high = arr.length - 1;
uint256 mid;

while (low <= high) {
mid = low + (high - low) / 2;
/// @dev Function that sorts a storage array of `uint256` in ascending order. We need this function because
/// `LibSort` does not support storage arrays.
function sortLeaves(uint256[] storage leaves) internal {
uint256 leavesCount = leaves.length;

if (arr[mid] == val) {
return mid;
}

if (arr[mid] < val) {
low = mid + 1;
} else {
if (mid == 0) {
break;
}
high = mid - 1;
}
// Declare the memory array.
uint256[] memory _leaves = new uint256[](leavesCount);
for (uint256 i = 0; i < leavesCount; ++i) {
_leaves[i] = leaves[i];
}

return mid;
}

/// @dev Function that sorts an array of `bytes32` in ascending order.
function sort(bytes32[] memory arr) internal pure returns (bytes32[] memory) {
_quickSort(arr, 0, arr.length - 1);
return arr;
}
// Sort the memory array.
LibSort.sort(_leaves);

function _quickSort(bytes32[] memory arr, uint256 i, uint256 j) private pure {
if (i < j) {
uint256 p = _partition(arr, i, j);
if (p > 0) {
_quickSort(arr, i, p - 1);
}
_quickSort(arr, p + 1, j);
// Copy the memory array back to storage.
for (uint256 i = 0; i < leavesCount; ++i) {
leaves[i] = _leaves[i];
}
}

function _partition(bytes32[] memory arr, uint256 i, uint256 j) private pure returns (uint256) {
bytes32 pivot = arr[j];
uint256 low = i;
for (uint256 k = i; k < j; ++k) {
if (arr[k] < pivot) {
_swap(arr, low, k);
++low;
}
/// @dev Function that converts an array of `uint256` to an array of `bytes32`.
function toBytes32(uint256[] storage _arr) internal view returns (bytes32[] memory arr) {
arr = new bytes32[](_arr.length);
for (uint256 i = 0; i < _arr.length; ++i) {
arr[i] = bytes32(_arr[i]);
}
_swap(arr, low, j);
return low;
}

function _swap(bytes32[] memory arr, uint256 i, uint256 j) private pure {
(arr[i], arr[j]) = (arr[j], arr[i]);
}
}
11 changes: 6 additions & 5 deletions test/utils/MerkleBuilder.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ import { MerkleBuilder } from "./MerkleBuilder.sol";

contract MerkleBuilder_Test is PRBTest, StdUtils {
function testFuzz_ComputeLeaf(uint256 index, address recipient, uint128 amount) external {
bytes32 actualLeaf = MerkleBuilder.computeLeaf(index, recipient, amount);
bytes32 expectedLeaf = keccak256(bytes.concat(keccak256(abi.encode(index, recipient, amount))));
uint256 actualLeaf = MerkleBuilder.computeLeaf(index, recipient, amount);
uint256 expectedLeaf = uint256(keccak256(bytes.concat(keccak256(abi.encode(index, recipient, amount)))));
assertEq(actualLeaf, expectedLeaf, "computeLeaf");
}

Expand All @@ -32,12 +32,13 @@ contract MerkleBuilder_Test is PRBTest, StdUtils {
amounts[i] = params[i].amounts;
}

bytes32[] memory actualLeaves = new bytes32[](count);
uint256[] memory actualLeaves = new uint256[](count);
actualLeaves = MerkleBuilder.computeLeaves(indexes, recipients, amounts);

bytes32[] memory expectedLeaves = new bytes32[](count);
uint256[] memory expectedLeaves = new uint256[](count);
for (uint256 i = 0; i < count; ++i) {
expectedLeaves[i] = keccak256(bytes.concat(keccak256(abi.encode(indexes[i], recipients[i], amounts[i]))));
expectedLeaves[i] =
uint256(keccak256(bytes.concat(keccak256(abi.encode(indexes[i], recipients[i], amounts[i])))));
}

assertEq(actualLeaves, expectedLeaves, "computeLeaves");
Expand Down