From 3a8883f1724578400659c280a0184229df1652f4 Mon Sep 17 00:00:00 2001 From: Andres Adjimann Date: Wed, 1 Nov 2023 14:50:49 -0300 Subject: [PATCH] feat: add a view function in RoyaltiesRegistry Add a function that return the royalties without touching the storage. The previous one that has a royalty type cache is named: getRoyaltiesWithTypeCache now --- .../contracts/RoyaltiesRegistry.sol | 62 +++++++++++++------ .../marketplace/contracts/TransferManager.sol | 2 +- .../interfaces/IRoyaltiesProvider.sol | 14 +++-- .../contracts/mocks/RoyaltiesProviderMock.sol | 4 ++ .../test/RoyaltiesRegistry.test.ts | 2 +- 5 files changed, 58 insertions(+), 26 deletions(-) diff --git a/packages/marketplace/contracts/RoyaltiesRegistry.sol b/packages/marketplace/contracts/RoyaltiesRegistry.sol index 59cd6cfa1a..668c2e7998 100644 --- a/packages/marketplace/contracts/RoyaltiesRegistry.sol +++ b/packages/marketplace/contracts/RoyaltiesRegistry.sol @@ -123,7 +123,21 @@ contract RoyaltiesRegistry is OwnableUpgradeable, IRoyaltiesProvider { /// @param token Address of the token. /// @param tokenId ID of the token. /// @return An array containing royalty parts. - function getRoyalties(address token, uint256 tokenId) external override returns (Part[] memory) { + function getRoyalties(address token, uint256 tokenId) external view override returns (Part[] memory) { + uint256 royaltiesProviderData = royaltiesProviders[token]; + address royaltiesProvider = address(uint160(royaltiesProviderData)); + RoyaltiesType royaltiesType = _getRoyaltiesType(royaltiesProviderData); + if (royaltiesType == RoyaltiesType.UNSET) { + royaltiesType = _calculateRoyaltiesType(token, royaltiesProvider); + } + return _getRoyalties(royaltiesType, token, tokenId, royaltiesProvider); + } + + /// @notice Fetches royalties for a given token and token ID (cache the type if needed). + /// @param token Address of the token. + /// @param tokenId ID of the token. + /// @return An array containing royalty parts. + function getRoyaltiesWithTypeCache(address token, uint256 tokenId) external override returns (Part[] memory) { uint256 royaltiesProviderData = royaltiesProviders[token]; address royaltiesProvider = address(uint160(royaltiesProviderData)); @@ -137,24 +151,7 @@ contract RoyaltiesRegistry is OwnableUpgradeable, IRoyaltiesProvider { //saving royalties type _setRoyaltiesType(token, royaltiesType, royaltiesProvider); } - - //case royaltiesType = 1, royalties are set in royaltiesByToken - if (royaltiesType == RoyaltiesType.BY_TOKEN) { - return royaltiesByToken[token].royalties; - } - - //case royaltiesType = 2, royalties from external provider - if (royaltiesType == RoyaltiesType.EXTERNAL_PROVIDER) { - return _providerExtractor(token, tokenId, royaltiesProvider); - } - - //case royaltiesType = 3, royalties EIP-2981 - if (royaltiesType == RoyaltiesType.EIP2981) { - return _getRoyaltiesEIP2981(token, tokenId); - } - - // case royaltiesType = 4, unknown/empty royalties - return new Part[](0); + return _getRoyalties(royaltiesType, token, tokenId, royaltiesProvider); } /// @notice Returns provider address for token contract from royaltiesProviders mapping @@ -267,7 +264,7 @@ contract RoyaltiesRegistry is OwnableUpgradeable, IRoyaltiesProvider { address token, uint256 tokenId, address providerAddress - ) internal returns (Part[] memory) { + ) internal view returns (Part[] memory) { try IRoyaltiesProvider(providerAddress).getRoyalties(token, tokenId) returns (Part[] memory result) { return result; } catch { @@ -292,6 +289,31 @@ contract RoyaltiesRegistry is OwnableUpgradeable, IRoyaltiesProvider { return result; } + /// @notice Fetches royalties for a given token and token ID. + /// @param royaltiesType type of royalty + /// @param token Address of the token. + /// @param tokenId ID of the token. + /// @param royaltiesProvider The address of the royalties provider. + /// @return An array containing royalty parts. + function _getRoyalties( + RoyaltiesType royaltiesType, + address token, + uint256 tokenId, + address royaltiesProvider + ) internal view returns (Part[] memory) { + if (royaltiesType == RoyaltiesType.BY_TOKEN) { + return royaltiesByToken[token].royalties; + } + if (royaltiesType == RoyaltiesType.EXTERNAL_PROVIDER) { + return _providerExtractor(token, tokenId, royaltiesProvider); + } + if (royaltiesType == RoyaltiesType.EIP2981) { + return _getRoyaltiesEIP2981(token, tokenId); + } + // case royaltiesType = 4, unknown/empty royalties + return new Part[](0); + } + // slither-disable-next-line unused-state uint256[50] private __gap; } diff --git a/packages/marketplace/contracts/TransferManager.sol b/packages/marketplace/contracts/TransferManager.sol index d8d570a0fa..a54ce49d94 100644 --- a/packages/marketplace/contracts/TransferManager.sol +++ b/packages/marketplace/contracts/TransferManager.sol @@ -176,7 +176,7 @@ abstract contract TransferManager is Initializable, ITransferManager { DealSide memory nftSide ) internal returns (uint256) { (address token, uint256 tokenId) = LibAsset.decodeToken(nftSide.asset.assetType); - IRoyaltiesProvider.Part[] memory royalties = royaltiesRegistry.getRoyalties(token, tokenId); + IRoyaltiesProvider.Part[] memory royalties = royaltiesRegistry.getRoyaltiesWithTypeCache(token, tokenId); uint256 totalRoyalties; uint256 len = royalties.length; for (uint256 i; i < len; i++) { diff --git a/packages/marketplace/contracts/interfaces/IRoyaltiesProvider.sol b/packages/marketplace/contracts/interfaces/IRoyaltiesProvider.sol index be37021617..e59e14bfe3 100644 --- a/packages/marketplace/contracts/interfaces/IRoyaltiesProvider.sol +++ b/packages/marketplace/contracts/interfaces/IRoyaltiesProvider.sol @@ -15,9 +15,15 @@ interface IRoyaltiesProvider { uint256 value; } - /// @notice Calculates all roaylties in token for tokenId + /// @notice Calculates all royalties in token for tokenId /// @param token Address of token - /// @param tokenId of the token we want to calculate royalites - /// @return A LibPart.Part with allroyalties for token - function getRoyalties(address token, uint256 tokenId) external returns (Part[] memory); + /// @param tokenId of the token we want to calculate royalties + /// @return A LibPart.Part with all royalties for token + function getRoyalties(address token, uint256 tokenId) external view returns (Part[] memory); + + /// @notice Calculates all royalties in token for tokenId as getRoyalties caching the royalty type + /// @param token Address of token + /// @param tokenId of the token we want to calculate royalties + /// @return A LibPart.Part with all royalties for token + function getRoyaltiesWithTypeCache(address token, uint256 tokenId) external returns (Part[] memory); } diff --git a/packages/marketplace/contracts/mocks/RoyaltiesProviderMock.sol b/packages/marketplace/contracts/mocks/RoyaltiesProviderMock.sol index 5ad7e89c3b..2b41686fd9 100644 --- a/packages/marketplace/contracts/mocks/RoyaltiesProviderMock.sol +++ b/packages/marketplace/contracts/mocks/RoyaltiesProviderMock.sol @@ -11,6 +11,10 @@ contract RoyaltiesProviderMock is IRoyaltiesProvider { return royaltiesTest[token][tokenId]; } + function getRoyaltiesWithTypeCache(address token, uint256 tokenId) external view override returns (Part[] memory) { + return royaltiesTest[token][tokenId]; + } + function initializeProvider(address token, uint256 tokenId, Part[] memory royalties) public { delete royaltiesTest[token][tokenId]; for (uint256 i = 0; i < royalties.length; ++i) { diff --git a/packages/marketplace/test/RoyaltiesRegistry.test.ts b/packages/marketplace/test/RoyaltiesRegistry.test.ts index f9aa3216c1..c200ff8249 100644 --- a/packages/marketplace/test/RoyaltiesRegistry.test.ts +++ b/packages/marketplace/test/RoyaltiesRegistry.test.ts @@ -284,7 +284,7 @@ describe('RoyaltiesRegistry.sol', function () { ) ).to.be.equal(0); - await RoyaltiesRegistryAsUser.getRoyalties( + await RoyaltiesRegistryAsUser.getRoyaltiesWithTypeCache( await ERC721WithRoyaltyV2981.getAddress(), 1 );