diff --git a/script/Deploy.s.sol b/script/Deploy.s.sol index 6117d2c..63f7fef 100644 --- a/script/Deploy.s.sol +++ b/script/Deploy.s.sol @@ -4,13 +4,18 @@ pragma solidity 0.8.26; import "@forge-std/Script.sol"; import {TransparentUpgradeableProxy} from "@openzeppelin/contracts/proxy/transparent/TransparentUpgradeableProxy.sol"; import {RewardsDistributor} from "src/RewardsDistributor.sol"; +import {DelegateStaking} from "src/DelegateStaking.sol"; import {Staking} from "src/Staking.sol"; import "./Constants.sol"; contract Deploy is Script { function run() public - returns (Staking stakingProxy, RewardsDistributor rewardsDistributor) + returns ( + Staking stakingProxy, + RewardsDistributor rewardsDistributor, + DelegateStaking delegateProxy + ) { vm.startBroadcast(); @@ -37,6 +42,25 @@ contract Deploy is Script { MIN_STAKE ); + DelegateStaking delegate = new DelegateStaking(); + delegateProxy = DelegateStaking( + address( + new TransparentUpgradeableProxy( + address(delegate), + address(CONTRACT_OWNER), + "" + ) + ) + ); + + delegateProxy.initialize( + CONTRACT_OWNER, + STAKING_TOKEN, + address(rewardsDistributor), + address(stakingProxy), + LOCK_PERIOD + ); + vm.stopBroadcast(); } } diff --git a/script/testnet/DeployTestnet.s.sol b/script/testnet/DeployTestnet.s.sol index ece1158..a99b01d 100644 --- a/script/testnet/DeployTestnet.s.sol +++ b/script/testnet/DeployTestnet.s.sol @@ -5,6 +5,7 @@ import "@forge-std/Script.sol"; import {TransparentUpgradeableProxy} from "@openzeppelin/contracts/proxy/transparent/TransparentUpgradeableProxy.sol"; import {RewardsDistributor} from "src/RewardsDistributor.sol"; import {Staking} from "src/Staking.sol"; +import {DelegateStaking} from "src/DelegateStaking.sol"; import {MockGovToken} from "test/mocks/MockGovToken.sol"; import "./Constants.sol"; @@ -12,7 +13,11 @@ import "./Constants.sol"; contract DeployTestnet is Script { function run() public - returns (Staking stakingProxy, RewardsDistributor rewardsDistributor) + returns ( + Staking stakingProxy, + RewardsDistributor rewardsDistributor, + DelegateStaking delegateProxy + ) { vm.startBroadcast(); @@ -42,6 +47,25 @@ contract DeployTestnet is Script { MIN_STAKE ); + DelegateStaking delegate = new DelegateStaking(); + delegateProxy = DelegateStaking( + address( + new TransparentUpgradeableProxy( + address(delegate), + address(CONTRACT_OWNER), + "" + ) + ) + ); + + delegateProxy.initialize( + CONTRACT_OWNER, + MOCKED_SHU, + address(rewardsDistributor), + address(stakingProxy), + LOCK_PERIOD + ); + vm.stopBroadcast(); } } diff --git a/src/BaseStaking.sol b/src/BaseStaking.sol index a83cc01..32e6ffd 100644 --- a/src/BaseStaking.sol +++ b/src/BaseStaking.sol @@ -180,9 +180,7 @@ abstract contract BaseStaking is OwnableUpgradeable, ERC20VotesUpgradeable { function convertToShares( uint256 assets ) public view virtual returns (uint256) { - // sum + 1 on both sides to prevent donation attack - // this is the same as OZ ERC4626 prevetion to inflation attack with decimal offset = 0 - return assets.mulDivDown(totalSupply() + 1, _totalAssets() + 1); + return assets.mulDivDown(totalSupply(), _totalAssets()); } /// @notice Get the total amount of assets the shares are worth @@ -190,9 +188,7 @@ abstract contract BaseStaking is OwnableUpgradeable, ERC20VotesUpgradeable { function convertToAssets( uint256 shares ) public view virtual returns (uint256) { - // sum + 1 on both sides to prevent donation attack - // this is the same as OZ ERC4626 prevetion to inflation attack with decimal offset = 0 - return shares.mulDivDown(_totalAssets() + 1, totalSupply() + 1); + return shares.mulDivDown(_totalAssets(), totalSupply()); } /// @notice Get the stake ids belonging to a user @@ -213,24 +209,17 @@ abstract contract BaseStaking is OwnableUpgradeable, ERC20VotesUpgradeable { /// @notice Deposit SHU into the contract /// @param user The user address /// @param amount The amount of SHU to deposit - function _deposit(address user, uint256 amount) internal { + function _deposit(uint256 amount) internal { // Calculate the amount of shares to mint uint256 shares = convertToShares(amount); - // A first deposit donation attack may result in shares being 0 if the - // contract has very high assets balance but a very low total supply. - // Although this attack is not profitable for the attacker, as they will - // spend more tokens than they will receive, it can still be used to perform a DDOS attack - // against a specific user. The targeted user can still withdraw their SHU, - // but this is only guaranteed if someone mints to increase the total supply of shares, - // because previewWithdraw rounds up and their shares will be less than the burn amount. - require(shares > 0, SharesMustBeGreaterThanZero()); - // Update the total locked amount - totalLocked[user] += amount; + unchecked { + totalLocked[msg.sender] += amount; + } // Mint the shares - _mint(user, shares); + _mint(msg.sender, shares); // Lock the SHU in the contract stakingToken.safeTransferFrom(msg.sender, address(this), amount); @@ -263,9 +252,7 @@ abstract contract BaseStaking is OwnableUpgradeable, ERC20VotesUpgradeable { /// @notice Get the amount of shares that will be burned /// @param assets The amount of assets function _previewWithdraw(uint256 assets) internal view returns (uint256) { - // sum + 1 on both sides to prevent donation attack - // this is the same as OZ ERC4626 prevetion to inflation attack with decimal offset = 0 - return assets.mulDivUp(totalSupply() + 1, _totalAssets() + 1); + return assets.mulDivUp(totalSupply(), _totalAssets()); } /// @notice Calculates the amount to withdraw diff --git a/src/DelegateStaking.sol b/src/DelegateStaking.sol index b9e9767..a5d309d 100644 --- a/src/DelegateStaking.sol +++ b/src/DelegateStaking.sol @@ -172,9 +172,11 @@ contract DelegateStaking is BaseStaking { stakes[stakeId].lockPeriod = lockPeriod; // Increase the keyper total delegated amount - totalDelegated[keyper] += amount; + unchecked { + totalDelegated[keyper] += amount; + } - _deposit(user, amount); + _deposit(amount); emit Staked(user, keyper, amount, lockPeriod); } diff --git a/src/RewardsDistributor.sol b/src/RewardsDistributor.sol index d66db60..e33ef2a 100644 --- a/src/RewardsDistributor.sol +++ b/src/RewardsDistributor.sol @@ -75,7 +75,7 @@ contract RewardsDistributor is Ownable, IRewardsDistributor { /// @notice Distribute rewards to receiver /// Caller must be the receiver - function collectRewards() external override returns (uint256 rewards) { + function collectRewards() public override returns (uint256 rewards) { address receiver = msg.sender; RewardConfiguration storage rewardConfiguration = rewardConfigurations[ @@ -85,13 +85,11 @@ contract RewardsDistributor is Ownable, IRewardsDistributor { // difference in time since last update uint256 timeDelta = block.timestamp - rewardConfiguration.lastUpdate; - uint256 funds = rewardToken.balanceOf(address(this)); - rewards = rewardConfiguration.emissionRate * timeDelta; // the contract must have enough funds to distribute // we don't want to revert in case its zero to not block the staking contract - if (rewards == 0 || funds < rewards) { + if (rewards == 0 || rewardToken.balanceOf(address(this)) < rewards) { return 0; } @@ -108,7 +106,7 @@ contract RewardsDistributor is Ownable, IRewardsDistributor { /// @param receiver The receiver of the rewards function collectRewardsTo( address receiver - ) external override returns (uint256 rewards) { + ) public override returns (uint256 rewards) { RewardConfiguration storage rewardConfiguration = rewardConfigurations[ receiver ]; @@ -120,12 +118,14 @@ contract RewardsDistributor is Ownable, IRewardsDistributor { require(timeDelta > 0, TimeDeltaZero()); - uint256 funds = rewardToken.balanceOf(address(this)); - rewards = rewardConfiguration.emissionRate * timeDelta; // the contract must have enough funds to distribute - require(funds >= rewards, NotEnoughFunds()); + // and the rewards must be greater than zero + require( + rewards > 0 && rewardToken.balanceOf(address(this)) >= rewards, + NotEnoughFunds() + ); // update the last update timestamp rewardConfiguration.lastUpdate = block.timestamp; @@ -142,7 +142,7 @@ contract RewardsDistributor is Ownable, IRewardsDistributor { function setRewardConfiguration( address receiver, uint256 emissionRate - ) external override onlyOwner { + ) public override onlyOwner { require(receiver != address(0), ZeroAddress()); // to remove a rewards, it should call removeRewardConfiguration @@ -151,7 +151,11 @@ contract RewardsDistributor is Ownable, IRewardsDistributor { // only update last update if it's the first time if (rewardConfigurations[receiver].lastUpdate == 0) { rewardConfigurations[receiver].lastUpdate = block.timestamp; + } else { + // claim the rewards before updating the emission rate + collectRewardsTo(receiver); } + rewardConfigurations[receiver].emissionRate = emissionRate; emit RewardConfigurationSet(receiver, emissionRate); @@ -159,33 +163,40 @@ contract RewardsDistributor is Ownable, IRewardsDistributor { /// @notice Remove a reward configuration /// @param receiver The receiver of the rewards - function removeRewardConfiguration(address receiver) external onlyOwner { - delete rewardConfigurations[receiver]; + function removeRewardConfiguration(address receiver) public onlyOwner { + rewardConfigurations[receiver].lastUpdate = 0; + rewardConfigurations[receiver].emissionRate = 0; emit RewardConfigurationSet(receiver, 0); } - /// @notice Withdraw funds from the contract - /// @param to The address to withdraw to - /// @param amount The amount to withdraw - function withdrawFunds( - address to, - uint256 amount - ) public override onlyOwner { - rewardToken.safeTransfer(to, amount); - } - /// @notice Set the reward token /// @param _rewardToken The reward token - function setRewardToken(address _rewardToken) external onlyOwner { + function setRewardToken(address _rewardToken) public onlyOwner { require(_rewardToken != address(0), ZeroAddress()); // withdraw remaining old reward token - withdrawFunds(msg.sender, rewardToken.balanceOf(address(this))); + withdrawFunds( + address(rewardToken), + msg.sender, + rewardToken.balanceOf(address(this)) + ); // set the new reward token rewardToken = IERC20(_rewardToken); emit RewardTokenSet(_rewardToken); } + + /// @notice Withdraw funds from the contract + /// @param to The address to withdraw to + /// @param amount The amount to withdraw + function withdrawFunds( + address token, + address to, + uint256 amount + ) public onlyOwner { + require(to != address(0), ZeroAddress()); + IERC20(token).safeTransfer(to, amount); + } } diff --git a/src/Staking.sol b/src/Staking.sol index 3bf30e4..498ced9 100644 --- a/src/Staking.sol +++ b/src/Staking.sol @@ -1,7 +1,6 @@ // SPDX-License-Identifier: MIT pragma solidity 0.8.26; -import {ERC20VotesUpgradeable} from "@openzeppelin-upgradeable/contracts/token/ERC20/extensions/ERC20VotesUpgradeable.sol"; import {EnumerableSetLib} from "@solady/utils/EnumerableSetLib.sol"; import {BaseStaking} from "./BaseStaking.sol"; @@ -69,11 +68,15 @@ contract Staking is BaseStaking { //////////////////////////////////////////////////////////////*/ /// @notice stores the metadata associated with a given stake - mapping(uint256 id => Stake _stake) public stakes; + mapping(uint256 id => Stake stake) public stakes; /// @notice keypers mapping mapping(address keyper => bool isKeyper) public keypers; + /*////////////////////////////////////////////////////////////// + EVENTS + //////////////////////////////////////////////////////////////*/ + /// @notice Emitted when a keyper stakes SHU event Staked(address indexed user, uint256 amount, uint256 lockPeriod); @@ -137,14 +140,17 @@ contract Staking is BaseStaking { minStake = _minStake; nextStakeId = 1; + + // TODO find the correct value here + _mint(address(0), 1e18); } /// @notice Stake SHU - /// - first stake must be at least the minimum stake + /// - first stake must be at least the minimum stake amount /// - amount will be locked in the contract for the lock period /// - keyper must approve the contract to spend the SHU before staking /// - this function will mint sSHU to the keyper - //// - sSHU is non-transferable + /// - sSHU is non-transferable /// - only keypers can stake /// @param amount The amount of SHU to stake /// @return stakeId The index of the stake @@ -155,10 +161,8 @@ contract Staking is BaseStaking { require(amount > 0, ZeroAmount()); - address user = msg.sender; - // Get the keyper stakes - EnumerableSetLib.Uint256Set storage stakesIds = userStakes[user]; + EnumerableSetLib.Uint256Set storage stakesIds = userStakes[msg.sender]; // If the keyper has no stakes, the first stake must be at least the minimum stake if (stakesIds.length() == 0) { @@ -168,16 +172,16 @@ contract Staking is BaseStaking { stakeId = nextStakeId++; // Add the stake id to the user stakes - userStakes[user].add(stakeId); + userStakes[msg.sender].add(stakeId); // Add the stake to the stakes mapping stakes[stakeId].amount = amount; stakes[stakeId].timestamp = block.timestamp; stakes[stakeId].lockPeriod = lockPeriod; - _deposit(user, amount); + _deposit(amount); - emit Staked(user, amount, lockPeriod); + emit Staked(msg.sender, amount, lockPeriod); } /// @notice Unstake SHU @@ -205,7 +209,7 @@ contract Staking is BaseStaking { address keyper, uint256 stakeId, uint256 _amount - ) external updateRewards returns (uint256 amount) { + ) public updateRewards returns (uint256 amount) { require( userStakes[keyper].contains(stakeId), StakeDoesNotBelongToUser() @@ -223,18 +227,23 @@ contract Staking is BaseStaking { // Only the keyper can unstake require(msg.sender == keyper, OnlyKeyper()); - // If the lock period is less than the global lock period, the stake - // must be locked for the lock period - // If the global lock period is greater than the stake lock period, - // the stake must be locked for the stake lock period + // If the stake lock period is greater than the global lock period, + // the stake must be locked for the global lock period + // If the stake lock period is less than the global lock period, the stake + // must be locked for the stake lock period uint256 lock = keyperStake.lockPeriod > lockPeriod ? lockPeriod : keyperStake.lockPeriod; - require( - block.timestamp > keyperStake.timestamp + lock, - StakeIsStillLocked() - ); + unchecked { + require( + block.timestamp > keyperStake.timestamp + lock, + StakeIsStillLocked() + ); + } + + uint256 maxWithdraw = keyperStake.amount - minStake; + require(amount <= maxWithdraw, WithdrawAmountTooHigh()); // The unstake can't never result in a keyper SHU staked < minStake require( @@ -318,15 +327,15 @@ contract Staking is BaseStaking { address keyper, uint256 unlockedAmount ) internal view virtual returns (uint256 amount) { - uint256 shares = balanceOf(keyper); - require(shares > 0, UserHasNoShares()); + uint256 assets = convertToAssets(balanceOf(keyper)); + require(assets > 0, UserHasNoShares()); - uint256 assets = convertToAssets(shares); + unchecked { + uint256 locked = totalLocked[keyper] - unlockedAmount; + uint256 compare = locked >= minStake ? locked : minStake; - uint256 locked = totalLocked[keyper] - unlockedAmount; - uint256 compare = locked >= minStake ? locked : minStake; - - // need the first branch as convertToAssets rounds down - amount = compare >= assets ? 0 : assets - compare; + // need the first branch as convertToAssets rounds down + amount = compare >= assets ? 0 : assets - compare; + } } } diff --git a/src/interfaces/IRewardsDistributor.sol b/src/interfaces/IRewardsDistributor.sol index 07217fd..c795bda 100644 --- a/src/interfaces/IRewardsDistributor.sol +++ b/src/interfaces/IRewardsDistributor.sol @@ -6,7 +6,7 @@ interface IRewardsDistributor { function collectRewardsTo(address receiver) external returns (uint256); - function withdrawFunds(address to, uint256 amount) external; + function withdrawFunds(address token, address to, uint256 amount) external; function setRewardConfiguration( address receiver, diff --git a/test/Staking.integration.t.sol b/test/Staking.integration.t.sol index 6849872..9c88166 100644 --- a/test/Staking.integration.t.sol +++ b/test/Staking.integration.t.sol @@ -27,7 +27,7 @@ contract StakingIntegrationTest is Test { vm.createSelectFork(vm.rpcUrl("mainnet"), 20254999); Deploy deployScript = new Deploy(); - (staking, rewardsDistributor) = deployScript.run(); + (staking, rewardsDistributor, ) = deployScript.run(); } function _boundRealisticTimeAhead(