diff --git a/contracts/gatekeepers/TokenGateKeeper.sol b/contracts/gatekeepers/TokenGateKeeper.sol index 48945e0f..8bc76bde 100644 --- a/contracts/gatekeepers/TokenGateKeeper.sol +++ b/contracts/gatekeepers/TokenGateKeeper.sol @@ -4,13 +4,14 @@ pragma solidity 0.8.20; import { IGateKeeper } from "./IGateKeeper.sol"; import { ContributionRouter } from "../crowdfund/ContributionRouter.sol"; -/** - * @notice Compatible with both ER20s and ERC721s. - */ -interface Token { +interface IERC20orERC721 { function balanceOf(address owner) external view returns (uint256); } +interface IERC1155 { + function balanceOf(address owner, uint256 id) external view returns (uint256); +} + /** * @notice a contract that implements an token gatekeeper */ @@ -25,13 +26,14 @@ contract TokenGateKeeper is IGateKeeper { } struct TokenGate { - Token token; + address token; + uint256 tokenId; uint256 minimumBalance; } - event TokenGateCreated(Token token, uint256 minimumBalance); + event TokenGateCreated(address token, uint256 tokenId, uint256 minimumBalance); - /// @notice Get the information for a gate identifyied by it's `id`. + /// @notice Get the information for a gate identified by it's `id`. mapping(uint96 => TokenGate) public gateInfo; /// @inheritdoc IGateKeeper @@ -44,18 +46,43 @@ contract TokenGateKeeper is IGateKeeper { participant = ContributionRouter(payable(CONTRIBUTION_ROUTER)).caller(); } TokenGate memory _gate = gateInfo[uint96(id)]; - return _gate.token.balanceOf(participant) >= _gate.minimumBalance; + + if (_gate.tokenId == 0) { + return IERC20orERC721(_gate.token).balanceOf(participant) >= _gate.minimumBalance; + } else { + return + IERC1155(_gate.token).balanceOf(participant, _gate.tokenId) >= _gate.minimumBalance; + } } - /// @notice Creates a gate that requires a minimum balance of a token. + /// @notice Creates a gate that requires a minimum balance of an ERC721 or ERC20 token. /// @param token The token address (e.g. ERC20 or ERC721). /// @param minimumBalance The minimum balance allowed for participation. /// @return id The ID of the new gate. - function createGate(Token token, uint256 minimumBalance) external returns (bytes12 id) { + function createGate(address token, uint256 minimumBalance) external returns (bytes12 id) { + return createGate(token, 0, minimumBalance); + } + + /// @notice Creates a gate that requires a minimum balance of an ERC1155 token. + /// @param token The token address (ERC1155). + /// @param tokenId The token ID. + /// @param minimumBalance The minimum balance allowed for participation. + /// @return id The ID of the new gate. + function createGate( + address token, + uint256 tokenId, + uint256 minimumBalance + ) public returns (bytes12 id) { uint96 id_ = ++_lastId; id = bytes12(id_); - gateInfo[id_].token = token; - gateInfo[id_].minimumBalance = minimumBalance; - emit TokenGateCreated(token, minimumBalance); + + TokenGate memory gate = TokenGate({ + token: token, + tokenId: tokenId, + minimumBalance: minimumBalance + }); + gateInfo[id_] = gate; + + emit TokenGateCreated(token, tokenId, minimumBalance); } } diff --git a/deploy/Deploy.s.sol b/deploy/Deploy.s.sol index 7346b557..fa49e2d3 100644 --- a/deploy/Deploy.s.sol +++ b/deploy/Deploy.s.sol @@ -14,7 +14,7 @@ import "../contracts/crowdfund/ReraiseETHCrowdfund.sol"; import "../contracts/crowdfund/CrowdfundFactory.sol"; import "../contracts/distribution/TokenDistributor.sol"; import "../contracts/gatekeepers/AllowListGateKeeper.sol"; -import "../contracts/gatekeepers/TokenGateKeeper.sol"; +import { TokenGateKeeper } from "../contracts/gatekeepers/TokenGateKeeper.sol"; import "../contracts/gatekeepers/IGateKeeper.sol"; import "../contracts/globals/Globals.sol"; import "../contracts/globals/LibGlobals.sol"; diff --git a/test/authorities/SellPartyCardsAuthority.t.sol b/test/authorities/SellPartyCardsAuthority.t.sol index e63471f5..88a8308a 100644 --- a/test/authorities/SellPartyCardsAuthority.t.sol +++ b/test/authorities/SellPartyCardsAuthority.t.sol @@ -5,7 +5,7 @@ import { Party, SetupPartyHelper } from "../utils/SetupPartyHelper.sol"; import { SellPartyCardsAuthority } from "contracts/authorities/SellPartyCardsAuthority.sol"; import { IGateKeeper } from "contracts/gatekeepers/IGateKeeper.sol"; import { ContributionRouter } from "../../contracts/crowdfund/ContributionRouter.sol"; -import { TokenGateKeeper, Token } from "contracts/gatekeepers/TokenGateKeeper.sol"; +import { TokenGateKeeper } from "contracts/gatekeepers/TokenGateKeeper.sol"; import { DummyERC20 } from "../DummyERC20.sol"; contract SellPartyCardsAuthorityTest is SetupPartyHelper { @@ -472,7 +472,7 @@ contract SellPartyCardsAuthorityTest is SetupPartyHelper { address buyer = _randomAddress(); vm.deal(buyer, 2 ether); - bytes12 gatekeeperId = gatekeeper.createGate(Token(address(token)), 0.01 ether); + bytes12 gatekeeperId = gatekeeper.createGate(address(token), 0.01 ether); token.deal(buyer, 0.001 ether); SellPartyCardsAuthority.FixedMembershipSaleOpts memory opts = SellPartyCardsAuthority diff --git a/test/crowdfund/ContributionRouterIntegration.t.sol b/test/crowdfund/ContributionRouterIntegration.t.sol index a1007aad..b26f04de 100644 --- a/test/crowdfund/ContributionRouterIntegration.t.sol +++ b/test/crowdfund/ContributionRouterIntegration.t.sol @@ -6,7 +6,7 @@ import "../../contracts/party/PartyFactory.sol"; import "../../contracts/crowdfund/InitialETHCrowdfund.sol"; import "../../contracts/crowdfund/ContributionRouter.sol"; import "./TestableCrowdfund.sol"; -import { TokenGateKeeper, Token } from "../../contracts/gatekeepers/TokenGateKeeper.sol"; +import { TokenGateKeeper } from "../../contracts/gatekeepers/TokenGateKeeper.sol"; import { DummyERC20 } from "../DummyERC20.sol"; import "../TestUtils.sol"; @@ -34,7 +34,7 @@ contract ContributionRouterIntegrationTest is TestUtils { gateKeeper = new TokenGateKeeper(address(router)); gatekeepToken = new DummyERC20(); - bytes12 gateKeeperId = gateKeeper.createGate(Token(address(gatekeepToken)), 100); + bytes12 gateKeeperId = gateKeeper.createGate(address(gatekeepToken), 100); InitialETHCrowdfund initialETHCrowdfundImpl = new InitialETHCrowdfund(globals); diff --git a/test/crowdfund/CrowdfundFactory.t.sol b/test/crowdfund/CrowdfundFactory.t.sol index ba2b3d05..3a021ff8 100644 --- a/test/crowdfund/CrowdfundFactory.t.sol +++ b/test/crowdfund/CrowdfundFactory.t.sol @@ -8,7 +8,7 @@ import "contracts/crowdfund/AuctionCrowdfund.sol"; import "contracts/market-wrapper/IMarketWrapper.sol"; import "contracts/crowdfund/Crowdfund.sol"; import "contracts/gatekeepers/AllowListGateKeeper.sol"; -import "contracts/gatekeepers/TokenGateKeeper.sol"; +import { TokenGateKeeper } from "contracts/gatekeepers/TokenGateKeeper.sol"; import "contracts/tokens/IERC721.sol"; import "./MockMarketWrapper.sol"; import "contracts/globals/Globals.sol"; @@ -78,9 +78,10 @@ contract CrowdfundFactoryTest is Test, TestUtils { } if (x == 1) { // Use `TokenGateKeeper`. - createGateCallData = abi.encodeCall( - TokenGateKeeper.createGate, - (Token(_randomAddress()), _randomUint256()) + createGateCallData = abi.encodeWithSelector( + bytes4(keccak256("createGate(address,uint256)")), + _randomAddress(), + _randomUint256() ); return (IGateKeeper(address(tokenGateKeeper)), bytes12(0), createGateCallData); } diff --git a/test/gatekeepers/TokenGateKeeper.t.sol b/test/gatekeepers/TokenGateKeeper.t.sol index 6330dea8..3c3247c4 100644 --- a/test/gatekeepers/TokenGateKeeper.t.sol +++ b/test/gatekeepers/TokenGateKeeper.t.sol @@ -5,30 +5,33 @@ import { Test } from "../../lib/forge-std/src/Test.sol"; import { console } from "../../lib/forge-std/src/console.sol"; import { DummyERC20 } from "../DummyERC20.sol"; import { DummyERC721 } from "../DummyERC721.sol"; +import { DummyERC1155 } from "../DummyERC1155.sol"; import { TestUtils } from "../TestUtils.sol"; -import { TokenGateKeeper, Token } from "../../contracts/gatekeepers/TokenGateKeeper.sol"; +import { TokenGateKeeper } from "../../contracts/gatekeepers/TokenGateKeeper.sol"; import "../../contracts/utils/LibERC20Compat.sol"; contract TokenGateKeeperTest is Test, TestUtils { TokenGateKeeper gk; uint256 constant MIN_ERC20_BALANCE = 10e18; uint256 constant MIN_ERC721_BALANCE = 1; + uint256 constant MIN_ERC1155_BALANCE = 1; DummyERC20 dummyERC20 = new DummyERC20(); DummyERC721 dummyERC721 = new DummyERC721(); + DummyERC1155 dummyERC1155 = new DummyERC1155(); function setUp() public { gk = new TokenGateKeeper(address(0)); } function testUniqueGateIds() public { - bytes12 gateId1 = gk.createGate(Token(address(dummyERC20)), MIN_ERC20_BALANCE); - bytes12 gateId2 = gk.createGate(Token(address(dummyERC721)), MIN_ERC721_BALANCE); + bytes12 gateId1 = gk.createGate(address(dummyERC20), MIN_ERC20_BALANCE); + bytes12 gateId2 = gk.createGate(address(dummyERC721), MIN_ERC721_BALANCE); assertTrue(gateId1 != gateId2); } function testAboveMinimumBalance() public { - bytes12 ERC20gateId = gk.createGate(Token(address(dummyERC20)), MIN_ERC20_BALANCE); - bytes12 ERC721gateId = gk.createGate(Token(address(dummyERC721)), MIN_ERC721_BALANCE); + bytes12 ERC20gateId = gk.createGate(address(dummyERC20), MIN_ERC20_BALANCE); + bytes12 ERC721gateId = gk.createGate(address(dummyERC721), MIN_ERC721_BALANCE); address user = _randomAddress(); dummyERC20.deal(user, MIN_ERC20_BALANCE + 1); dummyERC721.mint(user); @@ -38,8 +41,8 @@ contract TokenGateKeeperTest is Test, TestUtils { } function testEqualToMinimumBalance() public { - bytes12 ERC20gateId = gk.createGate(Token(address(dummyERC20)), MIN_ERC20_BALANCE); - bytes12 ERC721gateId = gk.createGate(Token(address(dummyERC721)), MIN_ERC721_BALANCE); + bytes12 ERC20gateId = gk.createGate(address(dummyERC20), MIN_ERC20_BALANCE); + bytes12 ERC721gateId = gk.createGate(address(dummyERC721), MIN_ERC721_BALANCE); address user = _randomAddress(); dummyERC20.deal(user, MIN_ERC20_BALANCE); dummyERC721.mint(user); @@ -48,16 +51,16 @@ contract TokenGateKeeperTest is Test, TestUtils { } function testBelowMinimumBalance() public { - bytes12 ERC20gateId = gk.createGate(Token(address(dummyERC20)), MIN_ERC20_BALANCE); - bytes12 ERC721gateId = gk.createGate(Token(address(dummyERC721)), MIN_ERC721_BALANCE); + bytes12 ERC20gateId = gk.createGate(address(dummyERC20), MIN_ERC20_BALANCE); + bytes12 ERC721gateId = gk.createGate(address(dummyERC721), MIN_ERC721_BALANCE); address user = _randomAddress(); assertFalse(gk.isAllowed(user, ERC20gateId, "")); assertFalse(gk.isAllowed(user, ERC721gateId, "")); } function testSeparateGateAccess() public { - bytes12 ERC20gateId = gk.createGate(Token(address(dummyERC20)), MIN_ERC20_BALANCE); - bytes12 ERC721gateId = gk.createGate(Token(address(dummyERC721)), MIN_ERC721_BALANCE); + bytes12 ERC20gateId = gk.createGate(address(dummyERC20), MIN_ERC20_BALANCE); + bytes12 ERC721gateId = gk.createGate(address(dummyERC721), MIN_ERC721_BALANCE); address user1 = _randomAddress(); address user2 = _randomAddress(); dummyERC20.deal(user1, MIN_ERC20_BALANCE); @@ -67,4 +70,25 @@ contract TokenGateKeeperTest is Test, TestUtils { assertFalse(gk.isAllowed(user1, ERC721gateId, "")); assertFalse(gk.isAllowed(user2, ERC20gateId, "")); } + + function test1155Gate() public { + bytes12 ERC1155gateId = gk.createGate(address(dummyERC1155), 1, MIN_ERC1155_BALANCE); + address user = _randomAddress(); + dummyERC1155.deal(user, 1, 1); + assertTrue(gk.isAllowed(user, ERC1155gateId, "")); + } + + function test1155Gate_wrongId() public { + bytes12 ERC1155gateId = gk.createGate(address(dummyERC1155), 1, MIN_ERC1155_BALANCE); + address user = _randomAddress(); + dummyERC1155.deal(user, 2, 1); + assertFalse(gk.isAllowed(user, ERC1155gateId, "")); + } + + function test1155Gate_notEnough() public { + bytes12 ERC1155gateId = gk.createGate(address(dummyERC1155), 1, 3); + address user = _randomAddress(); + dummyERC1155.deal(user, 1, 2); + assertFalse(gk.isAllowed(user, ERC1155gateId, "")); + } }