Skip to content

Commit

Permalink
harden caller checks for ComponentService.registerProduct/registerCom…
Browse files Browse the repository at this point in the history
…ponent
  • Loading branch information
matthiaszimmermann committed Sep 2, 2024
1 parent ebfcf1f commit d878d0f
Show file tree
Hide file tree
Showing 7 changed files with 216 additions and 160 deletions.
100 changes: 59 additions & 41 deletions contracts/shared/ComponentService.sol
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ pragma solidity ^0.8.20;
import {IERC20Metadata} from "@openzeppelin/contracts/token/ERC20/extensions/IERC20Metadata.sol";

import {IAccountingService} from "../accounting/IAccountingService.sol";
import {IComponent} from "../shared/IComponent.sol";
import {IComponents} from "../instance/module/IComponents.sol";
import {IComponentService} from "./IComponentService.sol";
import {IInstance} from "../instance/IInstance.sol";
Expand All @@ -15,7 +14,6 @@ import {InstanceStore} from "../instance/InstanceStore.sol";
import {IInstanceService} from "../instance/IInstanceService.sol";
import {IPoolComponent} from "../pool/IPoolComponent.sol";
import {IProductComponent} from "../product/IProductComponent.sol";
import {IRegisterable} from "../shared/IRegisterable.sol";
import {IRegistry} from "../registry/IRegistry.sol";
import {IRegistryService} from "../registry/IRegistryService.sol";

Expand All @@ -29,7 +27,6 @@ import {ObjectType, ACCOUNTING, REGISTRY, COMPONENT, DISTRIBUTION, INSTANCE, ORA
import {Service} from "../shared/Service.sol";
import {TokenHandler} from "../shared/TokenHandler.sol";
import {TokenHandlerDeployerLib} from "../shared/TokenHandlerDeployerLib.sol";
import {VersionPart} from "../type/Version.sol";


contract ComponentService is
Expand Down Expand Up @@ -73,17 +70,31 @@ contract ComponentService is
virtual
returns (NftId componentNftId)
{
(NftId productNftId, IInstance instance) = _getAndVerifyActiveComponent(COMPONENT());
NftId instanceNftId = getRegistry().getNftIdForAddress(address(instance));
// checks
// check sender is registered product
IRegistry registry = getRegistry();
if (!registry.isObjectType(msg.sender, PRODUCT())) {
revert ErrorComponentServiceCallerNotProduct(msg.sender);
}

// check provided address is product contract
if (!ContractLib.isInstanceLinkedComponent(address(registry), componentAddress)) {
revert ErrorComponentServiceNotComponent(componentAddress);
}

NftId productNftId = registry.getNftIdForAddress(msg.sender);
IInstance instance = IInstance(
registry.getObjectAddress(
registry.getParentNftId(productNftId)));

componentNftId = _verifyAndRegister(
instanceNftId,
instance,
componentAddress,
productNftId,
address(0));
productNftId, // product is parent of component to be registered
address(0)); // token will be inhereited from product
}


/// @inheritdoc IComponentService
function approveTokenHandler(
IERC20Metadata token,
Expand All @@ -93,7 +104,7 @@ contract ComponentService is
virtual
{
// checks
(NftId componentNftId, IInstance instance) = _getAndVerifyActiveComponent(COMPONENT());
(NftId componentNftId, IInstance instance) = _getAndVerifyComponent(COMPONENT(), true);
TokenHandler tokenHandler = instance.getInstanceReader().getTokenHandler(
componentNftId);

Expand All @@ -102,12 +113,13 @@ contract ComponentService is
}


/// @inheritdoc IComponentService
function setWallet(address newWallet)
external
virtual
{
// checks
(NftId componentNftId, IInstance instance) = _getAndVerifyActiveComponent(COMPONENT());
(NftId componentNftId, IInstance instance) = _getAndVerifyComponent(COMPONENT(), true);
TokenHandler tokenHandler = instance.getInstanceReader().getTokenHandler(
componentNftId);

Expand All @@ -129,6 +141,7 @@ contract ComponentService is
locked);
}


