Skip to content

Commit

Permalink
refactor: double hash (#206)
Browse files Browse the repository at this point in the history
* refactor: double hash the leaf in claim function

* test: add binarySearch function in MerkleBuilder

test: add sort function in Merkle Builder
test: sort tree leaves in MerkleStreamer tests

* test: update Precompiles bytecode

* test: use solady to sort merkle tree

* test: uppercase constants

chore: improve writing in comments
test: improve variable names
test: remove "_initMerkleTree"

* test: add comment explaining why the Merkle tree leaves are sorted

* test: reorder functions in Defaults

---------

Co-authored-by: Paul Razvan Berg <[email protected]>
  • Loading branch information
andreivladbrg and PaulRBerg authored Oct 12, 2023
1 parent db29e20 commit c58bf3e
Show file tree
Hide file tree
Showing 10 changed files with 147 additions and 84 deletions.
5 changes: 3 additions & 2 deletions src/SablierV2MerkleStreamerLL.sol
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,9 @@ contract SablierV2MerkleStreamerLL is
override
returns (uint256 streamId)
{
// Generate the Merkle tree leaf by hashing the corresponding parameters.
bytes32 leaf = keccak256(abi.encodePacked(index, recipient, amount));
// Generate the Merkle tree leaf by hashing the corresponding parameters. Hashing twice prevents second
// preimage attacks.
bytes32 leaf = keccak256(bytes.concat(keccak256(abi.encode(index, recipient, amount))));

// Checks: validate the function.
_checkClaim(index, leaf, merkleProof);
Expand Down
6 changes: 3 additions & 3 deletions test/Base.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -376,15 +376,15 @@ abstract contract Base_Test is Assertions, Events, Merkle, StdCheats, V2CoreUtil
//////////////////////////////////////////////////////////////////////////*/

function computeMerkleStreamerLLAddress() internal returns (address) {
return computeMerkleStreamerLLAddress(users.admin.addr, defaults.merkleRoot(), defaults.EXPIRATION());
return computeMerkleStreamerLLAddress(users.admin.addr, defaults.MERKLE_ROOT(), defaults.EXPIRATION());
}

function computeMerkleStreamerLLAddress(address admin) internal returns (address) {
return computeMerkleStreamerLLAddress(admin, defaults.merkleRoot(), defaults.EXPIRATION());
return computeMerkleStreamerLLAddress(admin, defaults.MERKLE_ROOT(), defaults.EXPIRATION());
}

function computeMerkleStreamerLLAddress(address admin, uint40 expiration) internal returns (address) {
return computeMerkleStreamerLLAddress(admin, defaults.merkleRoot(), expiration);
return computeMerkleStreamerLLAddress(admin, defaults.MERKLE_ROOT(), expiration);
}

function computeMerkleStreamerLLAddress(address admin, bytes32 merkleRoot) internal returns (address) {
Expand Down
54 changes: 36 additions & 18 deletions test/fork/merkle-streamer/MerkleStreamerLL.t.sol
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// SPDX-License-Identifier: UNLICENSED
pragma solidity >=0.8.19 <0.9.0;

import { Arrays } from "@openzeppelin/contracts/utils/Arrays.sol";
import { IERC20 } from "@openzeppelin/contracts/token/ERC20/IERC20.sol";
import { Lockup, LockupLinear } from "@sablier/v2-core/src/types/DataTypes.sol";

Expand All @@ -10,6 +11,8 @@ import { MerkleBuilder } from "../../utils/MerkleBuilder.sol";
import { Fork_Test } from "../Fork.t.sol";

abstract contract MerkleStreamerLL_Fork_Test is Fork_Test {
using MerkleBuilder for uint256[];

constructor(IERC20 asset_) Fork_Test(asset_) { }

function setUp() public virtual override {
Expand All @@ -27,31 +30,35 @@ abstract contract MerkleStreamerLL_Fork_Test is Fork_Test {
address admin;
uint40 expiration;
LeafData[] leafData;
uint256 leafPos;
uint256 posBeforeSort;
}

struct Vars {
uint256 actualStreamId;
LockupLinear.Stream actualStream;
uint128[] amounts;
ISablierV2MerkleStreamerLL merkleStreamerLL;
uint256 aggregateAmount;
uint128 clawbackAmount;
uint256 recipientsCount;
uint256 expectedStreamId;
address expectedStreamerLL;
LockupLinear.Stream expectedStream;
uint256 expectedStreamId;
uint256[] indexes;
bytes32[] leaves;
uint256 leafPos;
uint256 leafToClaim;
ISablierV2MerkleStreamerLL merkleStreamerLL;
bytes32 merkleRoot;
address[] recipients;
uint256 recipientsCount;
}

// We need the leaves as a storage variable so that we can use OpenZeppelin's {Arrays.findUpperBound}.
uint256[] public leaves;

function testForkFuzz_MerkleStreamerLL(Params memory params) external {
vm.assume(params.admin != address(0) && params.admin != users.admin.addr);
vm.assume(params.expiration == 0 || params.expiration > block.timestamp);
vm.assume(params.leafData.length > 1);
params.leafPos = _bound(params.leafPos, 0, params.leafData.length - 1);
params.posBeforeSort = _bound(params.posBeforeSort, 0, params.leafData.length - 1);
assumeNoBlacklisted({ token: address(asset), addr: params.admin });

/*//////////////////////////////////////////////////////////////////////////
Expand All @@ -75,8 +82,12 @@ abstract contract MerkleStreamerLL_Fork_Test is Fork_Test {
vars.recipients[i] = address(uint160(boundedRecipientSeed));
}

vars.leaves = MerkleBuilder.computeLeaves(vars.indexes, vars.recipients, vars.amounts);
vars.merkleRoot = getRoot(vars.leaves);
leaves = new uint256[](vars.recipientsCount);
leaves = MerkleBuilder.computeLeaves(vars.indexes, vars.recipients, vars.amounts);

// Sort the leaves in ascending order to match the production environment.
MerkleBuilder.sortLeaves(leaves);
vars.merkleRoot = getRoot(leaves.toBytes32());

vars.expectedStreamerLL = computeMerkleStreamerLLAddress(params.admin, vars.merkleRoot, params.expiration);
vm.expectEmit({ emitter: address(merkleStreamerFactory) });
Expand Down Expand Up @@ -123,25 +134,32 @@ abstract contract MerkleStreamerLL_Fork_Test is Fork_Test {
CLAIM
//////////////////////////////////////////////////////////////////////////*/

assertFalse(vars.merkleStreamerLL.hasClaimed(vars.indexes[params.leafPos]));
assertFalse(vars.merkleStreamerLL.hasClaimed(vars.indexes[params.posBeforeSort]));

vars.leafToClaim = MerkleBuilder.computeLeaf(
vars.indexes[params.posBeforeSort],
vars.recipients[params.posBeforeSort],
vars.amounts[params.posBeforeSort]
);
vars.leafPos = Arrays.findUpperBound(leaves, vars.leafToClaim);

vars.expectedStreamId = lockupLinear.nextStreamId();
emit Claim(
vars.indexes[params.leafPos],
vars.recipients[params.leafPos],
vars.amounts[params.leafPos],
vars.indexes[params.posBeforeSort],
vars.recipients[params.posBeforeSort],
vars.amounts[params.posBeforeSort],
vars.expectedStreamId
);
vars.actualStreamId = vars.merkleStreamerLL.claim({
index: vars.indexes[params.leafPos],
recipient: vars.recipients[params.leafPos],
amount: vars.amounts[params.leafPos],
merkleProof: getProof(vars.leaves, params.leafPos)
index: vars.indexes[params.posBeforeSort],
recipient: vars.recipients[params.posBeforeSort],
amount: vars.amounts[params.posBeforeSort],
merkleProof: getProof(leaves.toBytes32(), vars.leafPos)
});

vars.actualStream = lockupLinear.getStream(vars.actualStreamId);
vars.expectedStream = LockupLinear.Stream({
amounts: Lockup.Amounts({ deposited: vars.amounts[params.leafPos], refunded: 0, withdrawn: 0 }),
amounts: Lockup.Amounts({ deposited: vars.amounts[params.posBeforeSort], refunded: 0, withdrawn: 0 }),
asset: asset,
cliffTime: uint40(block.timestamp) + defaults.CLIFF_DURATION(),
endTime: uint40(block.timestamp) + defaults.TOTAL_DURATION(),
Expand All @@ -154,7 +172,7 @@ abstract contract MerkleStreamerLL_Fork_Test is Fork_Test {
wasCanceled: false
});

assertTrue(vars.merkleStreamerLL.hasClaimed(vars.indexes[params.leafPos]));
assertTrue(vars.merkleStreamerLL.hasClaimed(vars.indexes[params.posBeforeSort]));
assertEq(vars.actualStreamId, vars.expectedStreamId);
assertEq(vars.actualStream, vars.expectedStream);

Expand Down
2 changes: 1 addition & 1 deletion test/integration/merkle-streamer/MerkleStreamer.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ abstract contract MerkleStreamer_Integration_Test is Integration_Test {
initialAdmin: admin,
lockupLinear: lockupLinear,
asset: asset,
merkleRoot: defaults.merkleRoot(),
merkleRoot: defaults.MERKLE_ROOT(),
expiration: expiration,
cancelable: defaults.CANCELABLE(),
transferable: defaults.TRANSFERABLE(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ contract CreateMerkleStreamerLL_Integration_Test is MerkleStreamer_Integration_T

/// @dev This test works because a default Merkle streamer is deployed in {Integration_Test.setUp}
function test_RevertGiven_AlreadyDeployed() external {
bytes32 merkleRoot = defaults.merkleRoot();
bytes32 merkleRoot = defaults.MERKLE_ROOT();
uint40 expiration = defaults.EXPIRATION();
bool cancelable = defaults.CANCELABLE();
bool transferable = defaults.TRANSFERABLE();
Expand Down Expand Up @@ -54,7 +54,7 @@ contract CreateMerkleStreamerLL_Integration_Test is MerkleStreamer_Integration_T
admin: admin,
lockupLinear: lockupLinear,
asset: asset,
merkleRoot: defaults.merkleRoot(),
merkleRoot: defaults.MERKLE_ROOT(),
expiration: expiration,
streamDurations: defaults.durations(),
cancelable: defaults.CANCELABLE(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ contract Constructor_MerkleStreamerLL_Integration_Test is MerkleStreamer_Integra
users.admin.addr,
lockupLinear,
asset,
defaults.merkleRoot(),
defaults.MERKLE_ROOT(),
defaults.EXPIRATION(),
defaults.durations(),
defaults.CANCELABLE(),
defaults.CANCELABLE(),
defaults.TRANSFERABLE()
);

Expand All @@ -53,7 +53,7 @@ contract Constructor_MerkleStreamerLL_Integration_Test is MerkleStreamer_Integra
assertEq(vars.actualAsset, vars.expectedAsset, "asset");

vars.actualMerkleRoot = constructedStreamerLL.MERKLE_ROOT();
vars.expectedMerkleRoot = defaults.merkleRoot();
vars.expectedMerkleRoot = defaults.MERKLE_ROOT();
assertEq(vars.actualMerkleRoot, vars.expectedMerkleRoot, "merkleRoot");

vars.actualCancelable = constructedStreamerLL.CANCELABLE();
Expand Down
101 changes: 57 additions & 44 deletions test/utils/Defaults.sol
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// SPDX-License-Identifier: GPL-3.0-or-later
pragma solidity >=0.8.19 <0.9.0;

import { Arrays } from "@openzeppelin/contracts/utils/Arrays.sol";
import { IERC20 } from "@openzeppelin/contracts/token/ERC20/IERC20.sol";
import { IPRBProxy } from "@prb/proxy/src/interfaces/IPRBProxy.sol";
import { ud2x18, UD60x18 } from "@sablier/v2-core/src/types/Math.sol";
Expand All @@ -19,8 +20,10 @@ import { Users } from "./Types.sol";

/// @notice Contract with default values for testing.
contract Defaults is Merkle, PermitSignature {
using MerkleBuilder for uint256[];

/*//////////////////////////////////////////////////////////////////////////
GENERIC CONSTANTS
GENERICS
//////////////////////////////////////////////////////////////////////////*/

uint64 public constant BATCH_SIZE = 10;
Expand Down Expand Up @@ -50,36 +53,72 @@ contract Defaults is Merkle, PermitSignature {
uint256 public constant INDEX2 = 2;
uint256 public constant INDEX3 = 3;
uint256 public constant INDEX4 = 4;
string public IPFS_CID = "QmbWqxBEKC3P8tqsKc98xmWNzrzDtRLMiMPL8wBuTGsMnR";
string public constant IPFS_CID = "QmbWqxBEKC3P8tqsKc98xmWNzrzDtRLMiMPL8wBuTGsMnR";
uint256 public constant RECIPIENTS_COUNT = 4;
bool public constant TRANSFERABLE = false;
uint256[] public LEAVES = new uint256[](RECIPIENTS_COUNT);
bytes32 public immutable MERKLE_ROOT;

/*//////////////////////////////////////////////////////////////////////////
VARIABLES
//////////////////////////////////////////////////////////////////////////*/

IERC20 private asset;
IPRBProxy private proxy;
IAllowanceTransfer private permit2;
Users private users;

/*//////////////////////////////////////////////////////////////////////////
CONSTRUCTOR
//////////////////////////////////////////////////////////////////////////*/

constructor(Users memory users_, IERC20 asset_, IAllowanceTransfer permit2_, IPRBProxy proxy_) {
users = users_;
asset = asset_;
permit2 = permit2_;
proxy = proxy_;

// Initialize the immutables.
START_TIME = uint40(block.timestamp) + 100 seconds;
CLIFF_TIME = START_TIME + CLIFF_DURATION;
END_TIME = START_TIME + TOTAL_DURATION;
EXPIRATION = uint40(block.timestamp) + 12 weeks;

// Initialize the Merkle tree.
LEAVES[0] = MerkleBuilder.computeLeaf(INDEX1, users.recipient1.addr, CLAIM_AMOUNT);
LEAVES[1] = MerkleBuilder.computeLeaf(INDEX2, users.recipient2.addr, CLAIM_AMOUNT);
LEAVES[2] = MerkleBuilder.computeLeaf(INDEX3, users.recipient3.addr, CLAIM_AMOUNT);
LEAVES[3] = MerkleBuilder.computeLeaf(INDEX4, users.recipient4.addr, CLAIM_AMOUNT);
MerkleBuilder.sortLeaves(LEAVES);
MERKLE_ROOT = getRoot(LEAVES.toBytes32());
}

/*//////////////////////////////////////////////////////////////////////////
MERKLE-STREAMER
//////////////////////////////////////////////////////////////////////////*/

function index1Proof() public view returns (bytes32[] memory) {
return getProof(leaves(), 0);
uint256 leaf = MerkleBuilder.computeLeaf(INDEX1, users.recipient1.addr, CLAIM_AMOUNT);
uint256 pos = Arrays.findUpperBound(LEAVES, leaf);
return getProof(LEAVES.toBytes32(), pos);
}

function index2Proof() public view returns (bytes32[] memory) {
return getProof(leaves(), 1);
uint256 leaf = MerkleBuilder.computeLeaf(INDEX2, users.recipient2.addr, CLAIM_AMOUNT);
uint256 pos = Arrays.findUpperBound(LEAVES, leaf);
return getProof(LEAVES.toBytes32(), pos);
}

function index3Proof() public view returns (bytes32[] memory) {
return getProof(leaves(), 2);
uint256 leaf = MerkleBuilder.computeLeaf(INDEX3, users.recipient3.addr, CLAIM_AMOUNT);
uint256 pos = Arrays.findUpperBound(LEAVES, leaf);
return getProof(LEAVES.toBytes32(), pos);
}

function index4Proof() public view returns (bytes32[] memory) {
return getProof(leaves(), 3);
}

function leaves() public view returns (bytes32[] memory leaves_) {
leaves_ = new bytes32[](RECIPIENTS_COUNT);
leaves_[0] = MerkleBuilder.computeLeaf(INDEX1, users.recipient1.addr, CLAIM_AMOUNT);
leaves_[1] = MerkleBuilder.computeLeaf(INDEX2, users.recipient2.addr, CLAIM_AMOUNT);
leaves_[2] = MerkleBuilder.computeLeaf(INDEX3, users.recipient3.addr, CLAIM_AMOUNT);
leaves_[3] = MerkleBuilder.computeLeaf(INDEX4, users.recipient4.addr, CLAIM_AMOUNT);
}

function merkleRoot() public view returns (bytes32) {
return getRoot(leaves());
uint256 leaf = MerkleBuilder.computeLeaf(INDEX4, users.recipient4.addr, CLAIM_AMOUNT);
uint256 pos = Arrays.findUpperBound(LEAVES, leaf);
return getProof(LEAVES.toBytes32(), pos);
}

/*//////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -125,32 +164,6 @@ contract Defaults is Merkle, PermitSignature {
return abi.encode(permit2Params_);
}

/*//////////////////////////////////////////////////////////////////////////
VARIABLES
//////////////////////////////////////////////////////////////////////////*/

IERC20 private asset;
IPRBProxy private proxy;
IAllowanceTransfer private permit2;
Users private users;

/*//////////////////////////////////////////////////////////////////////////
CONSTRUCTOR
//////////////////////////////////////////////////////////////////////////*/

constructor(Users memory users_, IERC20 asset_, IAllowanceTransfer permit2_, IPRBProxy proxy_) {
users = users_;
asset = asset_;
permit2 = permit2_;
proxy = proxy_;

// Initialize the immutables.
START_TIME = uint40(block.timestamp) + 100 seconds;
CLIFF_TIME = START_TIME + CLIFF_DURATION;
END_TIME = START_TIME + TOTAL_DURATION;
EXPIRATION = uint40(block.timestamp) + 12 weeks;
}

/*//////////////////////////////////////////////////////////////////////////
SABLIER-V2-LOCKUP
//////////////////////////////////////////////////////////////////////////*/
Expand Down
Loading

0 comments on commit c58bf3e

Please sign in to comment.