Skip to content

Commit

Permalink
Fix reentrancy by protecting podBalanceOf() and balanceOf() from acce…
Browse files Browse the repository at this point in the history
…ss during updateBalances() loop
  • Loading branch information
k06a committed Nov 23, 2022
1 parent 5fc07f7 commit dc0ea03
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 3 deletions.
14 changes: 11 additions & 3 deletions contracts/ERC20Pods.sol
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ import "@1inch/solidity-utils/contracts/libraries/AddressSet.sol";

import "./interfaces/IERC20Pods.sol";
import "./interfaces/IPod.sol";
import "./libs/ReentrancyGuard.sol";

abstract contract ERC20Pods is ERC20, IERC20Pods {
abstract contract ERC20Pods is ERC20, IERC20Pods, ReentrancyGuardExt {
using AddressSet for AddressSet.Data;
using AddressArray for AddressArray.Data;
using ReentrancyGuardLib for ReentrancyGuardLib.Data;

error PodAlreadyAdded();
error PodNotFound();
Expand All @@ -22,10 +24,12 @@ abstract contract ERC20Pods is ERC20, IERC20Pods {

uint256 public immutable podsLimit;

ReentrancyGuardLib.Data private _guard;
mapping(address => AddressSet.Data) private _pods;

constructor(uint256 podsLimit_) {
podsLimit = podsLimit_;
_guard.init();
}

function hasPod(address account, address pod) public view virtual returns(bool) {
Expand All @@ -44,7 +48,11 @@ abstract contract ERC20Pods is ERC20, IERC20Pods {
return _pods[account].items.get();
}

function podBalanceOf(address pod, address account) public view returns(uint256) {
function balanceOf(address account) public nonReentrantView(_guard) view override(IERC20, ERC20) returns(uint256) {
return super.balanceOf(account);
}

function podBalanceOf(address pod, address account) public nonReentrantView(_guard) view returns(uint256) {
if (hasPod(account, pod)) {
return balanceOf(account);
}
Expand Down Expand Up @@ -119,7 +127,7 @@ abstract contract ERC20Pods is ERC20, IERC20Pods {

// ERC20 Overrides

function _afterTokenTransfer(address from, address to, uint256 amount) internal override virtual {
function _afterTokenTransfer(address from, address to, uint256 amount) internal nonReentrant(_guard) override virtual {
super._afterTokenTransfer(from, to, amount);

unchecked {
Expand Down
47 changes: 47 additions & 0 deletions contracts/libs/ReentrancyGuard.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// SPDX-License-Identifier: MIT

pragma solidity ^0.8.0;

library ReentrancyGuardLib {
error ReentrantCall();

uint256 private constant _NOT_ENTERED = 1;
uint256 private constant _ENTERED = 2;

struct Data {
uint256 _status;
}

function init(Data storage self) internal {
self._status = _NOT_ENTERED;
}

function enter(Data storage self) internal {
if (self._status == _ENTERED) revert ReentrantCall();
self._status = _ENTERED;
}

function exit(Data storage self) internal {
self._status = _NOT_ENTERED;
}

function check(Data storage self) internal view returns (bool) {
return self._status == _ENTERED;
}
}

contract ReentrancyGuardExt {
using ReentrancyGuardLib for ReentrancyGuardLib.Data;
error AccessDenied();

modifier nonReentrant(ReentrancyGuardLib.Data storage self) {
self.enter();
_;
self.exit();
}

modifier nonReentrantView(ReentrancyGuardLib.Data storage self) {
if (self.check()) revert ReentrancyGuardLib.ReentrantCall();
_;
}
}

0 comments on commit dc0ea03

Please sign in to comment.