diff --git a/solidity/contracts/lib/ERC4626Fees.sol b/solidity/contracts/lib/ERC4626Fees.sol index 5f8c2391c..7ac316a92 100644 --- a/solidity/contracts/lib/ERC4626Fees.sol +++ b/solidity/contracts/lib/ERC4626Fees.sol @@ -84,6 +84,13 @@ abstract contract ERC4626Fees is ERC4626Upgradeable { } } + /// @dev Calculate the maximum amount of assets that can be withdrawn + /// by an account including fees. See {IERC4626-maxWithdraw}. + function _maxWithdraw(address account) internal view returns (uint256) { + uint256 maxAssets = super.maxWithdraw(account); + return maxAssets - _feeOnTotal(maxAssets, _exitFeeBasisPoints()); + } + // === Fee configuration === // slither-disable-next-line dead-code diff --git a/solidity/contracts/stBTC.sol b/solidity/contracts/stBTC.sol index d0815c4be..da077b9cf 100644 --- a/solidity/contracts/stBTC.sol +++ b/solidity/contracts/stBTC.sol @@ -331,7 +331,7 @@ contract stBTC is ERC4626Fees, PausableOwnable { if (paused()) { return 0; } - return super.maxWithdraw(owner); + return _maxWithdraw(owner); } /// @dev Returns the maximum amount of Vault shares that can be redeemed from diff --git a/solidity/test/stBTC.test.ts b/solidity/test/stBTC.test.ts index 74fa0aa73..fed6991b0 100644 --- a/solidity/test/stBTC.test.ts +++ b/solidity/test/stBTC.test.ts @@ -2172,6 +2172,50 @@ describe("stBTC", () => { }) }) + describe("maxWithdraw", () => { + beforeAfterSnapshotWrapper() + const amountToDeposit = to1e18(1) + let expectedDepositedAmount: bigint + let expectedWithdrawnAmount: bigint + + before(async () => { + await tbtc + .connect(depositor1) + .approve(await stbtc.getAddress(), amountToDeposit) + await stbtc + .connect(depositor1) + .deposit(amountToDeposit, depositor1.address) + expectedDepositedAmount = + amountToDeposit - feeOnTotal(amountToDeposit, entryFeeBasisPoints) + expectedWithdrawnAmount = + expectedDepositedAmount - + feeOnTotal(expectedDepositedAmount, exitFeeBasisPoints) + }) + + it("should account for the exit fee", async () => { + const maxWithdraw = await stbtc.maxWithdraw(depositor1.address) + + expect(maxWithdraw).to.be.eq(expectedWithdrawnAmount) + }) + + it("should be equal to the actual redeemable amount", async () => { + const maxWithdraw = await stbtc.maxWithdraw(depositor1.address) + const availableShares = await stbtc.balanceOf(depositor1.address) + + const tx = await stbtc.redeem( + availableShares, + depositor1.address, + depositor1.address, + ) + + await expect(tx).to.changeTokenBalances( + tbtc, + [depositor1.address], + [maxWithdraw], + ) + }) + }) + describe("feeOnTotal - internal test helper", () => { context("when the fee's modulo remainder is greater than 0", () => { it("should add 1 to the result", () => {