diff --git a/src/Staking.sol b/src/Staking.sol index 8afa9bd..0f2471c 100644 --- a/src/Staking.sol +++ b/src/Staking.sol @@ -5,6 +5,7 @@ import {console} from "@forge-std/console.sol"; import {IERC20} from "@openzeppelin/contracts/token/ERC20/IERC20.sol"; import {ERC20} from "@openzeppelin/contracts/token/ERC20/ERC20.sol"; import "@openzeppelin/contracts/token/ERC20/utils/SafeERC20.sol"; +import "@openzeppelin/contracts/utils/structs/EnumerableSet.sol"; import {Ownable2StepUpgradeable} from "@openzeppelin-upgradeable/contracts/access/Ownable2StepUpgradeable.sol"; import {ERC20VotesUpgradeable} from "@openzeppelin-upgradeable/contracts/token/ERC20/extensions/ERC20VotesUpgradeable.sol"; import {FixedPointMathLib} from "@solmate/utils/FixedPointMathLib.sol"; @@ -19,6 +20,7 @@ contract Staking is ERC20VotesUpgradeable, Ownable2StepUpgradeable { //////////////////////////////////////////////////////////////*/ using SafeERC20 for IERC20; using FixedPointMathLib for uint256; + using EnumerableSet for EnumerableSet.UintSet; /*////////////////////////////////////////////////////////////// IMMUTABLES @@ -44,6 +46,9 @@ contract Staking is ERC20VotesUpgradeable, Ownable2StepUpgradeable { /// @dev only owner can change uint256 public minStake; + /// @notice Unique identifier that will be used for the next stake. + uint256 private nextStakeId; + /*////////////////////////////////////////////////////////////// STRUCTS //////////////////////////////////////////////////////////////*/ @@ -60,11 +65,15 @@ contract Staking is ERC20VotesUpgradeable, Ownable2StepUpgradeable { MAPPINGS/ARRAYS //////////////////////////////////////////////////////////////*/ - /// @notice the keyper stakes mapping - mapping(address keyper => Stake[]) public stakes; + /// @notice stores the metadata associated with a given stake + mapping(uint256 id => Stake) public stakes; + + // @notice stake ids belonging to a keyper + mapping(address keyper => EnumerableSet.UintSet stakeIds) + private keyperStakes; /// TODO when remove keyper also unstake the first stake - /// @notice the keypers mapping + /// @notice keypers mapping mapping(address keyper => bool isKeyper) public keypers; /// @notice how many SHU a keyper has locked @@ -152,10 +161,10 @@ contract Staking is ERC20VotesUpgradeable, Ownable2StepUpgradeable { address keyper = msg.sender; // Get the keyper stakes - Stake[] storage keyperStakes = stakes[keyper]; + EnumerableSet.UintSet storage stakesIds = keyperStakes[keyper]; // If the keyper has no stakes, the first stake must be at least the minimum stake - if (keyperStakes.length == 0) { + if (stakesIds.length() == 0) { require( amount >= minStake, "The first stake must be at least the minimum stake" @@ -177,12 +186,20 @@ contract Staking is ERC20VotesUpgradeable, Ownable2StepUpgradeable { // Lock the SHU in the contract STAKING_TOKEN.safeTransferFrom(keyper, address(this), amount); - // Record the new stake - keyperStakes.push(Stake(amount, block.timestamp, lockPeriod)); + // Get next stake id and increment it + uint256 stakeId = nextStakeId++; + + stakes[stakeId] = Stake({ + amount: amount, + timestamp: block.timestamp, + lockPeriod: lockPeriod + }); + + stakesIds.add(stakeId); emit Staked(keyper, amount, sharesToMint, lockPeriod); - return keyperStakes.length - 1; + return stakeId; } /// @notice Unstake SHU @@ -201,22 +218,26 @@ contract Staking is ERC20VotesUpgradeable, Ownable2StepUpgradeable { /// stake index, the contract will transfer the maximum amount available /// - amount must be specified in SHU, not shares /// @param keyper The keyper address - /// @param stakeIndex The index of the stake to unstake + /// @param stakeId The stake index /// @param amount The amount /// TODO check for reentrancy /// TODO unstake only principal /// TODO slippage protection function unstake( address keyper, - uint256 stakeIndex, + uint256 stakeId, uint256 amount ) external updateRewards { - console.log("stakes[keyper].length", stakes[keyper].length); /////////////////////////// CHECKS /////////////////////////////// - require(stakeIndex < stakes[keyper].length, "Invalid stake index"); + require( + keyperStakes[keyper].contains(stakeId), + "Stake does not belong to keyper" + ); - // Gets the keyper stake - Stake storage keyperStake = stakes[keyper][stakeIndex]; + require(stakes[stakeId].amount > 0, "Stake does not exist"); + + // Gets the stake + Stake storage keyperStake = stakes[stakeId]; uint256 maxWithdrawAmount; @@ -244,7 +265,6 @@ contract Staking is ERC20VotesUpgradeable, Ownable2StepUpgradeable { } maxWithdrawAmount = maxWithdraw(keyper, keyperStake.amount); - console.log("maxWithdrawAmount", maxWithdrawAmount); } else { // doesn't exclude the min stake and locked staked as the keyper is not a keyper anymore maxWithdrawAmount = convertToAssets(balanceOf(keyper)); @@ -281,11 +301,11 @@ contract Staking is ERC20VotesUpgradeable, Ownable2StepUpgradeable { // If the stake is empty, remove it if (keyperStake.amount == 0) { - // Remove the stake from the keyper's stake array - stakes[keyper][stakeIndex] = stakes[keyper][ - stakes[keyper].length - 1 - ]; - stakes[keyper].pop(); + // Remove the stake from the stakes mapping + delete stakes[stakeId]; + + // Remove the stake from the keyper stakes + keyperStakes[keyper].remove(stakeId); } /////////////////////////// INTERACTIONS /////////////////////////// @@ -505,4 +525,10 @@ contract Staking is ERC20VotesUpgradeable, Ownable2StepUpgradeable { function totalAssets() public view virtual returns (uint256) { return STAKING_TOKEN.balanceOf(address(this)); } + + function getKeyperStakeIds( + address keyper + ) public view returns (uint256[] memory) { + return keyperStakes[keyper].values(); + } } diff --git a/test/Staking.t.sol b/test/Staking.t.sol index c20d675..43e0ebd 100644 --- a/test/Staking.t.sol +++ b/test/Staking.t.sol @@ -94,7 +94,12 @@ contract StakingTest is Test { } function _mintGovToken(address _to, uint256 _amount) internal { - vm.assume(_to != address(0)); + vm.assume( + _to != address(0) && + _to != address(staking) && + _to != ProxyUtils.getAdminAddress(address(staking)) + ); + govToken.mint(_to, _amount); } @@ -109,7 +114,7 @@ contract StakingTest is Test { function _stake( address _keyper, uint256 _amount - ) internal returns (uint256 _depositId) { + ) internal returns (uint256 _stakeId) { vm.assume( _keyper != address(0) && uint160(_keyper) > 0x100 && // ignore precompiled address @@ -120,7 +125,7 @@ contract StakingTest is Test { vm.startPrank(_keyper); govToken.approve(address(staking), _amount); - _depositId = staking.stake(_amount); + _stakeId = staking.stake(_amount); vm.stopPrank(); } @@ -385,9 +390,9 @@ contract Stake is StakingTest { vm.assume(_depositor != address(0)); - uint256 depositIndex = _stake(_depositor, _amount); + uint256 stakeId = _stake(_depositor, _amount); - (uint256 amount, , ) = staking.stakes(_depositor, depositIndex); + (uint256 amount, , ) = staking.stakes(stakeId); assertEq(amount, _amount, "Wrong amount"); } @@ -403,9 +408,9 @@ contract Stake is StakingTest { vm.assume(_depositor != address(0)); - uint256 depositIndex = _stake(_depositor, _amount); + uint256 stakeId = _stake(_depositor, _amount); - (, uint256 timestamp, ) = staking.stakes(_depositor, depositIndex); + (, uint256 timestamp, ) = staking.stakes(stakeId); assertEq(timestamp, block.timestamp, "Wrong timestamp"); } @@ -421,9 +426,9 @@ contract Stake is StakingTest { vm.assume(_depositor != address(0)); - uint256 depositIndex = _stake(_depositor, _amount); + uint256 stakeId = _stake(_depositor, _amount); - (, , uint256 lockPeriod) = staking.stakes(_depositor, depositIndex); + (, , uint256 lockPeriod) = staking.stakes(stakeId); assertEq(lockPeriod, LOCK_PERIOD, "Wrong lock period"); } @@ -441,20 +446,14 @@ contract Stake is StakingTest { vm.assume(_depositor != address(0) && _depositor != address(this)); - uint256 depositIndex1 = _stake(_depositor, _amount1); + uint256 stakeId1 = _stake(_depositor, _amount1); - (uint256 amount1, uint256 timestamp, ) = staking.stakes( - _depositor, - depositIndex1 - ); + (uint256 amount1, uint256 timestamp, ) = staking.stakes(stakeId1); _jumpAhead(1); - uint256 depositIndex2 = _stake(_depositor, _amount2); - (uint256 amount2, uint256 timestamp2, ) = staking.stakes( - _depositor, - depositIndex2 - ); + uint256 stakeId2 = _stake(_depositor, _amount2); + (uint256 amount2, uint256 timestamp2, ) = staking.stakes(stakeId2); assertEq(amount1, _amount1, "Wrong amount"); assertEq(amount2, _amount2, "Wrong amount"); @@ -909,7 +908,7 @@ contract Unstake is StakingTest { ); } - function testFuzz_AnyoneCanUnstakeOnBehalfOfKeyperWhenKeyeprIsNotAKeyperAnymore( + function testFuzz_AnyoneCanUnstakeOnBehalfOfKeyperWhenKeyperIsNotAKeyperAnymore( address _depositor, address _anyone, uint256 _amount,