diff --git a/script/Deploy.s.sol b/script/Deploy.s.sol index 63f7fef..8d2191e 100644 --- a/script/Deploy.s.sol +++ b/script/Deploy.s.sol @@ -3,6 +3,7 @@ pragma solidity 0.8.26; import "@forge-std/Script.sol"; import {TransparentUpgradeableProxy} from "@openzeppelin/contracts/proxy/transparent/TransparentUpgradeableProxy.sol"; +import {IERC20Metadata} from "@openzeppelin/contracts/token/ERC20/extensions/IERC20Metadata.sol"; import {RewardsDistributor} from "src/RewardsDistributor.sol"; import {DelegateStaking} from "src/DelegateStaking.sol"; import {Staking} from "src/Staking.sol"; @@ -42,6 +43,8 @@ contract Deploy is Script { MIN_STAKE ); + IERC20Metadata(STAKING_TOKEN).approve(address(stakingProxy), 1000e18); + DelegateStaking delegate = new DelegateStaking(); delegateProxy = DelegateStaking( address( @@ -53,6 +56,8 @@ contract Deploy is Script { ) ); + IERC20Metadata(STAKING_TOKEN).approve(address(delegateProxy), 1000e18); + delegateProxy.initialize( CONTRACT_OWNER, STAKING_TOKEN, diff --git a/src/BaseStaking.sol b/src/BaseStaking.sol index a7e2b6a..2a5f443 100644 --- a/src/BaseStaking.sol +++ b/src/BaseStaking.sol @@ -1,6 +1,7 @@ // SPDX-License-Identifier: MIT pragma solidity 0.8.26; +import {console} from "@forge-std/console.sol"; import {OwnableUpgradeable} from "@openzeppelin-upgradeable/contracts/access/OwnableUpgradeable.sol"; import {ERC20VotesUpgradeable} from "@openzeppelin-upgradeable/contracts/token/ERC20/extensions/ERC20VotesUpgradeable.sol"; import {EnumerableSetLib} from "@solady/utils/EnumerableSetLib.sol"; @@ -180,7 +181,12 @@ abstract contract BaseStaking is OwnableUpgradeable, ERC20VotesUpgradeable { function convertToShares( uint256 assets ) public view virtual returns (uint256) { - return assets.mulDivDown(totalSupply(), _totalAssets()); + console.log("totoal supply", totalSupply()); + console.log("total assets", _totalAssets()); + console.log("assets", assets); + uint256 supply = totalSupply(); // Saves an extra SLOAD if totalSupply is non-zero. + + return supply == 0 ? assets : assets.mulDivDown(supply, _totalAssets()); } /// @notice Get the total amount of assets the shares are worth @@ -188,7 +194,9 @@ abstract contract BaseStaking is OwnableUpgradeable, ERC20VotesUpgradeable { function convertToAssets( uint256 shares ) public view virtual returns (uint256) { - return shares.mulDivDown(_totalAssets(), totalSupply()); + uint256 supply = totalSupply(); // Saves an extra SLOAD if totalSupply is non-zero. + + return supply == 0 ? shares : shares.mulDivDown(_totalAssets(), supply); } /// @notice Get the stake ids belonging to a user @@ -208,20 +216,20 @@ abstract contract BaseStaking is OwnableUpgradeable, ERC20VotesUpgradeable { /// @notice Deposit SHU into the contract /// @param amount The amount of SHU to deposit - function _deposit(uint256 amount) internal { + function _deposit(address to, uint256 amount) internal { // Calculate the amount of shares to mint uint256 shares = convertToShares(amount); // Update the total locked amount unchecked { - totalLocked[msg.sender] += amount; + totalLocked[to] += amount; } // Mint the shares - _mint(msg.sender, shares); + _mint(to, shares); // Lock the SHU in the contract - stakingToken.safeTransferFrom(msg.sender, address(this), amount); + stakingToken.safeTransferFrom(to, address(this), amount); } /// @notice Withdraw SHU from the contract @@ -251,7 +259,9 @@ abstract contract BaseStaking is OwnableUpgradeable, ERC20VotesUpgradeable { /// @notice Get the amount of shares that will be burned /// @param assets The amount of assets function _previewWithdraw(uint256 assets) internal view returns (uint256) { - return assets.mulDivUp(totalSupply(), _totalAssets()); + uint256 supply = totalSupply(); // Saves an extra SLOAD if totalSupply is non-zero. + + return supply == 0 ? assets : assets.mulDivUp(supply, _totalAssets()); } /// @notice Calculates the amount to withdraw diff --git a/src/DelegateStaking.sol b/src/DelegateStaking.sol index a5d309d..9d3f35f 100644 --- a/src/DelegateStaking.sol +++ b/src/DelegateStaking.sol @@ -141,6 +141,18 @@ contract DelegateStaking is BaseStaking { lockPeriod = _lockPeriod; nextStakeId = 1; + + // mint dead shares to avoid inflation attack + uint256 amount = 1000e18; + + // Calculate the amount of shares to mint + uint256 shares = convertToShares(amount); + + // Mint the shares to the vault + _mint(address(this), shares); + + // Transfer the SHU to the vault + stakingToken.safeTransferFrom(msg.sender, address(this), amount); } /// @notice Stake SHU @@ -158,12 +170,10 @@ contract DelegateStaking is BaseStaking { require(staking.keypers(keyper), AddressIsNotAKeyper()); - address user = msg.sender; - stakeId = nextStakeId++; // Add the stake id to the user stakes - userStakes[user].add(stakeId); + userStakes[msg.sender].add(stakeId); // Add the stake to the stakes mapping stakes[stakeId].keyper = keyper; @@ -176,9 +186,9 @@ contract DelegateStaking is BaseStaking { totalDelegated[keyper] += amount; } - _deposit(amount); + _deposit(msg.sender, amount); - emit Staked(user, keyper, amount, lockPeriod); + emit Staked(msg.sender, keyper, amount, lockPeriod); } /// @notice Unstake SHU diff --git a/src/Staking.sol b/src/Staking.sol index e51888d..35e43af 100644 --- a/src/Staking.sol +++ b/src/Staking.sol @@ -141,8 +141,16 @@ contract Staking is BaseStaking { nextStakeId = 1; - // TODO find the correct value here - _mint(address(0), 1e18); + // mint dead shares to avoid inflation attack + uint256 amount = 1000e18; + // Calculate the amount of shares to mint + uint256 shares = convertToShares(amount); + + // Mint the shares to the vault + _mint(address(this), shares); + + // Lock the SHU in the contract + stakingToken.safeTransferFrom(msg.sender, address(this), amount); } /// @notice Stake SHU @@ -179,7 +187,7 @@ contract Staking is BaseStaking { stakes[stakeId].timestamp = block.timestamp; stakes[stakeId].lockPeriod = lockPeriod; - _deposit(amount); + _deposit(msg.sender, amount); emit Staked(msg.sender, amount, lockPeriod); } diff --git a/test/DelegateStaking.t.sol b/test/DelegateStaking.t.sol index 44a5705..ae4d949 100644 --- a/test/DelegateStaking.t.sol +++ b/test/DelegateStaking.t.sol @@ -35,7 +35,6 @@ contract DelegateStakingTest is Test { _jumpAhead(1234); govToken = new MockGovToken(); - _mintGovToken(address(this), 100_000_000e18); vm.label(address(govToken), "govToken"); // deploy rewards distributor @@ -54,6 +53,8 @@ contract DelegateStakingTest is Test { ); vm.label(address(staking), "staking"); + _mintGovToken(address(this), 2000e18); + govToken.approve(address(staking), 1000e18); staking.initialize( address(this), // owner address(govToken), @@ -72,6 +73,8 @@ contract DelegateStakingTest is Test { ); vm.label(address(delegate), "delegate"); + govToken.approve(address(delegate), 1000e18); + delegate.initialize( address(this), // owner address(govToken), @@ -86,7 +89,7 @@ contract DelegateStakingTest is Test { ); // fund reward distribution - govToken.transfer(address(rewardsDistributor), 100_000_000e18); + _mintGovToken(address(rewardsDistributor), 100_000_000e18); } function _setKeyper(address _keyper, bool _isKeyper) internal { diff --git a/test/Staking.t.sol b/test/Staking.t.sol index 70cbf54..b9b4369 100644 --- a/test/Staking.t.sol +++ b/test/Staking.t.sol @@ -7,6 +7,8 @@ import {IERC20Metadata} from "@openzeppelin/contracts/token/ERC20/extensions/IER import {TransparentUpgradeableProxy, ITransparentUpgradeableProxy} from "@openzeppelin/contracts/proxy/transparent/TransparentUpgradeableProxy.sol"; import {Ownable} from "@openzeppelin/contracts/access/Ownable.sol"; +import {Create2} from "@openzeppelin/contracts/utils/Create2.sol"; + import {FixedPointMathLib} from "src/libraries/FixedPointMathLib.sol"; import {Staking} from "src/Staking.sol"; import {BaseStaking} from "src/BaseStaking.sol"; @@ -33,7 +35,6 @@ contract StakingTest is Test { _jumpAhead(1234); govToken = new MockGovToken(); - _mintGovToken(address(this), 100_000_000e18); vm.label(address(govToken), "govToken"); // deploy rewards distributor @@ -51,6 +52,10 @@ contract StakingTest is Test { ); vm.label(address(staking), "staking"); + _mintGovToken(address(this), 1000e18); + + govToken.approve(address(staking), 1000e18); + staking.initialize( address(this), // owner address(govToken), @@ -65,7 +70,8 @@ contract StakingTest is Test { ); // fund reward distribution - govToken.transfer(address(rewardsDistributor), 100_000_000e18); + + _mintGovToken(address(rewardsDistributor), 100_000_000e18); } function _jumpAhead(uint256 _seconds) public {