Skip to content

Commit

Permalink
feat: refactor batch to revert and return values (#1126)
Browse files Browse the repository at this point in the history
* feat: bubble up revert in batch function
feat: return results in batch function
test: comprehensive testing for batch with Lockup
test: move allowToHook call to Integration

* docs: polish natspecs

* test: unit tests for Batch

* docs: polish comments

test: refactor BatchMock
test: refactor branches in batch tests
test: rename initialize function

* test: fix failing test

* test: polish comments

* chore: polish batch comments and tests

Co-authored-by: andreivladbrg <[email protected]>

---------

Co-authored-by: Paul Razvan Berg <[email protected]>
Co-authored-by: andreivladbrg <[email protected]>
  • Loading branch information
3 people authored Jan 5, 2025
1 parent b88ead3 commit 08ee6e6
Show file tree
Hide file tree
Showing 19 changed files with 446 additions and 102 deletions.
22 changes: 17 additions & 5 deletions src/abstracts/Batch.sol
Original file line number Diff line number Diff line change
@@ -1,27 +1,39 @@
// SPDX-License-Identifier: GPL-3.0-or-later
// solhint-disable no-inline-assembly
pragma solidity >=0.8.22;

import { IBatch } from "../interfaces/IBatch.sol";
import { Errors } from "../libraries/Errors.sol";

/// @title Batch
/// @notice See the documentation in {IBatch}.
/// @dev Forked from: https://github.com/boringcrypto/BoringSolidity/blob/master/contracts/BoringBatchable.sol
abstract contract Batch is IBatch {
/*//////////////////////////////////////////////////////////////////////////
USER-FACING NON-CONSTANT FUNCTIONS
//////////////////////////////////////////////////////////////////////////*/

/// @inheritdoc IBatch
/// @dev The `msg.value` should not be used on any method called in the batch.
function batch(bytes[] calldata calls) external payable override {
/// @dev Since `msg.value` can be reused across calls, be VERY CAREFUL when using it. Refer to
/// https://paradigm.xyz/2021/08/two-rights-might-make-a-wrong for more information.
function batch(bytes[] calldata calls) external payable override returns (bytes[] memory results) {
uint256 count = calls.length;
results = new bytes[](count);

for (uint256 i = 0; i < count; ++i) {
(bool success, bytes memory result) = address(this).delegatecall(calls[i]);

// Check: If the delegatecall failed, load and bubble up the revert data.
if (!success) {
revert Errors.BatchError(result);
assembly {
// Get the length of the result stored in the first 32 bytes.
let resultSize := mload(result)

// Forward the pointer by 32 bytes to skip the length argument, and revert with the result.
revert(add(32, result), resultSize)
}
}

// Push the result into the results array.
results[i] = result;
}
}
}
7 changes: 5 additions & 2 deletions src/interfaces/IBatch.sol
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ pragma solidity >=0.8.22;

/// @notice This contract implements logic to batch call any function.
interface IBatch {
/// @notice Allows batched call to self, `this` contract.
/// @notice Allows batched calls to self, i.e., `this` contract.
/// @dev Since `msg.value` can be reused across calls, be VERY CAREFUL when using it. Refer to
/// https://paradigm.xyz/2021/08/two-rights-might-make-a-wrong for more information.
/// @param calls An array of inputs for each call.
function batch(bytes[] calldata calls) external payable;
/// @return results An array of results from each call. Empty when the calls do not return anything.
function batch(bytes[] calldata calls) external payable returns (bytes[] memory results);
}
58 changes: 36 additions & 22 deletions tests/integration/Integration.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -70,16 +70,8 @@ abstract contract Integration_Test is Base_Test {
function setUp() public virtual override {
Base_Test.setUp();

recipientInterfaceIDIncorrect = new RecipientInterfaceIDIncorrect();
recipientInterfaceIDMissing = new RecipientInterfaceIDMissing();
recipientInvalidSelector = new RecipientInvalidSelector();
recipientReentrant = new RecipientReentrant();
recipientReverting = new RecipientReverting();
vm.label({ account: address(recipientInterfaceIDIncorrect), newLabel: "Recipient Interface ID Incorrect" });
vm.label({ account: address(recipientInterfaceIDMissing), newLabel: "Recipient Interface ID Missing" });
vm.label({ account: address(recipientInvalidSelector), newLabel: "Recipient Invalid Selector" });
vm.label({ account: address(recipientReentrant), newLabel: "Recipient Reentrant" });
vm.label({ account: address(recipientReverting), newLabel: "Recipient Reverting" });
// Initialize the recipients with Hook implementations.
initializeRecipientsWithHooks();

_defaultParams.createWithTimestamps = defaults.createWithTimestamps();
_defaultParams.createWithDurations = defaults.createWithDurations();
Expand All @@ -104,8 +96,40 @@ abstract contract Integration_Test is Base_Test {
// Set the default Lockup model as Dynamic, we will override the default stream IDs where necessary.
lockupModel = Lockup.Model.LOCKUP_DYNAMIC;

// Initialize default streams IDs.
initializeDefaultStreamIds();
// Initialize default streams.
initializeDefaultStreams();
}

/*//////////////////////////////////////////////////////////////////////////
INITIALIZE-FUNCTIONS
//////////////////////////////////////////////////////////////////////////*/

function initializeDefaultStreams() internal {
defaultStreamId = createDefaultStream();
notCancelableStreamId = createDefaultStreamNonCancelable();
notTransferableStreamId = createDefaultStreamNonTransferable();
recipientGoodStreamId = createDefaultStreamWithRecipient(address(recipientGood));
recipientInvalidSelectorStreamId = createDefaultStreamWithRecipient(address(recipientInvalidSelector));
recipientReentrantStreamId = createDefaultStreamWithRecipient(address(recipientReentrant));
recipientRevertStreamId = createDefaultStreamWithRecipient(address(recipientReverting));
}

function initializeRecipientsWithHooks() internal {
recipientInterfaceIDIncorrect = new RecipientInterfaceIDIncorrect();
recipientInterfaceIDMissing = new RecipientInterfaceIDMissing();
recipientInvalidSelector = new RecipientInvalidSelector();
recipientReentrant = new RecipientReentrant();
recipientReverting = new RecipientReverting();
vm.label({ account: address(recipientInterfaceIDIncorrect), newLabel: "Recipient Interface ID Incorrect" });
vm.label({ account: address(recipientInterfaceIDMissing), newLabel: "Recipient Interface ID Missing" });
vm.label({ account: address(recipientInvalidSelector), newLabel: "Recipient Invalid Selector" });
vm.label({ account: address(recipientReentrant), newLabel: "Recipient Reentrant" });
vm.label({ account: address(recipientReverting), newLabel: "Recipient Reverting" });

// Allow the recipients to Hook.
resetPrank({ msgSender: users.admin });
lockup.allowToHook(address(recipientReverting));
resetPrank({ msgSender: users.sender });
}

/*//////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -179,16 +203,6 @@ abstract contract Integration_Test is Base_Test {
streamId = createDefaultStream(params);
}

function initializeDefaultStreamIds() internal {
defaultStreamId = createDefaultStream();
notCancelableStreamId = createDefaultStreamNonCancelable();
notTransferableStreamId = createDefaultStreamNonTransferable();
recipientGoodStreamId = createDefaultStreamWithRecipient(address(recipientGood));
recipientInvalidSelectorStreamId = createDefaultStreamWithRecipient(address(recipientInvalidSelector));
recipientReentrantStreamId = createDefaultStreamWithRecipient(address(recipientReentrant));
recipientRevertStreamId = createDefaultStreamWithRecipient(address(recipientReverting));
}

/*//////////////////////////////////////////////////////////////////////////
COMMON-REVERT-TESTS
//////////////////////////////////////////////////////////////////////////*/
Expand Down
165 changes: 165 additions & 0 deletions tests/integration/concrete/batch/batch.t.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
// SPDX-License-Identifier: UNLICENSED
pragma solidity >=0.8.22;

import { Solarray } from "solarray/src/Solarray.sol";
import { Errors } from "src/libraries/Errors.sol";

import { Integration_Test } from "../../Integration.t.sol";

contract Batch_Integration_Concrete_Test is Integration_Test {
/*//////////////////////////////////////////////////////////////////////////
BATCH + LOCKUP
//////////////////////////////////////////////////////////////////////////*/

/// @dev The batch call cancels a non-cancelable stream.
function test_RevertWhen_LockupThrows() external {
bytes[] memory calls = new bytes[](2);
calls[0] = abi.encodeCall(lockup.cancel, (defaultStreamId));
calls[1] = abi.encodeCall(lockup.cancel, (notCancelableStreamId));

// Expect revert on notCancelableStreamId.
vm.expectRevert(
abi.encodeWithSelector(Errors.SablierLockupBase_StreamNotCancelable.selector, notCancelableStreamId)
);
lockup.batch(calls);
}

/// @dev The batch call includes:
/// - Returning state changing functions
/// - Non-returning state changing functions
/// - View only functions
function test_Batch_StateChangingAndViewFunctions() external {
uint256 expectedNextStreamId = lockup.nextStreamId();
vm.warp(defaults.WARP_26_PERCENT());

bytes[] memory calls = new bytes[](6);
// It should return True.
calls[0] = abi.encodeCall(lockup.isCancelable, (defaultStreamId));
// It should return the withdrawn amount.
calls[1] = abi.encodeCall(lockup.withdrawMax, (notCancelableStreamId, users.recipient));
// It should return nothing.
calls[2] = abi.encodeCall(lockup.cancel, (defaultStreamId));
// It should return the next stream ID.
calls[3] = abi.encodeCall(lockup.nextStreamId, ());
// It should return the stream ID created.
calls[4] = abi.encodeCall(
lockup.createWithTimestampsLL,
(defaults.createWithTimestamps(), defaults.unlockAmounts(), defaults.CLIFF_TIME())
);
// It should return nothing.
calls[5] = abi.encodeCall(lockup.renounce, (notTransferableStreamId));

bytes[] memory results = lockup.batch(calls);
assertEq(results.length, 6, "batch results length");
assertTrue(abi.decode(results[0], (bool)), "batch results[0]");
assertEq(abi.decode(results[1], (uint128)), defaults.WITHDRAW_AMOUNT(), "batch results[1]");
assertEq(results[2], "", "batch results[2]");
assertEq(abi.decode(results[3], (uint256)), expectedNextStreamId, "batch results[3]");
assertEq(abi.decode(results[4], (uint256)), expectedNextStreamId, "batch results[4]");
assertEq(results[5], "", "batch results[5]");
}

/// @dev The batch call includes:
/// - ETH value
/// - All create stream functions that return a value
function test_BatchPayable_CreateStreams() external {
uint256 expectedNextStreamId = lockup.nextStreamId();
uint256 initialEthBalance = address(lockup).balance;

bytes[] memory calls = new bytes[](6);
calls[0] = abi.encodeCall(
lockup.createWithDurationsLD, (defaults.createWithDurations(), defaults.segmentsWithDurations())
);
calls[1] = abi.encodeCall(
lockup.createWithDurationsLL,
(defaults.createWithDurations(), defaults.unlockAmounts(), defaults.durations())
);
calls[2] = abi.encodeCall(
lockup.createWithDurationsLT, (defaults.createWithDurations(), defaults.tranchesWithDurations())
);
calls[3] = abi.encodeCall(lockup.createWithTimestampsLD, (defaults.createWithTimestamps(), defaults.segments()));
calls[4] = abi.encodeCall(
lockup.createWithTimestampsLL,
(defaults.createWithTimestamps(), defaults.unlockAmounts(), defaults.CLIFF_TIME())
);
calls[5] = abi.encodeCall(lockup.createWithTimestampsLT, (defaults.createWithTimestamps(), defaults.tranches()));

// It should return the stream IDs created.
bytes[] memory results = lockup.batch{ value: 1 wei }(calls);
assertEq(results.length, 6, "batch results length");
assertEq(abi.decode(results[0], (uint256)), expectedNextStreamId, "batch results[0]");
assertEq(abi.decode(results[1], (uint256)), expectedNextStreamId + 1, "batch results[1]");
assertEq(abi.decode(results[2], (uint256)), expectedNextStreamId + 2, "batch results[2]");
assertEq(abi.decode(results[3], (uint256)), expectedNextStreamId + 3, "batch results[3]");
assertEq(abi.decode(results[4], (uint256)), expectedNextStreamId + 4, "batch results[4]");
assertEq(abi.decode(results[5], (uint256)), expectedNextStreamId + 5, "batch results[5]");
assertEq(address(lockup).balance, initialEthBalance + 1 wei, "lockup contract balance");
}

/// @dev The batch call includes:
/// - ETH value
/// - All recipient related functions with both returns and non-returns
function test_BatchPayable_RecipientFunctions() external {
uint256 initialEthBalance = address(lockup).balance;
vm.warp(defaults.WARP_26_PERCENT());

bytes[] memory calls = new bytes[](4);
calls[0] = abi.encodeCall(lockup.cancel, (defaultStreamId));

uint256[] memory streamIds = new uint256[](2);
streamIds[0] = recipientGoodStreamId;
streamIds[1] = recipientInvalidSelectorStreamId;
calls[1] = abi.encodeCall(lockup.cancelMultiple, (streamIds));

calls[2] = abi.encodeCall(lockup.renounce, (recipientReentrantStreamId));

streamIds = new uint256[](1);
streamIds[0] = recipientRevertStreamId;
calls[3] = abi.encodeCall(lockup.renounceMultiple, (streamIds));

bytes[] memory results = lockup.batch{ value: 1 wei }(calls);

assertEq(results.length, 4, "batch results length");
assertEq(results[0], "", "batch results[0]");
assertEq(results[1], "", "batch results[1]");
assertEq(results[2], "", "batch results[2]");
assertEq(results[3], "", "batch results[3]");
assertEq(address(lockup).balance, initialEthBalance + 1 wei, "lockup contract balance");
}

/// @dev The batch call includes:
/// - ETH value
/// - All sender related functions with both returns and non-returns
function test_BatchPayable_SenderFunctions() external {
uint256 initialEthBalance = address(lockup).balance;
// Warp to the end time so that `burn` can be added to the call list.
vm.warp(defaults.END_TIME());

bytes[] memory calls = new bytes[](5);
// It should return nothing.
calls[0] = abi.encodeCall(lockup.withdraw, (defaultStreamId, users.recipient, 1));
// It should return the withdrawn amount.
calls[1] = abi.encodeCall(lockup.withdrawMax, (defaultStreamId, users.recipient));

uint256[] memory streamIds = Solarray.uint256s(notCancelableStreamId, notCancelableStreamId);
uint128[] memory amounts = Solarray.uint128s(1, 1);

// It should return nothing.
calls[2] = abi.encodeCall(lockup.withdrawMultiple, (streamIds, amounts));
// It should return the withdrawn amount.
calls[3] = abi.encodeCall(lockup.withdrawMaxAndTransfer, (notCancelableStreamId, users.recipient));
// It should return nothing.
calls[4] = abi.encodeCall(lockup.burn, (defaultStreamId));

resetPrank({ msgSender: users.recipient });
bytes[] memory results = lockup.batch{ value: 1 wei }(calls);

assertEq(results.length, 5, "batch results length");
assertEq(results[0], "", "batch results[0]");
assertEq(abi.decode(results[1], (uint128)), defaults.DEPOSIT_AMOUNT() - 1, "batch results[1]");
assertEq(results[2], "", "batch results[2]");
assertEq(abi.decode(results[3], (uint128)), defaults.DEPOSIT_AMOUNT() - 2, "batch results[3]");
assertEq(results[4], "", "batch results[4]");
assertEq(address(lockup).balance, initialEthBalance + 1 wei, "lockup contract balance");
}
}
41 changes: 0 additions & 41 deletions tests/integration/concrete/lockup-base/batch/batch.t.sol

This file was deleted.

10 changes: 0 additions & 10 deletions tests/integration/concrete/lockup-base/batch/batch.tree

This file was deleted.

5 changes: 0 additions & 5 deletions tests/integration/concrete/lockup-base/cancel/cancel.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,6 @@ abstract contract Cancel_Integration_Concrete_Test is Integration_Test {
givenSTREAMINGStatus
givenRecipientAllowedToHook
{
// Allow the recipient to hook.
resetPrank({ msgSender: users.admin });
lockup.allowToHook(address(recipientReverting));
resetPrank({ msgSender: users.sender });

// It should revert.
vm.expectRevert("You shall not pass");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ contract RenounceMultiple_Integration_Concrete_Test is Integration_Test {
lockup.renounceMultiple(nullStreamIds);
}

function test_RevertGiven_AtleastOneNullStream() external whenNoDelegateCall whenNonZeroArrayLength {
function test_RevertGiven_AtLeastOneNullStream() external whenNoDelegateCall whenNonZeroArrayLength {
expectRevert_Null({
callData: abi.encodeCall(lockup.renounceMultiple, Solarray.uint256s(streamIds[0], nullStreamId))
});
}

function test_RevertGiven_AtleastOneColdStream()
function test_RevertGiven_AtLeastOneColdStream()
external
whenNoDelegateCall
whenNonZeroArrayLength
Expand Down Expand Up @@ -68,7 +68,7 @@ contract RenounceMultiple_Integration_Concrete_Test is Integration_Test {
lockup.renounceMultiple(streamIds);
}

function test_RevertGiven_AtleastOneNonCancelableStream()
function test_RevertGiven_AtLeastOneNonCancelableStream()
external
whenNoDelegateCall
whenNonZeroArrayLength
Expand Down
Loading

0 comments on commit 08ee6e6

Please sign in to comment.