/// @inheritdoc IComponentService
function withdrawFees(Amount amount)
external
Expand All @@ -137,7 +150,7 @@ contract ComponentService is
returns (Amount withdrawnAmount)
{
// checks
(NftId componentNftId, IInstance instance) = _getAndVerifyActiveComponent(COMPONENT());
(NftId componentNftId, IInstance instance) = _getAndVerifyComponent(COMPONENT(), true);
InstanceReader instanceReader = instance.getInstanceReader();

// determine withdrawn amount
Expand Down Expand Up @@ -182,20 +195,35 @@ contract ComponentService is

//-------- product ------------------------------------------------------//

/// @inheritdoc IComponentService
function registerProduct(address productAddress, address token)
external
virtual
nonReentrant()
returns (NftId productNftId)
{
// TODO instance verification
//(NftId instanceNftId,, IInstance instance) = _getAndVerifyCallingInstance();
NftId instanceNftId = getRegistry().getNftIdForAddress(msg.sender);
IInstance instance = IInstance(msg.sender);
// checks
// check sender is registered instance
IRegistry registry = getRegistry();
if (!registry.isObjectType(msg.sender, INSTANCE())) {
revert ErrorComponentServiceCallerNotInstance(msg.sender);
}

// check provided address is product contract
if (!ContractLib.isProduct(address(registry), productAddress)) {
revert ErrorComponentServiceNotProduct(productAddress);
}

productNftId = _verifyAndRegister(instanceNftId, instance, productAddress, instanceNftId, token);
IInstance instance = IInstance(msg.sender);
productNftId = _verifyAndRegister(
instance,
productAddress,
instance.getNftId(), // instance is parent of product to be registered
token);
}


/// @inheritdoc IComponentService
function setProductFees(
Fee memory productFee, // product fee on net premium
Fee memory processingFee // product fee on payout amounts
Expand All @@ -204,7 +232,7 @@ contract ComponentService is
virtual
nonReentrant()
{
(NftId productNftId, IInstance instance) = _getAndVerifyActiveComponent(PRODUCT());
(NftId productNftId, IInstance instance) = _getAndVerifyComponent(PRODUCT(), true);
IComponents.FeeInfo memory feeInfo = instance.getInstanceReader().getFeeInfo(productNftId);
bool feesChanged = false;

Expand All @@ -228,6 +256,7 @@ contract ComponentService is
}
}


function _createProduct(
InstanceStore instanceStore,
NftId productNftId,
Expand Down Expand Up @@ -282,14 +311,15 @@ contract ComponentService is
instanceStore.updateProduct(productNftId, productInfo, KEEP_STATE());
}


function setDistributionFees(
Fee memory distributionFee, // distribution fee for sales that do not include commissions
Fee memory minDistributionOwnerFee // min fee required by distribution owner (not including commissions for distributors)
)
external
virtual
{
(NftId distributionNftId, IInstance instance) = _getAndVerifyActiveComponent(DISTRIBUTION());
(NftId distributionNftId, IInstance instance) = _getAndVerifyComponent(DISTRIBUTION(), true);
(NftId productNftId, IComponents.FeeInfo memory feeInfo) = _getLinkedFeeInfo(
instance.getInstanceReader(), distributionNftId);
bool feesChanged = false;
Expand Down Expand Up @@ -338,6 +368,7 @@ contract ComponentService is
productInfo.numberOfOracles++;
instanceStore.updateProduct(productNftId, productInfo, KEEP_STATE());
}

//-------- pool ---------------------------------------------------------//

function _createPool(
Expand Down Expand Up @@ -376,7 +407,7 @@ contract ComponentService is
external
virtual
{
(NftId poolNftId, IInstance instance) = _getAndVerifyActiveComponent(POOL());
(NftId poolNftId, IInstance instance) = _getAndVerifyComponent(POOL(), true);

(NftId productNftId, IComponents.FeeInfo memory feeInfo) = _getLinkedFeeInfo(
instance.getInstanceReader(), poolNftId);
Expand Down Expand Up @@ -409,9 +440,10 @@ contract ComponentService is
}
}


/// @dev Registers the component represented by the provided address.
/// The caller must ensure componentAddress is IInstanceLinkedComponent.
function _verifyAndRegister(
NftId instanceNftId,
IInstance instance,
address componentAddress,
NftId parentNftId,
Expand Down Expand Up @@ -486,16 +518,16 @@ contract ComponentService is
// authorize
instanceAdmin.initializeComponentAuthorization(componentAddress, componentType);

// TODO mostly repeats Registry log
emit LogComponentServiceRegistered(
instanceNftId,
instance.getNftId(),
componentNftId,
componentType,
address(component),
token,
objectInfo.initialOwner);
}


function _checkToken(IInstance instance, address token)
internal
view
Expand All @@ -513,6 +545,7 @@ contract ComponentService is
}
}


function _logUpdateFee(NftId productNftId, string memory name, Fee memory feeBefore, Fee memory feeAfter)
internal
virtual
Expand Down Expand Up @@ -543,9 +576,9 @@ contract ComponentService is
info = instanceReader.getFeeInfo(productNftId);
}


/// @dev Based on the provided component address required type the component
/// and related instance contract this function reverts iff:
/// - the component contract does not support IInstanceLinkedComponent
/// - the component parent does not match with the required parent
/// - the component release does not match with the service release
/// - the component has already been registered
Expand All @@ -560,11 +593,6 @@ contract ComponentService is
IRegistry.ObjectInfo memory info
)
{
// check component interface
if (!ContractLib.supportsInterface(componentAddress, type(IInstanceLinkedComponent).interfaceId)) {
revert ErrorComponentServiceNotInstanceLinkedComponent(componentAddress);
}

component = IInstanceLinkedComponent(componentAddress);
info = component.getInitialInfo();

Expand All @@ -573,9 +601,7 @@ contract ComponentService is
revert ErrorComponentServiceComponentParentInvalid(componentAddress, requiredParent, info.parentNftId);
}

// check component release
// TODO check version with registry
//if(info.version != getRelease()) {
// check component release (must match with service release)
if(component.getRelease() != getRelease()) {
revert ErrorComponentServiceComponentReleaseMismatch(componentAddress, getRelease(), component.getRelease());
}
Expand All @@ -586,20 +612,11 @@ contract ComponentService is
}
}


function _setLocked(InstanceAdmin instanceAdmin, address componentAddress, bool locked) internal {
instanceAdmin.setTargetLocked(componentAddress, locked);
}

function _getAndVerifyActiveComponent(ObjectType expectedType)
internal
view
returns (
NftId componentNftId,
IInstance instance
)
{
return _getAndVerifyComponent(expectedType, true); // only active
}

function _getAndVerifyComponent(ObjectType expectedType, bool isActive)
internal
Expand Down Expand Up @@ -630,6 +647,7 @@ contract ComponentService is
instance = IInstance(instanceAddress);
}


function _getDomain() internal pure virtual override returns(ObjectType) {
return COMPONENT();
}
Expand Down
Loading

0 comments on commit d878d0f

Please sign in to comment.