Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: refactor batch to revert and return values #1126

Merged
merged 7 commits into from
Jan 5, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions src/abstracts/Batch.sol
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
// SPDX-License-Identifier: GPL-3.0-or-later
// solhint-disable no-inline-assembly
pragma solidity >=0.8.22;

import { IBatch } from "../interfaces/IBatch.sol";
import { Errors } from "../libraries/Errors.sol";

/// @title Batch
/// @notice See the documentation in {IBatch}.
Expand All @@ -13,15 +13,28 @@ abstract contract Batch is IBatch {
//////////////////////////////////////////////////////////////////////////*/

/// @inheritdoc IBatch
/// @dev The `msg.value` should not be used on any method called in the batch.
function batch(bytes[] calldata calls) external payable override {
/// @dev Since `msg.value` can be reused across calls, be VERY CAREFUL when using it. Refer to
/// https://paradigm.xyz/2021/08/two-rights-might-make-a-wrong for more information.
smol-ninja marked this conversation as resolved.
Show resolved Hide resolved
function batch(bytes[] calldata calls) external payable override returns (bytes[] memory results) {
uint256 count = calls.length;
results = new bytes[](count);

for (uint256 i = 0; i < count; ++i) {
(bool success, bytes memory result) = address(this).delegatecall(calls[i]);

// Check: If the delegate call failed, load and bubble up the revert data.
smol-ninja marked this conversation as resolved.
Show resolved Hide resolved
if (!success) {
revert Errors.BatchError(result);
assembly {
// Get the length of the result stored in the first 32 bytes.
let resultSize := mload(result)

// Forward the pointer by 32 bytes to skip the length argument, and revert with the result.
revert(add(32, result), resultSize)
}
}

// Push the result into the results array.
results[i] = result;
}
}
}
7 changes: 5 additions & 2 deletions src/interfaces/IBatch.sol
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ pragma solidity >=0.8.22;

/// @notice This contract implements logic to batch call any function.
interface IBatch {
/// @notice Allows batched call to self, `this` contract.
/// @notice Allows batched calls to self, i.e., `this` contract.
/// @dev Since `msg.value` can be reused across calls, be VERY CAREFUL when using it. Refer to
/// https://paradigm.xyz/2021/08/two-rights-might-make-a-wrong for more information.
/// @param calls An array of inputs for each call.
function batch(bytes[] calldata calls) external payable;
/// @return results An array of results from each call. Empty when the calls do not return anything.
function batch(bytes[] calldata calls) external payable returns (bytes[] memory results);
}
58 changes: 36 additions & 22 deletions tests/integration/Integration.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -70,16 +70,8 @@ abstract contract Integration_Test is Base_Test {
function setUp() public virtual override {
Base_Test.setUp();

recipientInterfaceIDIncorrect = new RecipientInterfaceIDIncorrect();
recipientInterfaceIDMissing = new RecipientInterfaceIDMissing();
recipientInvalidSelector = new RecipientInvalidSelector();
recipientReentrant = new RecipientReentrant();
recipientReverting = new RecipientReverting();
vm.label({ account: address(recipientInterfaceIDIncorrect), newLabel: "Recipient Interface ID Incorrect" });
vm.label({ account: address(recipientInterfaceIDMissing), newLabel: "Recipient Interface ID Missing" });
vm.label({ account: address(recipientInvalidSelector), newLabel: "Recipient Invalid Selector" });
vm.label({ account: address(recipientReentrant), newLabel: "Recipient Reentrant" });
vm.label({ account: address(recipientReverting), newLabel: "Recipient Reverting" });
// Initialize the recipients with Hook implementations.
initializeRecipientsWithHooks();

_defaultParams.createWithTimestamps = defaults.createWithTimestamps();
_defaultParams.createWithDurations = defaults.createWithDurations();
Expand All @@ -104,8 +96,40 @@ abstract contract Integration_Test is Base_Test {
// Set the default Lockup model as Dynamic, we will override the default stream IDs where necessary.
lockupModel = Lockup.Model.LOCKUP_DYNAMIC;

// Initialize default streams IDs.
initializeDefaultStreamIds();
// Initialize default streams.
initializeDefaultStreams();
}

/*//////////////////////////////////////////////////////////////////////////
INITIALIZE-FUNCTIONS
//////////////////////////////////////////////////////////////////////////*/

function initializeDefaultStreams() internal {
defaultStreamId = createDefaultStream();
notCancelableStreamId = createDefaultStreamNonCancelable();
notTransferableStreamId = createDefaultStreamNonTransferable();
recipientGoodStreamId = createDefaultStreamWithRecipient(address(recipientGood));
recipientInvalidSelectorStreamId = createDefaultStreamWithRecipient(address(recipientInvalidSelector));
recipientReentrantStreamId = createDefaultStreamWithRecipient(address(recipientReentrant));
recipientRevertStreamId = createDefaultStreamWithRecipient(address(recipientReverting));
}

function initializeRecipientsWithHooks() internal {
recipientInterfaceIDIncorrect = new RecipientInterfaceIDIncorrect();
recipientInterfaceIDMissing = new RecipientInterfaceIDMissing();
recipientInvalidSelector = new RecipientInvalidSelector();
recipientReentrant = new RecipientReentrant();
recipientReverting = new RecipientReverting();
vm.label({ account: address(recipientInterfaceIDIncorrect), newLabel: "Recipient Interface ID Incorrect" });
vm.label({ account: address(recipientInterfaceIDMissing), newLabel: "Recipient Interface ID Missing" });
vm.label({ account: address(recipientInvalidSelector), newLabel: "Recipient Invalid Selector" });
vm.label({ account: address(recipientReentrant), newLabel: "Recipient Reentrant" });
vm.label({ account: address(recipientReverting), newLabel: "Recipient Reverting" });

// Allow the recipients to Hook.
resetPrank({ msgSender: users.admin });
lockup.allowToHook(address(recipientReverting));
resetPrank({ msgSender: users.sender });
smol-ninja marked this conversation as resolved.
Show resolved Hide resolved
}

/*//////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -179,16 +203,6 @@ abstract contract Integration_Test is Base_Test {
streamId = createDefaultStream(params);
}

function initializeDefaultStreamIds() internal {
defaultStreamId = createDefaultStream();
notCancelableStreamId = createDefaultStreamNonCancelable();
notTransferableStreamId = createDefaultStreamNonTransferable();
recipientGoodStreamId = createDefaultStreamWithRecipient(address(recipientGood));
recipientInvalidSelectorStreamId = createDefaultStreamWithRecipient(address(recipientInvalidSelector));
recipientReentrantStreamId = createDefaultStreamWithRecipient(address(recipientReentrant));
recipientRevertStreamId = createDefaultStreamWithRecipient(address(recipientReverting));
}

/*//////////////////////////////////////////////////////////////////////////
COMMON-REVERT-TESTS
//////////////////////////////////////////////////////////////////////////*/
Expand Down
165 changes: 165 additions & 0 deletions tests/integration/concrete/batch/batch.t.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
// SPDX-License-Identifier: UNLICENSED
pragma solidity >=0.8.22;

import { Solarray } from "solarray/src/Solarray.sol";
import { Errors } from "src/libraries/Errors.sol";

import { Integration_Test } from "../../Integration.t.sol";

contract Batch_Integration_Concrete_Test is Integration_Test {
/*//////////////////////////////////////////////////////////////////////////
BATCH + LOCKUP
//////////////////////////////////////////////////////////////////////////*/

/// @dev The batch call cancels a non-cancelable stream.
function test_RevertWhen_LockupThrows() external {
bytes[] memory calls = new bytes[](2);
calls[0] = abi.encodeCall(lockup.cancel, (defaultStreamId));
calls[1] = abi.encodeCall(lockup.cancel, (notCancelableStreamId));

// Expect revert on notCancelableStreamId.
vm.expectRevert(
abi.encodeWithSelector(Errors.SablierLockupBase_StreamNotCancelable.selector, notCancelableStreamId)
);
lockup.batch(calls);
}

/// @dev The batch call includes:
/// - Returning state changing functions
/// - Non-returning state changing functions
/// - View only functions
function test_Batch_StateChangingAndViewFunctions() external {
uint256 expectedNextStreamId = lockup.nextStreamId();
vm.warp(defaults.WARP_26_PERCENT());

bytes[] memory calls = new bytes[](6);
// It should return True.
calls[0] = abi.encodeCall(lockup.isCancelable, (defaultStreamId));
// It should return the withdrawn amount.
calls[1] = abi.encodeCall(lockup.withdrawMax, (notCancelableStreamId, users.recipient));
// It should return nothing.
calls[2] = abi.encodeCall(lockup.cancel, (defaultStreamId));
// It should return the next stream ID.
calls[3] = abi.encodeCall(lockup.nextStreamId, ());
// It should return the stream ID created.
calls[4] = abi.encodeCall(
lockup.createWithTimestampsLL,
(defaults.createWithTimestamps(), defaults.unlockAmounts(), defaults.CLIFF_TIME())
);
// It should return nothing.
calls[5] = abi.encodeCall(lockup.renounce, (notTransferableStreamId));

bytes[] memory results = lockup.batch(calls);
assertEq(results.length, 6, "batch results length");
assertTrue(abi.decode(results[0], (bool)), "batch results[0]");
assertEq(abi.decode(results[1], (uint128)), defaults.WITHDRAW_AMOUNT(), "batch results[1]");
assertEq(results[2], hex"", "batch results[2]");
assertEq(abi.decode(results[3], (uint256)), expectedNextStreamId, "batch results[3]");
assertEq(abi.decode(results[4], (uint256)), expectedNextStreamId, "batch results[4]");
assertEq(results[5], hex"", "batch results[5]");
}

/// @dev The batch call includes:
/// - Payable functions
/// - All create stream functions that return a value
function test_BatchPayable_CreateStreams() external {
uint256 expectedNextStreamId = lockup.nextStreamId();
uint256 initialEthBalance = address(lockup).balance;

bytes[] memory calls = new bytes[](6);
calls[0] = abi.encodeCall(
lockup.createWithDurationsLD, (defaults.createWithDurations(), defaults.segmentsWithDurations())
);
calls[1] = abi.encodeCall(
lockup.createWithDurationsLL,
(defaults.createWithDurations(), defaults.unlockAmounts(), defaults.durations())
);
calls[2] = abi.encodeCall(
lockup.createWithDurationsLT, (defaults.createWithDurations(), defaults.tranchesWithDurations())
);
calls[3] = abi.encodeCall(lockup.createWithTimestampsLD, (defaults.createWithTimestamps(), defaults.segments()));
calls[4] = abi.encodeCall(
lockup.createWithTimestampsLL,
(defaults.createWithTimestamps(), defaults.unlockAmounts(), defaults.CLIFF_TIME())
);
calls[5] = abi.encodeCall(lockup.createWithTimestampsLT, (defaults.createWithTimestamps(), defaults.tranches()));

// It should return the stream IDs created.
bytes[] memory results = lockup.batch{ value: 1 wei }(calls);
assertEq(results.length, 6, "batch results length");
assertEq(abi.decode(results[0], (uint256)), expectedNextStreamId, "batch results[0]");
assertEq(abi.decode(results[1], (uint256)), expectedNextStreamId + 1, "batch results[1]");
assertEq(abi.decode(results[2], (uint256)), expectedNextStreamId + 2, "batch results[2]");
assertEq(abi.decode(results[3], (uint256)), expectedNextStreamId + 3, "batch results[3]");
assertEq(abi.decode(results[4], (uint256)), expectedNextStreamId + 4, "batch results[4]");
assertEq(abi.decode(results[5], (uint256)), expectedNextStreamId + 5, "batch results[5]");
assertEq(address(lockup).balance, initialEthBalance + 1 wei, "batch contract balance");
}

/// @dev The batch call includes:
/// - Payable functions
smol-ninja marked this conversation as resolved.
Show resolved Hide resolved
/// - All recipient related functions with both returns and non-returns
function test_BatchPayable_RecipientFunctions() external {
uint256 initialEthBalance = address(lockup).balance;
vm.warp(defaults.WARP_26_PERCENT());

bytes[] memory calls = new bytes[](4);
calls[0] = abi.encodeCall(lockup.cancel, (defaultStreamId));

uint256[] memory streamIds = new uint256[](2);
streamIds[0] = recipientGoodStreamId;
streamIds[1] = recipientInvalidSelectorStreamId;
calls[1] = abi.encodeCall(lockup.cancelMultiple, (streamIds));

calls[2] = abi.encodeCall(lockup.renounce, (recipientReentrantStreamId));

streamIds = new uint256[](1);
streamIds[0] = recipientRevertStreamId;
calls[3] = abi.encodeCall(lockup.renounceMultiple, (streamIds));

bytes[] memory results = lockup.batch{ value: 1 wei }(calls);

assertEq(results.length, 4);
assertEq(results[0], hex"");
assertEq(results[1], hex"");
assertEq(results[2], hex"");
assertEq(results[3], hex"");
smol-ninja marked this conversation as resolved.
Show resolved Hide resolved
assertEq(address(lockup).balance, initialEthBalance + 1 wei);
}

/// @dev The batch call includes:
/// - Payable functions
/// - All sender related functions with both returns and non-returns
function test_BatchPayable_SenderFunctions() external {
uint256 initialEthBalance = address(lockup).balance;
// Warp to the end time so that `burn` can be added to the call list.
vm.warp(defaults.END_TIME());

bytes[] memory calls = new bytes[](5);
// It should return nothing.
calls[0] = abi.encodeCall(lockup.withdraw, (defaultStreamId, users.recipient, 1));
// It should return the withdrawn amount.
calls[1] = abi.encodeCall(lockup.withdrawMax, (defaultStreamId, users.recipient));

uint256[] memory streamIds = Solarray.uint256s(notCancelableStreamId, notCancelableStreamId);
uint128[] memory amounts = Solarray.uint128s(1, 1);

// It should return nothing.
calls[2] = abi.encodeCall(lockup.withdrawMultiple, (streamIds, amounts));
// It should return the withdrawn amount.
calls[3] = abi.encodeCall(lockup.withdrawMaxAndTransfer, (notCancelableStreamId, users.recipient));
// It should return nothing.
calls[4] = abi.encodeCall(lockup.burn, (defaultStreamId));

resetPrank({ msgSender: users.recipient });
bytes[] memory results = lockup.batch{ value: 1 wei }(calls);

assertEq(results.length, 5, "batch results length");
assertEq(results[0], hex"", "batch results[0]");
assertEq(abi.decode(results[1], (uint128)), defaults.DEPOSIT_AMOUNT() - 1, "batch results[1]");
assertEq(results[2], hex"", "batch results[2]");
assertEq(abi.decode(results[3], (uint128)), defaults.DEPOSIT_AMOUNT() - 2, "batch results[3]");
assertEq(results[4], hex"", "batch results[4]");
assertEq(address(lockup).balance, initialEthBalance + 1 wei, "batch contract balance");
}
}
41 changes: 0 additions & 41 deletions tests/integration/concrete/lockup-base/batch/batch.t.sol

This file was deleted.

10 changes: 0 additions & 10 deletions tests/integration/concrete/lockup-base/batch/batch.tree

This file was deleted.

5 changes: 0 additions & 5 deletions tests/integration/concrete/lockup-base/cancel/cancel.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,6 @@ abstract contract Cancel_Integration_Concrete_Test is Integration_Test {
givenSTREAMINGStatus
givenRecipientAllowedToHook
{
// Allow the recipient to hook.
resetPrank({ msgSender: users.admin });
lockup.allowToHook(address(recipientReverting));
resetPrank({ msgSender: users.sender });

// It should revert.
vm.expectRevert("You shall not pass");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ contract RenounceMultiple_Integration_Concrete_Test is Integration_Test {
lockup.renounceMultiple(nullStreamIds);
}

function test_RevertGiven_AtleastOneNullStream() external whenNoDelegateCall whenNonZeroArrayLength {
function test_RevertGiven_AtLeastOneNullStream() external whenNoDelegateCall whenNonZeroArrayLength {
expectRevert_Null({
callData: abi.encodeCall(lockup.renounceMultiple, Solarray.uint256s(streamIds[0], nullStreamId))
});
}

function test_RevertGiven_AtleastOneColdStream()
function test_RevertGiven_AtLeastOneColdStream()
smol-ninja marked this conversation as resolved.
Show resolved Hide resolved
external
whenNoDelegateCall
whenNonZeroArrayLength
Expand Down Expand Up @@ -68,7 +68,7 @@ contract RenounceMultiple_Integration_Concrete_Test is Integration_Test {
lockup.renounceMultiple(streamIds);
}

function test_RevertGiven_AtleastOneNonCancelableStream()
function test_RevertGiven_AtLeastOneNonCancelableStream()
external
whenNoDelegateCall
whenNonZeroArrayLength
Expand Down
Loading
Loading