diff --git a/src/modules/invoice-module/InvoiceModule.sol b/src/modules/invoice-module/InvoiceModule.sol index b9bc6973..14a0f3cc 100644 --- a/src/modules/invoice-module/InvoiceModule.sol +++ b/src/modules/invoice-module/InvoiceModule.sol @@ -98,21 +98,26 @@ contract InvoiceModule is IInvoiceModule, StreamManager { } } - // Gets the number of payments for the invoice based on the payment method, interval and recurrence type + // Validates the invoice interval (endTime - startTime) and returns the number of payments of the invoice + // based on the payment method, interval and recurrence type // // Notes: + // - The number of payments is taken into account only for transfer-based invoices // - There should be only one payment when dealing with a one-off transfer-based invoice - // - When dealing with a recurring transfer or tranched stream, the number of payments must be calculated based + // - When dealing with a recurring transfer, the number of payments must be calculated based // on the payment interval (endTime - startTime) and recurrence type uint40 numberOfPayments; if (invoice.payment.method == Types.Method.Transfer && invoice.payment.recurrence == Types.Recurrence.OneOff) { numberOfPayments = 1; } else if (invoice.payment.method != Types.Method.LinearStream) { - numberOfPayments = _checkAndComputeNumberOfPayments({ + numberOfPayments = _checkIntervalPayments({ recurrence: invoice.payment.recurrence, startTime: invoice.startTime, endTime: invoice.endTime }); + + // Set the number of payments to zero if dealing with a tranched-based invoice + if (invoice.payment.method == Types.Method.TranchedStream) numberOfPayments = 0; } // Checks: the asset is different than the native token if dealing with either a linear or tranched stream-based invoice @@ -166,6 +171,11 @@ contract InvoiceModule is IInvoiceModule, StreamManager { // Load the invoice from storage Types.Invoice memory invoice = _invoices[id]; + // Checks: the invoice is not null + if (invoice.recipient == address(0)) { + revert Errors.InvoiceNull(); + } + // Checks: the invoice is not already paid or canceled if (invoice.status == Types.Status.Paid) { revert Errors.InvoiceAlreadyPaid(); @@ -178,14 +188,55 @@ contract InvoiceModule is IInvoiceModule, StreamManager { _payByTransfer(id, invoice); } else { uint256 streamId; - // Check to see wether to pay by creating a linear or tranched stream + // Check to see whether the invoice must be paid through a linear or tranched stream if (invoice.payment.method == Types.Method.LinearStream) { streamId = _payByLinearStream(invoice); } else streamId = _payByTranchedStream(invoice); + + // Effects: update the status of the invoice and stream ID + _invoices[id].status = Types.Status.Paid; + _invoices[id].payment.streamId = streamId; } // Log the payment transaction - emit InvoicePaid({ id: id, payer: msg.sender, status: invoice.status, payment: invoice.payment }); + emit InvoicePaid({ id: id, payer: msg.sender, status: _invoices[id].status, payment: _invoices[id].payment }); + } + + /// @inheritdoc IInvoiceModule + function cancelInvoice(uint256 id) external { + // Load the invoice from storage + Types.Invoice memory invoice = _invoices[id]; + + // Checks: the invoice is paid or already canceled + if (invoice.status == Types.Status.Paid) { + revert Errors.CannotCancelPaidInvoice(); + } else if (invoice.status == Types.Status.Canceled) { + revert Errors.CannotCancelCanceledInvoice(); + } + + // Checks: the `msg.sender` is the creator if dealing with a transfer-based invoice + // + // Notes: + // - for a linear or tranched stream-based invoice, the `msg.sender` is checked in the + // {SablierV2Lockup} `cancel` method + if (invoice.payment.method == Types.Method.Transfer) { + if (invoice.recipient != msg.sender) { + revert Errors.InvoiceOwnerUnauthorized(); + } + } + + // Effects: cancel the stream accordingly depending on its type + if (invoice.payment.method == Types.Method.LinearStream) { + cancelLinearStream({ streamId: invoice.payment.streamId }); + } else if (invoice.payment.method == Types.Method.TranchedStream) { + cancelTranchedStream({ streamId: invoice.payment.streamId }); + } + + // Effects: mark the invoice as canceled + _invoices[id].status = Types.Status.Canceled; + + // Log the invoice cancelation + emit InvoiceCanceled(id); } /*////////////////////////////////////////////////////////////////////////// @@ -219,7 +270,8 @@ contract InvoiceModule is IInvoiceModule, StreamManager { if (!success) revert Errors.NativeTokenPaymentFailed(); } else { // Interactions: pay the recipient with the ERC-20 token - IERC20(invoice.payment.asset).safeTransfer({ + IERC20(invoice.payment.asset).safeTransferFrom({ + from: msg.sender, to: address(invoice.recipient), value: invoice.payment.amount }); @@ -239,19 +291,24 @@ contract InvoiceModule is IInvoiceModule, StreamManager { /// @dev Create the tranched stream payment function _payByTranchedStream(Types.Invoice memory invoice) internal returns (uint256 streamId) { + uint40 numberOfTranches = Helpers.computeNumberOfPayments( + invoice.payment.recurrence, + invoice.endTime - invoice.startTime + ); + streamId = StreamManager.createTranchedStream({ asset: IERC20(invoice.payment.asset), totalAmount: invoice.payment.amount, startTime: invoice.startTime, recipient: invoice.recipient, - numberOfTranches: invoice.payment.paymentsLeft, + numberOfTranches: numberOfTranches, recurrence: invoice.payment.recurrence }); } /// @notice Calculates the number of payments to be made for a recurring transfer and tranched stream-based invoice /// @dev Reverts if the number of payments is zero, indicating that either the interval or recurrence type was set incorrectly - function _checkAndComputeNumberOfPayments( + function _checkIntervalPayments( Types.Recurrence recurrence, uint40 startTime, uint40 endTime diff --git a/src/modules/invoice-module/interfaces/IInvoiceModule.sol b/src/modules/invoice-module/interfaces/IInvoiceModule.sol index da661771..00e6c181 100644 --- a/src/modules/invoice-module/interfaces/IInvoiceModule.sol +++ b/src/modules/invoice-module/interfaces/IInvoiceModule.sol @@ -10,7 +10,7 @@ interface IInvoiceModule { EVENTS //////////////////////////////////////////////////////////////////////////*/ - /// @notice Emitted when a regular or recurring invoice is created + /// @notice Emitted when an invoice is created /// @param id The ID of the invoice /// @param recipient The address receiving the payment /// @param status The status of the invoice @@ -26,13 +26,17 @@ interface IInvoiceModule { Types.Payment payment ); - /// @notice Emitted when a regular or recurring invoice is paid + /// @notice Emitted when an invoice is paid /// @param id The ID of the invoice /// @param payer The address of the payer /// @param status The status of the invoice /// @param payment Struct representing the payment details associated with the invoice event InvoicePaid(uint256 indexed id, address indexed payer, Types.Status status, Types.Payment payment); + /// @notice Emitted when an invoice is canceled + /// @param id The ID of the invoice + event InvoiceCanceled(uint256 indexed id); + /*////////////////////////////////////////////////////////////////////////// CONSTANT FUNCTIONS //////////////////////////////////////////////////////////////////////////*/ @@ -64,4 +68,17 @@ interface IInvoiceModule { /// /// @param id The ID of the invoice to pay function payInvoice(uint256 id) external payable; + + /// @notice Cancels the `id` invoice + /// + /// Notes: + /// - if the invoice has a linear or tranched stream payment method, the streaming flow will be + /// stopped and the remaining funds will be refunded to the stream payer + /// + /// Important: + /// - if the invoice has a linear or tranched stream payment method, the portion that has already + /// been streamed is NOT automatically transferred + /// + /// @param id The ID of the invoice + function cancelInvoice(uint256 id) external; } diff --git a/src/modules/invoice-module/libraries/Errors.sol b/src/modules/invoice-module/libraries/Errors.sol index 3cefaa23..2275bf78 100644 --- a/src/modules/invoice-module/libraries/Errors.sol +++ b/src/modules/invoice-module/libraries/Errors.sol @@ -38,6 +38,12 @@ library Errors { /// @notice Thrown when a payer attempts to pay a canceled invoice error InvoiceCanceled(); + /// @notice Thrown when the invoice ID references a null invoice + error InvoiceNull(); + + /// @notice Thrown when `msg.sender` is not the creator (recipient) of the invoice + error InvoiceOwnerUnauthorized(); + /// @notice Thrown when the payment interval (endTime - startTime) is too short for the selected recurrence /// i.e. recurrence is set to weekly but interval is shorter than 1 week error PaymentIntervalTooShortForSelectedRecurrence(); @@ -45,6 +51,12 @@ library Errors { /// @notice Thrown when a tranched stream has a one-off recurrence type error TranchedStreamInvalidOneOffRecurence(); + /// @notice Thrown when an attempt is made to cancel an already paid invoice + error CannotCancelPaidInvoice(); + + /// @notice Thrown when an attempt is made to cancel an already canceled invoice + error CannotCancelCanceledInvoice(); + /*////////////////////////////////////////////////////////////////////////// STREAM-MANAGER //////////////////////////////////////////////////////////////////////////*/ diff --git a/src/modules/invoice-module/sablier-v2/StreamManager.sol b/src/modules/invoice-module/sablier-v2/StreamManager.sol index af8709d0..7dc908b7 100644 --- a/src/modules/invoice-module/sablier-v2/StreamManager.sol +++ b/src/modules/invoice-module/sablier-v2/StreamManager.sol @@ -7,6 +7,7 @@ import { ISablierV2Lockup } from "@sablier/v2-core/src/interfaces/ISablierV2Lock import { LockupLinear, LockupTranched } from "@sablier/v2-core/src/types/DataTypes.sol"; import { Broker, LockupLinear } from "@sablier/v2-core/src/types/DataTypes.sol"; import { IERC20 } from "@openzeppelin/contracts/token/ERC20/IERC20.sol"; +import { SafeERC20 } from "@openzeppelin/contracts/token/ERC20/utils/SafeERC20.sol"; import { ud60x18, UD60x18 } from "@prb/math/src/UD60x18.sol"; import { IStreamManager } from "./interfaces/IStreamManager.sol"; @@ -16,6 +17,8 @@ import { Types } from "./../libraries/Types.sol"; /// @title StreamManager /// @dev See the documentation in {IStreamManager} contract StreamManager is IStreamManager { + using SafeERC20 for IERC20; + /*////////////////////////////////////////////////////////////////////////// PUBLIC STORAGE //////////////////////////////////////////////////////////////////////////*/ @@ -71,7 +74,7 @@ contract StreamManager is IStreamManager { address recipient ) public returns (uint256 streamId) { // Transfer the provided amount of ERC-20 tokens to this contract and approve the Sablier contract to spend it - _transferFromAndApprove({ asset: asset, spender: address(LOCKUP_LINEAR), amount: totalAmount }); + _transferFromAndApprove({ asset: asset, amount: totalAmount, spender: address(LOCKUP_LINEAR) }); // Create the Lockup Linear stream streamId = _createLinearStream(asset, totalAmount, startTime, endTime, recipient); @@ -87,14 +90,14 @@ contract StreamManager is IStreamManager { Types.Recurrence recurrence ) public returns (uint256 streamId) { // Transfer the provided amount of ERC-20 tokens to this contract and approve the Sablier contract to spend it - _transferFromAndApprove({ asset: asset, spender: address(LOCKUP_TRANCHED), amount: totalAmount }); + _transferFromAndApprove({ asset: asset, amount: totalAmount, spender: address(LOCKUP_TRANCHED) }); // Create the Lockup Linear stream streamId = _createTranchedStream(asset, totalAmount, startTime, recipient, numberOfTranches, recurrence); } /// @inheritdoc IStreamManager - function updateBrokerFee(UD60x18 newBrokerFee) public onlyBrokerAdmin { + function updateStreamBrokerFee(UD60x18 newBrokerFee) public onlyBrokerAdmin { // Log the broker fee update emit BrokerFeeUpdated({ oldFee: brokerFee, newFee: newBrokerFee }); @@ -107,34 +110,13 @@ contract StreamManager is IStreamManager { //////////////////////////////////////////////////////////////////////////*/ /// @inheritdoc IStreamManager - function withdraw(ISablierV2Lockup sablier, uint256 streamId, address to, uint128 amount) external { - sablier.withdraw(streamId, to, amount); - } - - /// @inheritdoc IStreamManager - function withdrawableAmountOf( - ISablierV2Lockup sablier, - uint256 streamId - ) external view returns (uint128 withdrawableAmount) { - withdrawableAmount = sablier.withdrawableAmountOf(streamId); + function withdrawLinearStream(uint256 streamId, address to, uint128 amount) public { + _withdrawStream({ sablier: LOCKUP_LINEAR, streamId: streamId, to: to, amount: amount }); } /// @inheritdoc IStreamManager - function withdrawMax( - ISablierV2Lockup sablier, - uint256 streamId, - address to - ) external returns (uint128 withdrawnAmount) { - withdrawnAmount = sablier.withdrawMax(streamId, to); - } - - /// @inheritdoc IStreamManager - function withdrawMultiple( - ISablierV2Lockup sablier, - uint256[] calldata streamIds, - uint128[] calldata amounts - ) external { - sablier.withdrawMultiple(streamIds, amounts); + function withdrawTranchedStream(uint256 streamId, address to, uint128 amount) public { + _withdrawStream({ sablier: LOCKUP_TRANCHED, streamId: streamId, to: to, amount: amount }); } /*////////////////////////////////////////////////////////////////////////// @@ -142,35 +124,27 @@ contract StreamManager is IStreamManager { //////////////////////////////////////////////////////////////////////////*/ /// @inheritdoc IStreamManager - function cancel(ISablierV2Lockup sablier, uint256 streamId) external { - sablier.cancel(streamId); + function cancelLinearStream(uint256 streamId) public { + _cancelStream({ sablier: LOCKUP_LINEAR, streamId: streamId }); } /// @inheritdoc IStreamManager - function cancelMultiple(ISablierV2Lockup sablier, uint256[] calldata streamIds) external { - sablier.cancelMultiple(streamIds); + function cancelTranchedStream(uint256 streamId) public { + _cancelStream({ sablier: LOCKUP_TRANCHED, streamId: streamId }); } /*////////////////////////////////////////////////////////////////////////// - RENOUNCE FUNCTIONS + CONSTANT FUNCTIONS //////////////////////////////////////////////////////////////////////////*/ /// @inheritdoc IStreamManager - function renounce(ISablierV2Lockup sablier, uint256 streamId) external { - sablier.renounce(streamId); + function getLinearStream(uint256 streamId) public view returns (LockupLinear.StreamLL memory stream) { + stream = LOCKUP_LINEAR.getStream(streamId); } - /*////////////////////////////////////////////////////////////////////////// - TRANSFER FUNCTIONS - //////////////////////////////////////////////////////////////////////////*/ - /// @inheritdoc IStreamManager - function withdrawMaxAndTransfer( - ISablierV2Lockup sablier, - uint256 streamId, - address newRecipient - ) external returns (uint128 withdrawnAmount) { - withdrawnAmount = sablier.withdrawMaxAndTransfer(streamId, newRecipient); + function getTranchedStream(uint256 streamId) public view returns (LockupTranched.StreamLT memory stream) { + stream = LOCKUP_TRANCHED.getStream(streamId); } /*////////////////////////////////////////////////////////////////////////// @@ -225,6 +199,7 @@ contract StreamManager is IStreamManager { params.asset = asset; // The streaming asset params.cancelable = true; // Whether the stream will be cancelable or not params.transferable = true; // Whether the stream will be transferable or not + params.startTime = startTime; // The timestamp when to start streaming // Calculate the duration of each tranche based on the payment recurrence uint40 durationPerTranche = _computeDurationPerTrache(recurrence); @@ -251,14 +226,27 @@ contract StreamManager is IStreamManager { streamId = LOCKUP_TRANCHED.createWithTimestamps(params); } + /// @dev Withdraws from either a linear or tranched stream + function _withdrawStream(ISablierV2Lockup sablier, uint256 streamId, address to, uint128 amount) internal { + sablier.withdraw(streamId, to, amount); + } + + /// @dev Cancels the `streamId` stream + function _cancelStream(ISablierV2Lockup sablier, uint256 streamId) internal { + sablier.cancel(streamId); + } + + /// @dev Transfers the `amount` of `asset` tokens to this address (or the contract inherting from) + /// and approves either the `SablierV2LockupLinear` or `SablierV2LockupTranched` to spend the amount function _transferFromAndApprove(IERC20 asset, uint128 amount, address spender) internal { // Transfer the provided amount of ERC-20 tokens to this contract - asset.transferFrom(msg.sender, address(this), amount); + IERC20(asset).safeTransferFrom(msg.sender, address(this), amount); // Approve the Sablier contract to spend the ERC-20 tokens asset.approve(spender, amount); } + /// @dev Calculates the duration of each tranches from a tranched stream based on a recurrence function _computeDurationPerTrache(Types.Recurrence recurrence) internal pure returns (uint40 duration) { if (recurrence == Types.Recurrence.Weekly) duration = 1 weeks; else if (recurrence == Types.Recurrence.Monthly) duration = 4 weeks; diff --git a/src/modules/invoice-module/sablier-v2/interfaces/IStreamManager.sol b/src/modules/invoice-module/sablier-v2/interfaces/IStreamManager.sol index ffa3079d..92022638 100644 --- a/src/modules/invoice-module/sablier-v2/interfaces/IStreamManager.sol +++ b/src/modules/invoice-module/sablier-v2/interfaces/IStreamManager.sol @@ -3,6 +3,7 @@ pragma solidity >=0.8.22; import { ISablierV2LockupLinear } from "@sablier/v2-core/src/interfaces/ISablierV2LockupLinear.sol"; import { ISablierV2LockupTranched } from "@sablier/v2-core/src/interfaces/ISablierV2LockupTranched.sol"; +import { LockupLinear, LockupTranched } from "@sablier/v2-core/src/types/DataTypes.sol"; import { ISablierV2Lockup } from "@sablier/v2-core/src/interfaces/ISablierV2Lockup.sol"; import { IERC20 } from "@openzeppelin/contracts/token/ERC20/IERC20.sol"; import { UD60x18 } from "@prb/math/src/UD60x18.sol"; @@ -42,6 +43,14 @@ interface IStreamManager { /// @dev See the `UD60x18` type definition in the `@prb/math/src/ud60x18/ValueType.sol file` function brokerFee() external view returns (UD60x18); + /// @notice Retrieves a linear stream details according to the {LockupLinear.StreamLL} struct + /// @param streamId The ID of the stream to be retrieved + function getLinearStream(uint256 streamId) external view returns (LockupLinear.StreamLL memory stream); + + /// @notice Retrieves a tranched stream details according to the {LockupTranched.StreamLT} struct + /// @param streamId The ID of the stream to be retrieved + function getTranchedStream(uint256 streamId) external view returns (LockupTranched.StreamLT memory stream); + /*////////////////////////////////////////////////////////////////////////// NON-CONSTANT FUNCTIONS //////////////////////////////////////////////////////////////////////////*/ @@ -82,44 +91,17 @@ interface IStreamManager { /// - The new fee will be applied only to the new streams hence it can't be retrospectively updated /// /// @param newBrokerFee The new broker fee - function updateBrokerFee(UD60x18 newBrokerFee) external; - - /// @notice See the documentation in {ISablierV2Lockup} - function withdraw(ISablierV2Lockup sablier, uint256 streamId, address to, uint128 amount) external; - - /// @notice See the documentation in {ISablierV2Lockup} - function withdrawableAmountOf( - ISablierV2Lockup sablier, - uint256 streamId - ) external view returns (uint128 withdrawableAmount); - - /// @notice See the documentation in {ISablierV2Lockup} - function withdrawMax( - ISablierV2Lockup sablier, - uint256 streamId, - address to - ) external returns (uint128 withdrawnAmount); - - /// @notice See the documentation in {ISablierV2Lockup} - function withdrawMultiple( - ISablierV2Lockup sablier, - uint256[] calldata streamIds, - uint128[] calldata amounts - ) external; - - /// @notice See the documentation in {ISablierV2Lockup} - function withdrawMaxAndTransfer( - ISablierV2Lockup sablier, - uint256 streamId, - address newRecipient - ) external returns (uint128 withdrawnAmount); - - /// @notice See the documentation in {ISablierV2Lockup} - function cancel(ISablierV2Lockup sablier, uint256 streamId) external; - - /// @notice See the documentation in {ISablierV2Lockup} - function cancelMultiple(ISablierV2Lockup sablier, uint256[] calldata streamIds) external; - - /// @notice See the documentation in {ISablierV2Lockup} - function renounce(ISablierV2Lockup sablier, uint256 streamId) external; + function updateStreamBrokerFee(UD60x18 newBrokerFee) external; + + /// @notice See the documentation in {ISablierV2Lockup-withdraw} + function withdrawLinearStream(uint256 streamId, address to, uint128 amount) external; + + /// @notice See the documentation in {ISablierV2Lockup-withdraw} + function withdrawTranchedStream(uint256 streamId, address to, uint128 amount) external; + + /// @notice See the documentation in {ISablierV2Lockup-cancel} + function cancelLinearStream(uint256 streamId) external; + + /// @notice See the documentation in {ISablierV2Lockup-cancel} + function cancelTranchedStream(uint256 streamId) external; } diff --git a/test/Base.t.sol b/test/Base.t.sol index a0731a41..6d1fe05e 100644 --- a/test/Base.t.sol +++ b/test/Base.t.sol @@ -7,6 +7,7 @@ import { Test } from "forge-std/Test.sol"; import { MockERC20NoReturn } from "./mocks/MockERC20NoReturn.sol"; import { MockNonCompliantContainer } from "./mocks/MockNonCompliantContainer.sol"; import { MockModule } from "./mocks/MockModule.sol"; +import { MockBadReceiver } from "./mocks/MockBadReceiver.sol"; import { Container } from "./../src/Container.sol"; abstract contract Base_Test is Test, Events { @@ -24,6 +25,7 @@ abstract contract Base_Test is Test, Events { MockERC20NoReturn internal usdt; MockModule internal mockModule; MockNonCompliantContainer internal mockNonCompliantContainer; + MockBadReceiver internal mockBadReceiver; /*////////////////////////////////////////////////////////////////////////// SET-UP FUNCTION @@ -39,6 +41,7 @@ abstract contract Base_Test is Test, Events { // Deploy test contracts mockModule = new MockModule(); mockNonCompliantContainer = new MockNonCompliantContainer({ _owner: users.admin }); + mockBadReceiver = new MockBadReceiver(); // Label the test contracts so we can easily track them vm.label({ account: address(usdt), newLabel: "USDT" }); @@ -63,7 +66,7 @@ abstract contract Base_Test is Test, Events { function createUser(string memory name) internal returns (address payable) { address payable user = payable(makeAddr(name)); vm.deal({ account: user, newBalance: 100 ether }); - deal({ token: address(usdt), to: user, give: 1000000e6 }); + deal({ token: address(usdt), to: user, give: 10000000e18 }); return user; } diff --git a/test/integration/concrete/invoice-module/create-invoice/createInvoice.t.sol b/test/integration/concrete/invoice-module/create-invoice/createInvoice.t.sol index 01a102c2..7c803321 100644 --- a/test/integration/concrete/invoice-module/create-invoice/createInvoice.t.sol +++ b/test/integration/concrete/invoice-module/create-invoice/createInvoice.t.sol @@ -1,14 +1,16 @@ // SPDX-License-Identifier: MIT pragma solidity ^0.8.26; -import { InvoiceModule_Integration_Shared_Test } from "../../../shared/InvoiceModule.t.sol"; +import { CreateInvoice_Integration_Shared_Test } from "../../../shared/createInvoice.t.sol"; import { Types } from "./../../../../../src/modules/invoice-module/libraries/Types.sol"; import { Errors } from "../../../../utils/Errors.sol"; import { Events } from "../../../../utils/Events.sol"; -contract CreateInvoice_Integration_Concret_Test is InvoiceModule_Integration_Shared_Test { +contract CreateInvoice_Integration_Concret_Test is CreateInvoice_Integration_Shared_Test { + Types.Invoice invoice; + function setUp() public virtual override { - InvoiceModule_Integration_Shared_Test.setUp(); + CreateInvoice_Integration_Shared_Test.setUp(); } function test_RevertWhen_CallerNotContract() external { @@ -19,7 +21,7 @@ contract CreateInvoice_Integration_Concret_Test is InvoiceModule_Integration_Sha vm.expectRevert(Errors.ContainerZeroCodeSize.selector); // Create an one-off transfer invoice - createInvoiceWithOneOffTransfer(); + invoice = createInvoiceWithOneOffTransfer({ asset: address(usdt), recipient: users.eve }); // Run the test invoiceModule.createInvoice(invoice); @@ -30,7 +32,7 @@ contract CreateInvoice_Integration_Concret_Test is InvoiceModule_Integration_Sha vm.startPrank({ msgSender: users.eve }); // Create an one-off transfer invoice - createInvoiceWithOneOffTransfer(); + invoice = createInvoiceWithOneOffTransfer({ asset: address(usdt), recipient: users.eve }); // Create the calldata for the Invoice Module execution bytes memory data = abi.encodeWithSignature( @@ -50,7 +52,7 @@ contract CreateInvoice_Integration_Concret_Test is InvoiceModule_Integration_Sha vm.startPrank({ msgSender: users.eve }); // Create an one-off transfer invoice - createInvoiceWithOneOffTransfer(); + invoice = createInvoiceWithOneOffTransfer({ asset: address(usdt), recipient: users.eve }); // Set the payment amount to zero to simulate the error invoice.payment.amount = 0; @@ -78,7 +80,7 @@ contract CreateInvoice_Integration_Concret_Test is InvoiceModule_Integration_Sha vm.startPrank({ msgSender: users.eve }); // Create an one-off transfer invoice - createInvoiceWithOneOffTransfer(); + invoice = createInvoiceWithOneOffTransfer({ asset: address(usdt), recipient: users.eve }); // Set the start time to be the current timestamp and the end time one second earlier invoice.startTime = uint40(block.timestamp); @@ -108,7 +110,7 @@ contract CreateInvoice_Integration_Concret_Test is InvoiceModule_Integration_Sha vm.startPrank({ msgSender: users.eve }); // Create an one-off transfer invoice - createInvoiceWithOneOffTransfer(); + invoice = createInvoiceWithOneOffTransfer({ asset: address(usdt), recipient: users.eve }); // Set the block.timestamp to 1641070800 vm.warp(1641070800); @@ -145,7 +147,7 @@ contract CreateInvoice_Integration_Concret_Test is InvoiceModule_Integration_Sha // Create a recurring transfer invoice that must be paid on a monthly basis // Hence, the interval between the start and end time must be at least 1 month - createInvoiceWithOneOffTransfer(); + invoice = createInvoiceWithOneOffTransfer({ asset: address(usdt), recipient: users.eve }); // Create the calldata for the Invoice Module execution bytes memory data = abi.encodeWithSignature( @@ -199,7 +201,7 @@ contract CreateInvoice_Integration_Concret_Test is InvoiceModule_Integration_Sha // Create a recurring transfer invoice that must be paid on a monthly basis // Hence, the interval between the start and end time must be at least 1 month - createInvoiceWithRecurringTransfer({ recurrence: Types.Recurrence.Monthly }); + invoice = createInvoiceWithRecurringTransfer({ recurrence: Types.Recurrence.Monthly, recipient: users.eve }); // Alter the end time to be 3 weeks from now invoice.endTime = uint40(block.timestamp) + 3 weeks; @@ -231,7 +233,7 @@ contract CreateInvoice_Integration_Concret_Test is InvoiceModule_Integration_Sha vm.startPrank({ msgSender: users.eve }); // Create a recurring transfer invoice that must be paid on weekly basis - createInvoiceWithRecurringTransfer({ recurrence: Types.Recurrence.Weekly }); + invoice = createInvoiceWithRecurringTransfer({ recurrence: Types.Recurrence.Weekly, recipient: users.eve }); // Create the calldata for the Invoice Module execution bytes memory data = abi.encodeWithSignature( @@ -284,7 +286,7 @@ contract CreateInvoice_Integration_Concret_Test is InvoiceModule_Integration_Sha vm.startPrank({ msgSender: users.eve }); // Create a new invoice with a tranched stream payment - createInvoiceWithTranchedStream({ recurrence: Types.Recurrence.Weekly }); + invoice = createInvoiceWithTranchedStream({ recurrence: Types.Recurrence.Weekly, recipient: users.eve }); // Alter the payment recurrence by setting it to one-off invoice.payment.recurrence = Types.Recurrence.OneOff; @@ -316,7 +318,7 @@ contract CreateInvoice_Integration_Concret_Test is InvoiceModule_Integration_Sha vm.startPrank({ msgSender: users.eve }); // Create a new invoice with a tranched stream payment - createInvoiceWithTranchedStream({ recurrence: Types.Recurrence.Monthly }); + invoice = createInvoiceWithTranchedStream({ recurrence: Types.Recurrence.Monthly, recipient: users.eve }); // Alter the end time to be 3 weeks from now invoice.endTime = uint40(block.timestamp) + 3 weeks; @@ -349,7 +351,7 @@ contract CreateInvoice_Integration_Concret_Test is InvoiceModule_Integration_Sha vm.startPrank({ msgSender: users.eve }); // Create a new invoice with a linear stream payment - createInvoiceWithTranchedStream({ recurrence: Types.Recurrence.Weekly }); + invoice = createInvoiceWithTranchedStream({ recurrence: Types.Recurrence.Weekly, recipient: users.eve }); // Alter the payment asset by setting it to invoice.payment.asset = address(0); @@ -381,7 +383,7 @@ contract CreateInvoice_Integration_Concret_Test is InvoiceModule_Integration_Sha vm.startPrank({ msgSender: users.eve }); // Create a new invoice with a tranched stream payment - createInvoiceWithTranchedStream({ recurrence: Types.Recurrence.Weekly }); + invoice = createInvoiceWithTranchedStream({ recurrence: Types.Recurrence.Weekly, recipient: users.eve }); // Create the calldata for the Invoice Module execution bytes memory data = abi.encodeWithSignature( @@ -415,7 +417,7 @@ contract CreateInvoice_Integration_Concret_Test is InvoiceModule_Integration_Sha assertEq(actualInvoice.endTime, invoice.endTime); assertEq(uint8(actualInvoice.payment.method), uint8(Types.Method.TranchedStream)); assertEq(uint8(actualInvoice.payment.recurrence), uint8(Types.Recurrence.Weekly)); - assertEq(actualInvoice.payment.paymentsLeft, 4); + assertEq(actualInvoice.payment.paymentsLeft, 0); assertEq(actualInvoice.payment.asset, invoice.payment.asset); assertEq(actualInvoice.payment.amount, invoice.payment.amount); assertEq(actualInvoice.payment.streamId, 0); @@ -434,7 +436,7 @@ contract CreateInvoice_Integration_Concret_Test is InvoiceModule_Integration_Sha vm.startPrank({ msgSender: users.eve }); // Create a new invoice with a linear stream payment - createInvoiceWithLinearStream(); + invoice = createInvoiceWithLinearStream({ recipient: users.eve }); // Alter the payment asset by setting it to invoice.payment.asset = address(0); @@ -466,7 +468,7 @@ contract CreateInvoice_Integration_Concret_Test is InvoiceModule_Integration_Sha vm.startPrank({ msgSender: users.eve }); // Create a new invoice with a linear stream payment - createInvoiceWithLinearStream(); + invoice = createInvoiceWithLinearStream({ recipient: users.eve }); // Create the calldata for the Invoice Module execution bytes memory data = abi.encodeWithSignature( diff --git a/test/integration/concrete/invoice-module/pay-invoice/payInvoice.t.sol b/test/integration/concrete/invoice-module/pay-invoice/payInvoice.t.sol new file mode 100644 index 00000000..c672278c --- /dev/null +++ b/test/integration/concrete/invoice-module/pay-invoice/payInvoice.t.sol @@ -0,0 +1,351 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.26; + +import { PayInvoice_Integration_Shared_Test } from "../../../shared/payInvoice.t.sol"; +import { Types } from "./../../../../../src/modules/invoice-module/libraries/Types.sol"; +import { Events } from "../../../../utils/Events.sol"; +import { Errors } from "../../../../utils/Errors.sol"; + +import { LockupLinear, LockupTranched } from "@sablier/v2-core/src/types/DataTypes.sol"; + +contract PayInvoice_Integration_Concret_Test is PayInvoice_Integration_Shared_Test { + function setUp() public virtual override { + PayInvoice_Integration_Shared_Test.setUp(); + + // Create a mock invoice with a one-off USDT transfer + Types.Invoice memory invoice = createInvoiceWithOneOffTransfer({ asset: address(usdt), recipient: users.eve }); + invoices[0] = invoice; + executeCreateInvoice({ invoice: invoice, user: users.eve }); + + // Create a mock invoice with a one-off ETH transfer + invoice = createInvoiceWithOneOffTransfer({ asset: address(0), recipient: users.eve }); + invoices[1] = invoice; + executeCreateInvoice({ invoice: invoice, user: users.eve }); + + // Create a mock invoice with a recurring USDT transfer + invoice = createInvoiceWithRecurringTransfer({ recurrence: Types.Recurrence.Weekly, recipient: users.eve }); + invoices[2] = invoice; + executeCreateInvoice({ invoice: invoice, user: users.eve }); + + // Create a mock invoice with a linear stream payment + invoice = createInvoiceWithLinearStream({ recipient: users.eve }); + invoices[3] = invoice; + executeCreateInvoice({ invoice: invoice, user: users.eve }); + + // Create a mock invoice with a tranched stream payment + invoice = createInvoiceWithTranchedStream({ recurrence: Types.Recurrence.Weekly, recipient: users.eve }); + invoices[4] = invoice; + executeCreateInvoice({ invoice: invoice, user: users.eve }); + } + + function test_RevertWhen_InvoiceNull() external { + // Expect the call to revert with the {InvoiceNull} error + vm.expectRevert(Errors.InvoiceNull.selector); + + // Run the test + invoiceModule.payInvoice({ id: 99 }); + } + + function test_RevertWhen_InvoiceAlreadyPaid() external whenInvoiceNotNull { + // Set the one-off USDT transfer invoice as current one + uint256 invoiceId = 0; + + // Make Bob the payer for the default invoice + vm.startPrank({ msgSender: users.bob }); + + // Approve the {InvoiceModule} to transfer the ERC-20 token on Bob's behalf + usdt.approve({ spender: address(invoiceModule), amount: invoices[invoiceId].payment.amount }); + + // Pay first the invoice + invoiceModule.payInvoice({ id: invoiceId }); + + // Expect the call to be reverted with the {InvoiceAlreadyPaid} error + vm.expectRevert(Errors.InvoiceAlreadyPaid.selector); + + // Run the test + invoiceModule.payInvoice({ id: invoiceId }); + } + + function test_RevertWhen_InvoiceCanceled() external whenInvoiceNotNull whenInvoiceNotAlreadyPaid { + // Set the one-off USDT transfer invoice as current one + uint256 invoiceId = 0; + + // Make Eve the caller in this test suite as she's the owner of the {Container} contract + vm.startPrank({ msgSender: users.eve }); + + // Cancel the invoice first + invoiceModule.cancelInvoice({ id: invoiceId }); + + // Make Bob the payer of this invoice + vm.startPrank({ msgSender: users.bob }); + + // Expect the call to be reverted with the {InvoiceCanceled} error + vm.expectRevert(Errors.InvoiceCanceled.selector); + + // Run the test + invoiceModule.payInvoice({ id: invoiceId }); + } + + function test_RevertWhen_PaymentMethodTransfer_PaymentAmountLessThanInvoiceValue() + external + whenInvoiceNotNull + whenInvoiceNotAlreadyPaid + whenInvoiceNotCanceled + givenPaymentMethodTransfer + givenPaymentAmountInNativeToken + { + // Set the one-off ETH transfer invoice as current one + uint256 invoiceId = 1; + + // Make Bob the payer for the default invoice + vm.startPrank({ msgSender: users.bob }); + + // Expect the call to be reverted with the {PaymentAmountLessThanInvoiceValue} error + vm.expectRevert( + abi.encodeWithSelector( + Errors.PaymentAmountLessThanInvoiceValue.selector, + invoices[invoiceId].payment.amount + ) + ); + + // Run the test + invoiceModule.payInvoice{ value: invoices[invoiceId].payment.amount - 1 }({ id: invoiceId }); + } + + function test_RevertWhen_PaymentMethodTransfer_NativeTokenTransferFails() + external + whenInvoiceNotNull + whenInvoiceNotAlreadyPaid + whenInvoiceNotCanceled + givenPaymentMethodTransfer + givenPaymentAmountInNativeToken + whenPaymentAmountEqualToInvoiceValue + { + // Create a mock invoice with a one-off ETH transfer and set {MockBadReceiver} as the recipient + Types.Invoice memory invoice = createInvoiceWithOneOffTransfer({ + asset: address(0), + recipient: address(mockBadReceiver) + }); + executeCreateInvoice({ invoice: invoice, user: users.eve }); + + // Make {MockBadReceiver} the payer for this invoice + vm.startPrank({ msgSender: users.bob }); + + // Expect the call to be reverted with the {NativeTokenPaymentFailed} error + vm.expectRevert(Errors.NativeTokenPaymentFailed.selector); + + // Run the test + invoiceModule.payInvoice{ value: invoice.payment.amount }({ id: 5 }); + } + + function test_PayInvoice_PaymentMethodTransfer_NativeToken_OneOff() + external + whenInvoiceNotNull + whenInvoiceNotAlreadyPaid + whenInvoiceNotCanceled + givenPaymentMethodTransfer + givenPaymentAmountInNativeToken + whenPaymentAmountEqualToInvoiceValue + whenNativeTokenPaymentSucceeds + { + // Set the one-off ETH transfer invoice as current one + uint256 invoiceId = 1; + + // Make Bob the payer for the default invoice + vm.startPrank({ msgSender: users.bob }); + + // Store the ETH balances of Bob and recipient before paying the invoice + uint256 balanceOfBobBefore = address(users.bob).balance; + uint256 balanceOfRecipientBefore = address(invoices[invoiceId].recipient).balance; + + // Expect the {InvoicePaid} event to be emitted + vm.expectEmit(); + emit Events.InvoicePaid({ + id: invoiceId, + payer: users.bob, + status: Types.Status.Paid, + payment: Types.Payment({ + method: invoices[invoiceId].payment.method, + recurrence: invoices[invoiceId].payment.recurrence, + paymentsLeft: 0, + asset: invoices[invoiceId].payment.asset, + amount: invoices[invoiceId].payment.amount, + streamId: 0 + }) + }); + + // Run the test + invoiceModule.payInvoice{ value: invoices[invoiceId].payment.amount }({ id: invoiceId }); + + // Assert the actual and the expected state of the invoice + Types.Invoice memory invoice = invoiceModule.getInvoice({ id: invoiceId }); + assertEq(uint8(invoice.status), uint8(Types.Status.Paid)); + assertEq(invoice.payment.paymentsLeft, 0); + + // Assert the balances of payer and recipient + assertEq(address(users.bob).balance, balanceOfBobBefore - invoices[invoiceId].payment.amount); + assertEq( + address(invoices[invoiceId].recipient).balance, + balanceOfRecipientBefore + invoices[invoiceId].payment.amount + ); + } + + function test_PayInvoice_PaymentMethodTransfer_ERC20Token_Recurring() + external + whenInvoiceNotNull + whenInvoiceNotAlreadyPaid + whenInvoiceNotCanceled + givenPaymentMethodTransfer + givenPaymentAmountInERC20Tokens + whenPaymentAmountEqualToInvoiceValue + { + // Set the recurring USDT transfer invoice as current one + uint256 invoiceId = 2; + + // Make Bob the payer for the default invoice + vm.startPrank({ msgSender: users.bob }); + + // Store the USDT balances of Bob and recipient before paying the invoice + uint256 balanceOfBobBefore = usdt.balanceOf(users.bob); + uint256 balanceOfRecipientBefore = usdt.balanceOf(invoices[invoiceId].recipient); + + // Approve the {InvoiceModule} to transfer the ERC-20 tokens on Bob's behalf + usdt.approve({ spender: address(invoiceModule), amount: invoices[invoiceId].payment.amount }); + + // Expect the {InvoicePaid} event to be emitted + vm.expectEmit(); + emit Events.InvoicePaid({ + id: invoiceId, + payer: users.bob, + status: Types.Status.Ongoing, + payment: Types.Payment({ + method: invoices[invoiceId].payment.method, + recurrence: invoices[invoiceId].payment.recurrence, + paymentsLeft: 3, + asset: invoices[invoiceId].payment.asset, + amount: invoices[invoiceId].payment.amount, + streamId: 0 + }) + }); + + // Run the test + invoiceModule.payInvoice{ value: invoices[invoiceId].payment.amount }({ id: invoiceId }); + + // Assert the actual and the expected state of the invoice + Types.Invoice memory invoice = invoiceModule.getInvoice({ id: invoiceId }); + assertEq(uint8(invoice.status), uint8(Types.Status.Ongoing)); + assertEq(invoice.payment.paymentsLeft, 3); + + // Assert the balances of payer and recipient + assertEq(usdt.balanceOf(users.bob), balanceOfBobBefore - invoices[invoiceId].payment.amount); + assertEq( + usdt.balanceOf(invoices[invoiceId].recipient), + balanceOfRecipientBefore + invoices[invoiceId].payment.amount + ); + } + + function test_PayInvoice_PaymentMethodLinearStream() + external + whenInvoiceNotNull + whenInvoiceNotAlreadyPaid + whenInvoiceNotCanceled + givenPaymentMethodLinearStream + givenPaymentAmountInERC20Tokens + whenPaymentAmountEqualToInvoiceValue + { + // Set the linear USDT stream-based invoice as current one + uint256 invoiceId = 3; + + // Make Bob the payer for the default invoice + vm.startPrank({ msgSender: users.bob }); + + // Approve the {InvoiceModule} to transfer the ERC-20 tokens on Bob's behalf + usdt.approve({ spender: address(invoiceModule), amount: invoices[invoiceId].payment.amount }); + + // Expect the {InvoicePaid} event to be emitted + vm.expectEmit(); + emit Events.InvoicePaid({ + id: invoiceId, + payer: users.bob, + status: Types.Status.Paid, + payment: Types.Payment({ + method: invoices[invoiceId].payment.method, + recurrence: invoices[invoiceId].payment.recurrence, + paymentsLeft: 0, + asset: invoices[invoiceId].payment.asset, + amount: invoices[invoiceId].payment.amount, + streamId: 1 + }) + }); + + // Run the test + invoiceModule.payInvoice{ value: invoices[invoiceId].payment.amount }({ id: invoiceId }); + + // Assert the actual and the expected state of the invoice + Types.Invoice memory invoice = invoiceModule.getInvoice({ id: invoiceId }); + assertEq(uint8(invoice.status), uint8(Types.Status.Paid)); + assertEq(invoice.payment.streamId, 1); + assertEq(invoice.payment.paymentsLeft, 0); + + // Assert the actual and the expected state of the Sablier v2 linear stream + LockupLinear.StreamLL memory stream = invoiceModule.getLinearStream({ streamId: 1 }); + assertEq(stream.sender, users.bob); + assertEq(stream.recipient, users.eve); + assertEq(address(stream.asset), address(usdt)); + assertEq(stream.startTime, invoice.startTime); + assertEq(stream.endTime, invoice.endTime); + } + + function test_PayInvoice_PaymentMethodTranchedStream() + external + whenInvoiceNotNull + whenInvoiceNotAlreadyPaid + whenInvoiceNotCanceled + givenPaymentMethodTranchedStream + givenPaymentAmountInERC20Tokens + whenPaymentAmountEqualToInvoiceValue + { + // Set the tranched USDT stream-based invoice as current one + uint256 invoiceId = 4; + + // Make Bob the payer for the default invoice + vm.startPrank({ msgSender: users.bob }); + + // Approve the {InvoiceModule} to transfer the ERC-20 tokens on Bob's behalf + usdt.approve({ spender: address(invoiceModule), amount: invoices[invoiceId].payment.amount }); + + // Expect the {InvoicePaid} event to be emitted + vm.expectEmit(); + emit Events.InvoicePaid({ + id: invoiceId, + payer: users.bob, + status: Types.Status.Paid, + payment: Types.Payment({ + method: invoices[invoiceId].payment.method, + recurrence: invoices[invoiceId].payment.recurrence, + paymentsLeft: 0, + asset: invoices[invoiceId].payment.asset, + amount: invoices[invoiceId].payment.amount, + streamId: 1 + }) + }); + + // Run the test + invoiceModule.payInvoice{ value: invoices[invoiceId].payment.amount }({ id: invoiceId }); + + // Assert the actual and the expected state of the invoice + Types.Invoice memory invoice = invoiceModule.getInvoice({ id: invoiceId }); + assertEq(uint8(invoice.status), uint8(Types.Status.Paid)); + assertEq(invoice.payment.streamId, 1); + assertEq(invoice.payment.paymentsLeft, 0); + + // Assert the actual and the expected state of the Sablier v2 tranched stream + LockupTranched.StreamLT memory stream = invoiceModule.getTranchedStream({ streamId: 1 }); + assertEq(stream.sender, users.bob); + assertEq(stream.recipient, users.eve); + assertEq(address(stream.asset), address(usdt)); + assertEq(stream.startTime, invoice.startTime); + assertEq(stream.endTime, invoice.endTime); + assertEq(stream.tranches.length, 4); + } +} diff --git a/test/integration/concrete/invoice-module/pay-invoice/payInvoice.tree b/test/integration/concrete/invoice-module/pay-invoice/payInvoice.tree new file mode 100644 index 00000000..c42e37e3 --- /dev/null +++ b/test/integration/concrete/invoice-module/pay-invoice/payInvoice.tree @@ -0,0 +1,38 @@ +payInvoice.t.sol +├── when the invoice IS null +│ └── it should revert with the {InvoiceNull} error +└── when the invoice IS NOT null + ├── when the invoice IS already paid + │ └── it should revert with the {InvoiceAlreadyPaid} error + └── when the invoice IS NOT already paid + ├── when the invoice IS canceled + │ └── it should revert with the {InvoiceCanceled} error + └── when the invoice IS NOT canceled + ├── given the payment method is transfer + │ ├── given the payment amount is in native token (ETH) + │ │ ├── when the payment amount is less than the invoice value + │ │ │ └── it should revert with the {PaymentAmountLessThanInvoiceValue} error + │ │ └── when the payment amount IS equal to the invoice value + │ │ ├── when the native token transfer fails + │ │ │ └── it should revert with the {NativeTokenPaymentFailed} error + │ │ └── when the native token transfer succeeds + │ │ ├── given the payment method is a one-off transfer + │ │ │ ├── it should update the invoice status to Paid + │ │ │ └── it should decrease the number of payments to zero + │ │ ├── given the payment method is a recurring transfer + │ │ │ ├── it should update the invoice status to Ongoing + │ │ │ └── it should decrease the number of payments + │ │ ├── it should transfer the payment amount to the invoice recipient + │ │ └── it should emit an {InvoicePaid} event + │ └── given the payment amount is in an ERC-20 token + │ ├── it should transfer the payment amount to the invoice recipient + │ └── it should emit an {InvoicePaid} event + ├── given the payment method is linear stream + │ ├── it should create a Sablier v2 linear stream + │ ├── it should update the invoice stream ID + │ └── it should emit an {InvoicePaid} event + └── given the payment method is tranched stream + ├── it should create a Sablier v2 tranched stream + ├── it should update the invoice stream ID + └── it should emit an {InvoicePaid} event + diff --git a/test/integration/shared/InvoiceModule.t.sol b/test/integration/shared/createInvoice.t.sol similarity index 64% rename from test/integration/shared/InvoiceModule.t.sol rename to test/integration/shared/createInvoice.t.sol index 7ffdd4ee..bfb5f7f7 100644 --- a/test/integration/shared/InvoiceModule.t.sol +++ b/test/integration/shared/createInvoice.t.sol @@ -4,14 +4,9 @@ pragma solidity ^0.8.26; import { Integration_Test } from "../Integration.t.sol"; import { Types } from "./../../../src/modules/invoice-module/libraries/Types.sol"; -abstract contract InvoiceModule_Integration_Shared_Test is Integration_Test { - Types.Invoice invoice; - +abstract contract CreateInvoice_Integration_Shared_Test is Integration_Test { function setUp() public virtual override { Integration_Test.setUp(); - - invoice.recipient = users.eve; - invoice.status = Types.Status.Pending; } modifier whenCallerContract() { @@ -63,7 +58,13 @@ abstract contract InvoiceModule_Integration_Shared_Test is Integration_Test { } /// @dev Creates an invoice with a one-off transfer payment - function createInvoiceWithOneOffTransfer() internal { + function createInvoiceWithOneOffTransfer( + address asset, + address recipient + ) internal view returns (Types.Invoice memory invoice) { + invoice.recipient = recipient; + invoice.status = Types.Status.Pending; + invoice.startTime = uint40(block.timestamp); invoice.endTime = uint40(block.timestamp) + 4 weeks; @@ -71,14 +72,20 @@ abstract contract InvoiceModule_Integration_Shared_Test is Integration_Test { method: Types.Method.Transfer, recurrence: Types.Recurrence.OneOff, paymentsLeft: 1, - asset: address(usdt), + asset: asset, amount: 100e18, streamId: 0 }); } /// @dev Creates an invoice with a recurring transfer payment - function createInvoiceWithRecurringTransfer(Types.Recurrence recurrence) internal { + function createInvoiceWithRecurringTransfer( + Types.Recurrence recurrence, + address recipient + ) internal view returns (Types.Invoice memory invoice) { + invoice.recipient = recipient; + invoice.status = Types.Status.Pending; + invoice.startTime = uint40(block.timestamp); invoice.endTime = uint40(block.timestamp) + 4 weeks; @@ -93,7 +100,10 @@ abstract contract InvoiceModule_Integration_Shared_Test is Integration_Test { } /// @dev Creates an invoice with a linear stream-based payment - function createInvoiceWithLinearStream() internal { + function createInvoiceWithLinearStream(address recipient) internal view returns (Types.Invoice memory invoice) { + invoice.recipient = recipient; + invoice.status = Types.Status.Pending; + invoice.startTime = uint40(block.timestamp); invoice.endTime = uint40(block.timestamp) + 4 weeks; @@ -108,7 +118,13 @@ abstract contract InvoiceModule_Integration_Shared_Test is Integration_Test { } /// @dev Creates an invoice with a tranched stream-based payment - function createInvoiceWithTranchedStream(Types.Recurrence recurrence) internal { + function createInvoiceWithTranchedStream( + Types.Recurrence recurrence, + address recipient + ) internal view returns (Types.Invoice memory invoice) { + invoice.recipient = recipient; + invoice.status = Types.Status.Pending; + invoice.startTime = uint40(block.timestamp); invoice.endTime = uint40(block.timestamp) + 4 weeks; @@ -121,4 +137,19 @@ abstract contract InvoiceModule_Integration_Shared_Test is Integration_Test { streamId: 0 }); } + + function executeCreateInvoice(Types.Invoice memory invoice, address user) public { + // Make the `user` account the caller who must be the owner of the {Container} contract + vm.startPrank({ msgSender: user }); + + // Create the invoice + bytes memory data = abi.encodeWithSignature( + "createInvoice((address,uint8,uint40,uint40,(uint8,uint8,uint40,address,uint128,uint256)))", + invoice + ); + container.execute({ module: address(invoiceModule), value: 0, data: data }); + + // Stop the active prank + vm.stopPrank(); + } } diff --git a/test/integration/shared/payInvoice.t.sol b/test/integration/shared/payInvoice.t.sol new file mode 100644 index 00000000..bf01a994 --- /dev/null +++ b/test/integration/shared/payInvoice.t.sol @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.26; + +import { Integration_Test } from "../Integration.t.sol"; +import { CreateInvoice_Integration_Shared_Test } from "./createInvoice.t.sol"; +import { Types } from "./../../../src/modules/invoice-module/libraries/Types.sol"; + +abstract contract PayInvoice_Integration_Shared_Test is Integration_Test, CreateInvoice_Integration_Shared_Test { + mapping(uint256 invoiceId => Types.Invoice) invoices; + + function setUp() public virtual override(Integration_Test, CreateInvoice_Integration_Shared_Test) { + CreateInvoice_Integration_Shared_Test.setUp(); + } + + modifier whenInvoiceNotNull() { + _; + } + + modifier whenInvoiceNotAlreadyPaid() { + _; + } + + modifier whenInvoiceNotCanceled() { + _; + } + + modifier givenPaymentMethodTransfer() { + _; + } + + modifier givenPaymentAmountInNativeToken() { + _; + } + + modifier givenPaymentAmountInERC20Tokens() { + _; + } + + modifier whenPaymentAmountEqualToInvoiceValue() { + _; + } + + modifier whenNativeTokenPaymentSucceeds() { + _; + } +} diff --git a/test/mocks/MockBadReceiver.sol b/test/mocks/MockBadReceiver.sol new file mode 100644 index 00000000..3c166c17 --- /dev/null +++ b/test/mocks/MockBadReceiver.sol @@ -0,0 +1,8 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.26; + +contract MockBadReceiver { + receive() external payable { + revert(); + } +} diff --git a/test/utils/Errors.sol b/test/utils/Errors.sol index 4fedf6a0..13ce956b 100644 --- a/test/utils/Errors.sol +++ b/test/utils/Errors.sol @@ -70,6 +70,12 @@ library Errors { /// @notice Thrown when a payer attempts to pay a canceled invoice error InvoiceCanceled(); + /// @notice Thrown when the invoice ID references a null invoice + error InvoiceNull(); + + /// @notice Thrown when `msg.sender` is not the creator (recipient) of the invoice + error InvoiceOwnerUnauthorized(); + /// @notice Thrown when the payment interval (endTime - startTime) is too short for the selected recurrence /// i.e. recurrence is set to weekly but interval is shorter than 1 week error PaymentIntervalTooShortForSelectedRecurrence(); diff --git a/test/utils/Events.sol b/test/utils/Events.sol index c10c6c7c..bd06a95f 100644 --- a/test/utils/Events.sol +++ b/test/utils/Events.sol @@ -58,4 +58,15 @@ abstract contract Events { uint40 endTime, Types.Payment payment ); + + /// @notice Emitted when an invoice is paid + /// @param id The ID of the invoice + /// @param payer The address of the payer + /// @param status The status of the invoice + /// @param payment Struct representing the payment details associated with the invoice + event InvoicePaid(uint256 indexed id, address indexed payer, Types.Status status, Types.Payment payment); + + /// @notice Emitted when an invoice is canceled + /// @param id The ID of the invoice + event InvoiceCanceled(uint256 indexed id); } diff --git a/test/utils/Helpers.sol b/test/utils/Helpers.sol index eb600ce5..8367a7c3 100644 --- a/test/utils/Helpers.sol +++ b/test/utils/Helpers.sol @@ -1,19 +1,19 @@ // SPDX-License-Identifier: MIT pragma solidity ^0.8.26; -import { Types as InvoiceModuleTypes } from "./../../src/modules/invoice-module/libraries/Types.sol"; +import { Types } from "./../../src/modules/invoice-module/libraries/Types.sol"; library Helpers { - function createInvoiceDataType(address recipient) public view returns (InvoiceModuleTypes.Invoice memory) { + function createInvoiceDataType(address recipient) public view returns (Types.Invoice memory) { return - InvoiceModuleTypes.Invoice({ + Types.Invoice({ recipient: recipient, - status: InvoiceModuleTypes.Status.Pending, + status: Types.Status.Pending, startTime: 0, endTime: uint40(block.timestamp) + 1 weeks, - payment: InvoiceModuleTypes.Payment({ - method: InvoiceModuleTypes.Method.Transfer, - recurrence: InvoiceModuleTypes.Recurrence.OneOff, + payment: Types.Payment({ + method: Types.Method.Transfer, + recurrence: Types.Recurrence.OneOff, paymentsLeft: 1, asset: address(0), amount: uint128(1 ether), @@ -24,14 +24,14 @@ library Helpers { /// @dev Calculates the number of payments that must be done based on a Recurring invoice function computeNumberOfRecurringPayments( - InvoiceModuleTypes.Recurrence recurrence, + Types.Recurrence recurrence, uint40 interval ) internal pure returns (uint40 numberOfPayments) { - if (recurrence == InvoiceModuleTypes.Recurrence.Weekly) { + if (recurrence == Types.Recurrence.Weekly) { numberOfPayments = interval / 1 weeks; - } else if (recurrence == InvoiceModuleTypes.Recurrence.Monthly) { + } else if (recurrence == Types.Recurrence.Monthly) { numberOfPayments = interval / 4 weeks; - } else if (recurrence == InvoiceModuleTypes.Recurrence.Yearly) { + } else if (recurrence == Types.Recurrence.Yearly) { numberOfPayments = interval / 48 weeks; } }